Skip to content

Commit

Permalink
Merge branch 'output_file_format_upgrade' into 'master'
Browse files Browse the repository at this point in the history
input format detect/sparse merge tool

See merge request deep-learning/tensornet!12
  • Loading branch information
gzm55 committed Jul 1, 2024
2 parents c5e95fb + 50b1462 commit 0eaa912
Show file tree
Hide file tree
Showing 7 changed files with 173 additions and 9 deletions.
38 changes: 31 additions & 7 deletions core/ps/optimizer/optimizer_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <Eigen/Dense>
#include <cstring>
#include <cstdio>
#include <regex>

#include <boost/iostreams/stream.hpp>

Expand Down Expand Up @@ -505,16 +506,39 @@ class SparseOptimizerKernel : public SparseOptimizerKernelBase {

void DeSerialized(const std::string& filepath, const std::string& mode) {
std::vector<std::thread> threads;
std::string actual_mode = mode;
std::string file_prefix = "sparse_block_";
std::string file_suffix = "";

/*
file pattern could be sparse_block_0.gz | block_0.gz | sparse_block_0_[bin|txt].gz
without _[bin|txt] pattern, by default is txt mode
*/
std::vector<std::string> child_files;
if(FileUtils::GetChildren(filepath, &child_files)){
std::string first_file_str = child_files.front();
size_t index = first_file_str.find_last_of("/");
std::string first_file_name = first_file_str.substr(index + 1);;

std::regex regex_pattern(R"((.*?)(\d+)(.*)\.gz)");
std::smatch match;
if (std::regex_match(first_file_name, match, regex_pattern)) {
file_prefix = match[1];
if(match[3].matched && match[3].length() > 0){
file_suffix = match[3];
actual_mode = file_suffix.substr(1);
} else {
actual_mode = "txt";
}
}
std::cerr << file_prefix << std::endl;
std::cerr << file_suffix << std::endl;
}

for (size_t i = 0; i < SPARSE_KERNEL_BLOCK_NUM; ++i) {
threads.push_back(std::thread([this, i, &mode, &filepath]() {
threads.push_back(std::thread([this, i, &actual_mode, &filepath, &file_prefix, &file_suffix]() {
std::string file = filepath;
if(FileUtils::CheckFileExists(filepath + "/block_" + std::to_string(i) + ".gz")){
file.append("/block_").append(std::to_string(i)).append(".gz");
} else {
file.append("/sparse_block_").append(std::to_string(i)).append(".gz");
}

file.append("/" + file_prefix).append(std::to_string(i)).append(file_suffix).append(".gz");
FileReaderSource reader_source(file, FCT_ZLIB);
boost::iostreams::stream<FileReaderSource> in_stream(reader_source);

Expand Down
5 changes: 5 additions & 0 deletions core/utility/file_io.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "core/utility/file_io.h"

#include <string>
#include <vector>

#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/lib/io/buffered_inputstream.h"
Expand Down Expand Up @@ -165,5 +166,9 @@ bool FileUtils::CheckFileExists(const std::string& filepath) {
return tensorflow::Env::Default()-> FileExists(filepath).ok();
}

bool FileUtils::GetChildren(const std::string& dir, std::vector<std::string>* result) {
return tensorflow::Env::Default()-> GetChildren(dir, result).ok();
}

} // namespace tensornet

3 changes: 2 additions & 1 deletion core/utility/file_io.h
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#define TENSORNET_CORE_UTILITY_FILE_IO_H_

#include <string>
#include <vector>
#include <memory>
#include <iosfwd> // streamsize
#include <boost/iostreams/categories.hpp> // sink_tag, source_tag
Expand Down Expand Up @@ -73,7 +74,7 @@ class FileReaderSource {
class FileUtils {
public:
static bool CheckFileExists(const std::string& filepath);

static bool GetChildren(const std::string& dir, std::vector<std::string>* result);
};

} // namespace tensornet
Expand Down
3 changes: 2 additions & 1 deletion tensornet/callbacks/callbacks.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ def __init__(self, checkpoint_dir, checkpoint_save=None, need_save_model=False,
self.need_save_model = need_save_model
self.need_load_model = kwargs.get('need_load_model', True)
self.save_mode = save_mode
self.load_mode = kwargs.get('load_mode', self.save_mode)
self.model_path_incl_dt = model_path_incl_dt
self.dt = dt
self.delta_days = delta_days
Expand All @@ -45,7 +46,7 @@ def __init__(self, checkpoint_dir, checkpoint_save=None, need_save_model=False,
def load_model(self):
tn.core.barrier()
if self.need_load_model:
self.model.load_weights(self.checkpoint_dir, include_dt=self.model_path_incl_dt, mode=self.save_mode)
self.model.load_weights(self.checkpoint_dir, include_dt=self.model_path_incl_dt, mode=self.load_mode)
tn.core.barrier()

def reset_balance_dataset(self):
Expand Down
17 changes: 17 additions & 0 deletions tools/merge_sparse/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
## 合并sparse file

由于生成的sparse文件分布在各个目录, 可以通过spark将sparse文件结果抽取结果并合并到一个目录下

```bash
spark-submit3 --executor-memory 8g --driver-memory 10g --py-files utils.py merge_sparse.py -i "/user/test/model_path/sparse_table/*/*/*bin.gz" -o "/user/test/model_merge_path" -f 'bin' -n 500 -b
```

### 参数配置

配置名称 | 默认值 | 含义
----------- | ----------- | -----------
-i/--input | None | 输入路径
-o/--output | None | 输出路径
-f/--format | bin | 输入文件格式
-n/--number | 20 | 输出并行度
-b/--bracker | False | 输出的Weights是否需要用[]包括, []当作一列, 用\t分割
53 changes: 53 additions & 0 deletions tools/merge_sparse/merge_sparse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
#!/usr/bin/python3.6
#coding=utf-8
import sys
import argparse
import os
from pyspark import SparkContext, SparkConf
from pyspark.sql import *
from pyspark.sql.functions import *
from pyspark.sql import functions as F
from pyspark.sql.types import *
from utils import *


def parse_args():
parser = argparse.ArgumentParser()
parser.add_argument("-i", "--input", type=str, help="sparse table input path")
parser.add_argument("-o", "--output", type=str, help="merged file output path")
parser.add_argument("-f", "--format", type=str, help="input file format, 'txt' or 'bin'")
parser.add_argument("-n", "--number", type=int, help="output file parallelism", default=30)
parser.add_argument("-b", "--bracket", help="if dims need bracket", action="store_true", default=False)
args = parser.parse_args()
return args


def main(args):
spark = SparkSession.builder \
.appName("[spark][merge sparse table]") \
.master('yarn') \
.enableHiveSupport() \
.getOrCreate()

sc = spark.sparkContext

if args.format == 'txt':
get_handle_name_udf = udf(get_handle_name, StringType())
dims_df = sc.textFile(args.input)\
.map(lambda x: process_txt_line(x))\
.toDF(["key", "dims"])\
.withColumn("input_file_name",F.input_file_name())\
.withColumn("handle", get_handle_name_udf(col("input_file_name")))\
.drop("input_file_name")\
.filter(col("key") != "").dropDuplicates(['key','handle'])
elif args.format == 'bin':
dims_df = sc.binaryFiles(args.input)\
.mapPartitions(process_binary_partition)\
.toDF(['handle', 'key', 'dims'])

dims_df.dropDuplicates(['key','handle']).drop('handle').rdd.map(lambda x: output_line(x, args.bracket)).repartition(args.number).saveAsTextFile(args.output)


if __name__ == '__main__':
args = parse_args()
main(args)
63 changes: 63 additions & 0 deletions tools/merge_sparse/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
#coding=utf-8
import gzip
import struct
import io
from struct import unpack
from io import BytesIO

def get_handle_name(path):
"""
get handle name from file path
File Path should be prefix/sparse_table/handle_name/rank_num/file.gz
"""
elements = path.split('/')
return elements[-3] if elements else None

def process_txt_line(line):
"""
Fetch sign and weights from sparse
Data should be seperated by '\t', sign\tdim_num\tdim_num*weight
"""
data_list = line.split('\t')
if len(data_list) < 3:
return ("", [])
else:
sign = data_list[0]
dim = int(data_list[1])
weights = data_list[2: dim+2]
return (sign, weights)


def process_binary_partition(iterator):
"""
Used by mapPartition to convert binary file to line record
File has an int for dim_num on top, then each data should be sign, dim_num * weight, g2sum, show, no_show_days
"""
for filename, file_content in iterator:
handle = filename.split('/')[-3]
with io.BytesIO(file_content) as fc:
with gzip.open(fc, 'rb') as gzip_file:
dim = unpack('i', gzip_file.read(4))[0]
while True:
try:
long_value = unpack('Q', gzip_file.read(8))[0]
sign_str = str(long_value)
weights = [unpack('f', gzip_file.read(4))[0] for _ in range(8)]
g2sum = unpack('f', gzip_file.read(4))[0]
show_rate = unpack('f', gzip_file.read(4))[0]
no_show_days = unpack('i', gzip_file.read(4))[0]
yield (handle, sign_str, weights)
except Exception as e:
print(e)
yield ("","",[])
break


def output_line(line, need_bracket):
key = line[0]
dims = [ element for element in line[1] ]
if need_bracket:
output_list = [key, "["] + dims + ["]"]
else:
output_list = [key] + dims
return "\t".join(str(item) for item in output_list)

0 comments on commit 0eaa912

Please sign in to comment.