-
Notifications
You must be signed in to change notification settings - Fork 71
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge branch 'support_sparse_table_resize' into 'master'
add sparse resize See merge request deep-learning/tensornet!13
- Loading branch information
Showing
8 changed files
with
440 additions
and
98 deletions.
There are no files selected for viewing
This file was deleted.
Oops, something went wrong.
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,54 @@ | ||
## 合并sparse file | ||
|
||
由于生成的sparse文件分布在各个目录, 可以通过spark将sparse文件结果抽取结果并合并到一个目录下 | ||
|
||
```bash | ||
spark-submit3 --executor-memory 8g --driver-memory 10g --py-files utils.py merge_sparse.py -i "/user/test/model_path/sparse_table/*/*/*bin.gz" -o "/user/test/model_merge_path" -f 'bin' -n 500 -b | ||
``` | ||
|
||
### 参数配置 | ||
|
||
配置名称 | 默认值 | 含义 | ||
----------- | ----------- | ----------- | ||
-i/--input | None | 输入路径 | ||
-o/--output | None | 输出路径 | ||
-f/--format | bin | 输入文件格式 | ||
-n/--number | 20 | 输出并行度 | ||
-b/--bracker | False | 输出的Weights是否需要用[]包括, []当作一列, 用\t分割 | ||
|
||
|
||
## sparse切换并行度 | ||
|
||
现阶段生成的sparse_table目录并行度无法切换,如果前后不一致会导致数据缺失问题,无法扩缩容。通过spark读入原始数据,按照指定的并行度输出文件parttern | ||
|
||
由于使用了hdfs3来写入文件,需要打包上传环境,使用[env文件](config/tn_tool_env.yaml) | ||
|
||
```bash | ||
spark-submit3 --conf spark.executor.memory=10g --conf spark.archives=hdfs://nn/user/test/cache/python.tar.gz#envs --conf spark.pyspark.driver.python=/home/test/micromamba/envs/tn_tool_env/bin/python --conf spark.pyspark.python=./envs/bin/python --py-files utils.py resize_sparse.py --input /user/test/model/* --output /user/test/resize --number 50 | ||
``` | ||
|
||
### 参数配置 | ||
|
||
配置名称 | 默认值 | 含义 | ||
----------- | ----------- | ----------- | ||
-i/--input | None | 输入路径, 会抓取hdfs头用作hdfs文件写入,如没有hdfs头会默认用hdfs://ss-hadoop2 | ||
-o/--output | None | 输出路径,会在输出路径下生成 handle_name/rank_number/block_num.gz 文件 | ||
-f/--format | bin | 输入文件格式 | ||
-n/--number | 20 | 输出并行度 | ||
|
||
|
||
## dense切换并行度 | ||
|
||
和 sparse 类似 | ||
|
||
```bash | ||
spark-submit3 --conf spark.executor.memory=10g --conf spark.archives=hdfs://nn/user/test/cache/python.tar.gz#envs --conf spark.pyspark.driver.python=/home/test/micromamba/envs/tn_tool_env/bin/python --conf spark.pyspark.python=./envs/bin/python --py-files utils.py resize_dense.py --input /user/test/model/* --output /user/test/resize --number 50 | ||
``` | ||
|
||
### 参数配置 | ||
|
||
配置名称 | 默认值 | 含义 | ||
----------- | ----------- | ----------- | ||
-i/--input | None | 输入路径, 会抓取hdfs头用作hdfs文件写入,如没有hdfs头会默认用hdfs://ss-hadoop2 | ||
-o/--output | None | 输出路径,会在输出路径下生成 handle_name/rank_number 文件 | ||
-n/--number | 20 | 输出并行度 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
name: tn_build | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.8 | ||
- nomkl | ||
- openssl>=3 | ||
- hdfs3 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,42 @@ | ||
#coding=utf-8 | ||
import sys | ||
import argparse | ||
import os | ||
from pyspark import SparkContext, SparkConf | ||
from pyspark.sql import * | ||
from pyspark.sql.functions import * | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.types import * | ||
from utils import * | ||
import math | ||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--input", type=str, help="sparse table input path") | ||
parser.add_argument("-o", "--output", type=str, help="merged file output path") | ||
parser.add_argument("-n", "--number", type=int, help="output file parallelism", default=30) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args): | ||
spark = SparkSession.builder \ | ||
.appName("[spark][resize dense table]") \ | ||
.master('yarn') \ | ||
.enableHiveSupport() \ | ||
.getOrCreate() | ||
|
||
sc = spark.sparkContext | ||
output_bc_value = sc.broadcast(args.output) | ||
dense_file_rdd = sc.wholeTextFiles(args.input).map(lambda x: (x[0].split("/")[-1], x[0].split("/")[-2], x[1])).flatMap(mapIndexToDenseRecord) | ||
|
||
whole_data = dense_file_rdd.collect() | ||
res = process_whole_text(whole_data, args.number) | ||
|
||
res_rdd = sc.parallelize(res, args.number) | ||
res_rdd.foreachPartition(lambda p:write_dense_partition(p, output_bc_value)) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
#coding=utf-8 | ||
import sys | ||
import argparse | ||
import os | ||
from pyspark import SparkContext, SparkConf | ||
from pyspark.sql import * | ||
from pyspark.sql.functions import * | ||
from pyspark.sql import functions as F | ||
from pyspark.sql.types import * | ||
from utils import * | ||
|
||
|
||
def parse_args(): | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("-i", "--input", type=str, help="sparse table input path") | ||
parser.add_argument("-o", "--output", type=str, help="merged file output path") | ||
parser.add_argument("-f", "--format", type=str, help="input file format, 'txt' or 'bin'") | ||
parser.add_argument("-n", "--number", type=int, help="output file parallelism", default=30) | ||
args = parser.parse_args() | ||
return args | ||
|
||
|
||
def main(args): | ||
spark = SparkSession.builder \ | ||
.appName("[spark][resize sparse table]") \ | ||
.master('yarn') \ | ||
.enableHiveSupport() \ | ||
.getOrCreate() | ||
|
||
sc = spark.sparkContext | ||
output_bc_value = sc.broadcast(args.output) | ||
format_bc_value = sc.broadcast(args.format) | ||
number_bc_value = sc.broadcast(args.number) | ||
|
||
handle_names = fetch_hanlds(args.input) | ||
handle_names_bc_value = sc.broadcast(handle_names) | ||
|
||
dims_df = load_sparse_table_to_df(sc, args.input, args.format) | ||
|
||
dims_df.rdd.map(lambda row: (get_sign_partition_key(row[0], args.number), row)).partitionBy(args.number * BLOCK_NUM)\ | ||
.foreachPartition(lambda p: resize_partition(p, output_bc_value, format_bc_value, number_bc_value, handle_names_bc_value)) | ||
|
||
|
||
if __name__ == '__main__': | ||
args = parse_args() | ||
main(args) |
Oops, something went wrong.