Skip to content

Commit

Permalink
Merge pull request #45 from Qihoo360/zhangys
Browse files Browse the repository at this point in the history
supoort sparse table save with binary format
  • Loading branch information
zhangys-lucky authored Dec 27, 2020
2 parents 43521f7 + ed933f3 commit 8ab0661
Show file tree
Hide file tree
Showing 15 changed files with 261 additions and 85 deletions.
10 changes: 6 additions & 4 deletions core/main/py_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,13 +134,15 @@ PYBIND11_MODULE(_pywrap_tn, m) {

return table->GetHandle();
})
.def("save_sparse_table", [](uint32_t table_handle, std::string filepath) {
.def("save_sparse_table", [](uint32_t table_handle, std::string filepath,
const std::string& mode="txt") {
SparseTable* table = SparseTableRegistry::Instance()->Get(table_handle);
return table->Save(filepath);
return table->Save(filepath, mode);
})
.def("load_sparse_table", [](uint32_t table_handle, std::string filepath) {
.def("load_sparse_table", [](uint32_t table_handle, std::string filepath,
const std::string& mode="txt") {
SparseTable* table = SparseTableRegistry::Instance()->Get(table_handle);
return table->Load(filepath);
return table->Load(filepath, mode);
})
.def("save_dense_table", [](uint32_t table_handle, std::string filepath) {
DenseTable* table = DenseTableRegistry::Instance()->Get(table_handle);
Expand Down
22 changes: 17 additions & 5 deletions core/ps/optimizer/ada_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ SparseAdaGradValue::SparseAdaGradValue(int dim, const AdaGrad* opt) {
}

void SparseAdaGradValue::Apply(const AdaGrad* opt, SparseGradInfo& grad_info, int dim) {
delta_show += grad_info.batch_show;
delta_show_ += grad_info.batch_show;

float* w = Weight();

Expand All @@ -106,22 +106,34 @@ void SparseAdaGradValue::Apply(const AdaGrad* opt, SparseGradInfo& grad_info, in
}
}

void SparseAdaGradValue::Serialize(std::ostream& os, int dim) {
void SparseAdaGradValue::SerializeTxt_(std::ostream& os, int dim) {
for (int i = 0; i < dim; i++) {
os << Weight()[i] << "\t";
}

os << g2sum_ << "\t";
os << show;
os << show_;
}

void SparseAdaGradValue::DeSerialize(std::istream& is, int dim) {
void SparseAdaGradValue::DeSerializeTxt_(std::istream& is, int dim) {
for (int i = 0; i < dim; i++) {
is >> Weight()[i];
}

is >> g2sum_;
is >> show;
is >> show_;
}

void SparseAdaGradValue::SerializeBin_(std::ostream& os, int dim) {
os.write(reinterpret_cast<const char*>(Weight()), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(&g2sum_), sizeof(g2sum_));
os.write(reinterpret_cast<const char*>(&show_), sizeof(show_));
}

void SparseAdaGradValue::DeSerializeBin_(std::istream& is, int dim) {
is.read(reinterpret_cast<char*>(Weight()), dim * sizeof(float));
is.read(reinterpret_cast<char*>(&g2sum_), sizeof(g2sum_));
is.read(reinterpret_cast<char*>(&show_), sizeof(show_));
}

} // namespace tensornet
Expand Down
10 changes: 6 additions & 4 deletions core/ps/optimizer/ada_grad_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class DenseAdaGradValue {
std::ostream& operator<<(std::ostream& os, const DenseAdaGradValue& value);
std::istream& operator>>(std::istream& is, DenseAdaGradValue& value);

struct alignas(4) SparseAdaGradValue
class alignas(4) SparseAdaGradValue
: public SparseOptValue {
public:
SparseAdaGradValue(int dim, const AdaGrad* opt);
Expand All @@ -74,9 +74,11 @@ struct alignas(4) SparseAdaGradValue

void Apply(const AdaGrad* opt, SparseGradInfo& grad_info, int dim);

void Serialize(std::ostream& os, int dim);

void DeSerialize(std::istream& is, int dim);
protected:
virtual void SerializeTxt_(std::ostream& os, int dim);
virtual void DeSerializeTxt_(std::istream& is, int dim);
virtual void SerializeBin_(std::ostream& os, int dim);
virtual void DeSerializeBin_(std::istream& is, int dim);

private:
float g2sum_;
Expand Down
24 changes: 19 additions & 5 deletions core/ps/optimizer/adam_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,7 @@ SparseAdamValue::SparseAdamValue(int dim, const Adam* opt) {
}

void SparseAdamValue::Apply(const Adam* opt, SparseGradInfo& grad_info, int dim) {
delta_show += grad_info.batch_show;
delta_show_ += grad_info.batch_show;

float* w = Weight();
float* m = M(dim);
Expand All @@ -118,7 +118,7 @@ void SparseAdamValue::Apply(const Adam* opt, SparseGradInfo& grad_info, int dim)
}
}

void SparseAdamValue::Serialize(std::ostream& os, int dim) {
void SparseAdamValue::SerializeTxt_(std::ostream& os, int dim) {
float* w = Weight();
float* m = M(dim);
float* v = V(dim);
Expand All @@ -129,10 +129,10 @@ void SparseAdamValue::Serialize(std::ostream& os, int dim) {
os << v[i] << "\t";
}

os << show;
os << show_;
}

void SparseAdamValue::DeSerialize(std::istream& is, int dim) {
void SparseAdamValue::DeSerializeTxt_(std::istream& is, int dim) {
float* w = Weight();
float* m = M(dim);
float* v = V(dim);
Expand All @@ -143,7 +143,21 @@ void SparseAdamValue::DeSerialize(std::istream& is, int dim) {
is >> v[i];
}

is >> show;
is >> show_;
}

void SparseAdamValue::SerializeBin_(std::ostream& os, int dim) {
os.write(reinterpret_cast<const char*>(Weight()), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(M(dim)), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(V(dim)), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(&show_), sizeof(show_));
}

void SparseAdamValue::DeSerializeBin_(std::istream& is, int dim) {
is.read(reinterpret_cast<char*>(Weight()), dim * sizeof(float));
is.read(reinterpret_cast<char*>(M(dim)), dim * sizeof(float));
is.read(reinterpret_cast<char*>(V(dim)), dim * sizeof(float));
is.read(reinterpret_cast<char*>(&show_), sizeof(show_));
}

} // namespace tensornet {
Expand Down
12 changes: 6 additions & 6 deletions core/ps/optimizer/adam_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ class DenseAdamValue {
std::ostream& operator<<(std::ostream& os, const DenseAdamValue& value);
std::istream& operator>>(std::istream& is, DenseAdamValue& value);

struct alignas(4) SparseAdamValue
class alignas(4) SparseAdamValue
: public SparseOptValue {
public:
SparseAdamValue(int dim, const Adam* opt);
Expand All @@ -76,12 +76,7 @@ struct alignas(4) SparseAdamValue

void Apply(const Adam* opt, SparseGradInfo& grad_info, int dim);

void Serialize(std::ostream& os, int dim);

void DeSerialize(std::istream& is, int dim);

protected:

float* M(int dim) {
return data_ + dim * 1;
}
Expand All @@ -98,6 +93,11 @@ struct alignas(4) SparseAdamValue
return data_ + dim * 2;
}

virtual void SerializeTxt_(std::ostream& os, int dim);
virtual void DeSerializeTxt_(std::istream& is, int dim);
virtual void SerializeBin_(std::ostream& os, int dim);
virtual void DeSerializeBin_(std::istream& is, int dim);

private:
float data_[0];
};
Expand Down
46 changes: 46 additions & 0 deletions core/ps/optimizer/data_struct.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
// Copyright (c) 2020, Qihoo, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <iostream>

#include "core/ps/optimizer/data_struct.h"

namespace tensornet {

int const SERIALIZE_FMT_ID = std::ios_base::xalloc();

void SparseOptValue::Serialize(std::ostream& os, int dim) {
switch (os.iword(SERIALIZE_FMT_ID)) {
case SF_TXT:
SerializeTxt_(os, dim);
break;
case SF_BIN:
SerializeBin_(os, dim);
break;
}
}

void SparseOptValue::DeSerialize(std::istream& is, int dim) {
switch (is.iword(SERIALIZE_FMT_ID)) {
case SF_TXT:
DeSerializeTxt_(is, dim);
break;
case SF_BIN:
DeSerializeBin_(is, dim);
break;
}
}

} // namespace tensornet

34 changes: 29 additions & 5 deletions core/ps/optimizer/data_struct.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,14 +22,38 @@ struct SparseGradInfo {
int batch_show;
};

struct alignas(4) SparseOptValue {
float show = 0.0;
int delta_show = 0;
extern int const SERIALIZE_FMT_ID;

enum SerializeFormat {
SF_TXT,
SF_BIN,
};

class alignas(4) SparseOptValue {
public:
void ShowDecay(float decay_rate) {
show = (1 - decay_rate) * delta_show + decay_rate * show;
delta_show = 0;
show_ = (1 - decay_rate) * delta_show_ + decay_rate * show_;
delta_show_ = 0;
}

void Serialize(std::ostream& os, int dim);

void DeSerialize(std::istream& is, int dim);

float Show() const {
return show_;
}

protected:
virtual void SerializeTxt_(std::ostream& os, int dim) = 0;
virtual void DeSerializeTxt_(std::istream& is, int dim) = 0;
virtual void SerializeBin_(std::ostream& os, int dim) = 0;
virtual void DeSerializeBin_(std::istream& is, int dim) = 0;

protected:
float show_ = 0.0;
int delta_show_ = 0;

};

} // namespace tensornet {
Expand Down
24 changes: 19 additions & 5 deletions core/ps/optimizer/ftrl_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ SparseFtrlValue::SparseFtrlValue(int dim, const Ftrl* opt) {
}

void SparseFtrlValue::Apply(const Ftrl* opt, SparseGradInfo& grad_info, int dim) {
delta_show += grad_info.batch_show;
delta_show_ += grad_info.batch_show;

float* w = Weight();
float* z = Z(dim);
Expand All @@ -81,7 +81,7 @@ void SparseFtrlValue::Apply(const Ftrl* opt, SparseGradInfo& grad_info, int dim)
}
}

void SparseFtrlValue::Serialize(std::ostream& os, int dim) {
void SparseFtrlValue::SerializeTxt_(std::ostream& os, int dim) {
float* w = Weight();
float* z = Z(dim);
float* n = N(dim);
Expand All @@ -92,10 +92,10 @@ void SparseFtrlValue::Serialize(std::ostream& os, int dim) {
os << n[i] << "\t";
}

os << show;
os << show_;
}

void SparseFtrlValue::DeSerialize(std::istream& is, int dim) {
void SparseFtrlValue::DeSerializeTxt_(std::istream& is, int dim) {
float* w = Weight();
float* z = Z(dim);
float* n = N(dim);
Expand All @@ -106,7 +106,21 @@ void SparseFtrlValue::DeSerialize(std::istream& is, int dim) {
is >> n[i];
}

is >> show;
is >> show_;
}

void SparseFtrlValue::SerializeBin_(std::ostream& os, int dim) {
os.write(reinterpret_cast<const char*>(Weight()), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(Z(dim)), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(N(dim)), dim * sizeof(float));
os.write(reinterpret_cast<const char*>(&show_), sizeof(show_));
}

void SparseFtrlValue::DeSerializeBin_(std::istream& is, int dim) {
is.read(reinterpret_cast<char*>(Weight()), dim * sizeof(float));
is.read(reinterpret_cast<char*>(Z(dim)), dim * sizeof(float));
is.read(reinterpret_cast<char*>(N(dim)), dim * sizeof(float));
is.read(reinterpret_cast<char*>(&show_), sizeof(show_));
}

} // namespace tensornet
Expand Down
11 changes: 6 additions & 5 deletions core/ps/optimizer/ftrl_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ class DenseFtrlValue {
std::ostream& operator<<(std::ostream& os, const DenseFtrlValue& value);
std::istream& operator>>(std::istream& is, DenseFtrlValue& value);

struct alignas(4) SparseFtrlValue
class alignas(4) SparseFtrlValue
: public SparseOptValue {
public:
SparseFtrlValue(int dim, const Ftrl* opt);
Expand All @@ -73,10 +73,6 @@ struct alignas(4) SparseFtrlValue

void Apply(const Ftrl* opt, SparseGradInfo& grad_info, int dim);

void Serialize(std::ostream& os, int dim);

void DeSerialize(std::istream& is, int dim);

protected:
float* Z(int dim) {
return data_ + dim * 1;
Expand All @@ -94,6 +90,11 @@ struct alignas(4) SparseFtrlValue
return data_ + dim * 2;
}

virtual void SerializeTxt_(std::ostream& os, int dim);
virtual void DeSerializeTxt_(std::istream& is, int dim);
virtual void SerializeBin_(std::ostream& os, int dim);
virtual void DeSerializeBin_(std::istream& is, int dim);

private:
float data_[0];
};
Expand Down
Loading

0 comments on commit 8ab0661

Please sign in to comment.