-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'output_file_format_upgrade' into 'master'
input format detect/sparse merge tool See merge request deep-learning/tensornet!12
- Loading branch information
Showing
7 changed files
with
173 additions
and
9 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,17 @@ | ||
## 合并sparse file | ||
|
||
由于生成的sparse文件分布在各个目录, 可以通过spark将sparse文件结果抽取结果并合并到一个目录下 | ||
|
||
```bash | ||
spark-submit3 --executor-memory 8g --driver-memory 10g --py-files utils.py merge_sparse.py -i "/user/test/model_path/sparse_table/*/*/*bin.gz" -o "/user/test/model_merge_path" -f 'bin' -n 500 -b | ||
``` | ||
|
||
### 参数配置 | ||
|
||
配置名称 | 默认值 | 含义 | ||
----------- | ----------- | ----------- | ||
-i/--input | None | 输入路径 | ||
-o/--output | None | 输出路径 | ||
-f/--format | bin | 输入文件格式 | ||
-n/--number | 20 | 输出并行度 | ||
-b/--bracker | False | 输出的Weights是否需要用[]包括, []当作一列, 用\t分割 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,53 @@ | ||
#!/usr/bin/python3.6 | ||
#coding=utf-8 | ||
import sys | ||
import argparse | ||
import os | ||
from pyspark import SparkContext, SparkConf | ||
from pyspark.sql import * | ||
from pyspark.sql.functions import * | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.types import * | ||
from utils import * | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--input", type=str, help="sparse table input path") | ||
parser.add_argument("-o", "--output", type=str, help="merged file output path") | ||
parser.add_argument("-f", "--format", type=str, help="input file format, 'txt' or 'bin'") | ||
parser.add_argument("-n", "--number", type=int, help="output file parallelism", default=30) | ||
parser.add_argument("-b", "--bracket", help="if dims need bracket", action="store_true", default=False) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args): | ||
spark = SparkSession.builder \ | ||
.appName("[spark][merge sparse table]") \ | ||
.master('yarn') \ | ||
.enableHiveSupport() \ | ||
.getOrCreate() | ||
|
||
sc = spark.sparkContext | ||
|
||
if args.format == 'txt': | ||
get_handle_name_udf = udf(get_handle_name, StringType()) | ||
dims_df = sc.textFile(args.input)\ | ||
.map(lambda x: process_txt_line(x))\ | ||
.toDF(["key", "dims"])\ | ||
.withColumn("input_file_name",F.input_file_name())\ | ||
.withColumn("handle", get_handle_name_udf(col("input_file_name")))\ | ||
.drop("input_file_name")\ | ||
.filter(col("key") != "").dropDuplicates(['key','handle']) | ||
elif args.format == 'bin': | ||
dims_df = sc.binaryFiles(args.input)\ | ||
.mapPartitions(process_binary_partition)\ | ||
.toDF(['handle', 'key', 'dims']) | ||
|
||
dims_df.dropDuplicates(['key','handle']).drop('handle').rdd.map(lambda x: output_line(x, args.bracket)).repartition(args.number).saveAsTextFile(args.output) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
#coding=utf-8 | ||
import gzip | ||
import struct | ||
import io | ||
from struct import unpack | ||
from io import BytesIO | ||
|
||
def get_handle_name(path): | ||
""" | ||
get handle name from file path | ||
File Path should be prefix/sparse_table/handle_name/rank_num/file.gz | ||
""" | ||
elements = path.split('/') | ||
return elements[-3] if elements else None | ||
|
||
def process_txt_line(line): | ||
""" | ||
Fetch sign and weights from sparse | ||
Data should be seperated by '\t', sign\tdim_num\tdim_num*weight | ||
""" | ||
data_list = line.split('\t') | ||
if len(data_list) < 3: | ||
return ("", []) | ||
else: | ||
sign = data_list[0] | ||
dim = int(data_list[1]) | ||
weights = data_list[2: dim+2] | ||
return (sign, weights) | ||
|
||
|
||
def process_binary_partition(iterator): | ||
""" | ||
Used by mapPartition to convert binary file to line record | ||
File has an int for dim_num on top, then each data should be sign, dim_num * weight, g2sum, show, no_show_days | ||
""" | ||
for filename, file_content in iterator: | ||
handle = filename.split('/')[-3] | ||
with io.BytesIO(file_content) as fc: | ||
with gzip.open(fc, 'rb') as gzip_file: | ||
dim = unpack('i', gzip_file.read(4))[0] | ||
while True: | ||
try: | ||
long_value = unpack('Q', gzip_file.read(8))[0] | ||
sign_str = str(long_value) | ||
weights = [unpack('f', gzip_file.read(4))[0] for _ in range(8)] | ||
g2sum = unpack('f', gzip_file.read(4))[0] | ||
show_rate = unpack('f', gzip_file.read(4))[0] | ||
no_show_days = unpack('i', gzip_file.read(4))[0] | ||
yield (handle, sign_str, weights) | ||
except Exception as e: | ||
print(e) | ||
yield ("","",[]) | ||
break | ||
|
||
|
||
def output_line(line, need_bracket): | ||
key = line[0] | ||
dims = [ element for element in line[1] ] | ||
if need_bracket: | ||
output_list = [key, "["] + dims + ["]"] | ||
else: | ||
output_list = [key] + dims | ||
return "\t".join(str(item) for item in output_list) |