Skip to content

Commit

Permalink
Merge branch 'use_pctr_bn_logic' into 'master'
Browse files Browse the repository at this point in the history
Use pctr bn logic

See merge request deep-learning/tensornet!19
  • Loading branch information
gzm55 committed Jan 13, 2025
2 parents a1191ba + 441b14e commit 19eac0d
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 21 deletions.
4 changes: 2 additions & 2 deletions core/main/py_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -136,10 +136,10 @@ PYBIND11_MODULE(_pywrap_tn, m) {

return table->GetHandle();
})
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count) {
.def("create_bn_table", [](std::string name, uint32_t bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn) {
PsCluster* cluster = PsCluster::Instance();

BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count);
BnTable* table = CreateBnTable(name, cluster->RankNum(), cluster->Rank(), bn_size, sync, moment, max_count, use_pctr_dnn_bn);

return table->GetHandle();
})
Expand Down
31 changes: 21 additions & 10 deletions core/ps/table/bn_table.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,15 @@

namespace tensornet {

BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count)
BnTable::BnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool synchronized, float moment, uint64_t max_count, bool use_pctr_dnn_bn)
: shard_num_(shard_num)
, self_shard_id_(self_shard_id)
, name_(name)
, synchronized_(synchronized)
, moment_(moment)
, max_count_(max_count)
, bn_size_(bn_size) {
, bn_size_(bn_size)
, use_pctr_dnn_bn_(use_pctr_dnn_bn){
total_sum_.setZero(bn_size);
total_sum_err_.setZero(bn_size);
total_squared_sum_.setZero(bn_size);
Expand Down Expand Up @@ -104,9 +105,13 @@ void BnTable::TotalSumAcc(Eigen::ArrayXd acc){

std::tuple<Eigen::ArrayXf,Eigen::ArrayXf> BnTable::GetMoments() {
Eigen::ArrayXf global_mean = DivideNoNan(total_sum_, total_count_);
Eigen::ArrayXf global_squared_mean = DivideNoNan(total_squared_sum_, total_count_);
Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square()).max(0.0);
return std::make_tuple(global_mean, global_var);
if(use_pctr_dnn_bn_){
return std::make_tuple(global_mean, total_squared_sum_.cast<float>());
} else {
Eigen::ArrayXf global_squared_mean = DivideNoNan(total_squared_sum_, total_count_);
Eigen::ArrayXf global_var = (global_squared_mean - global_mean.square()).max(0.0);
return std::make_tuple(global_mean, global_var);
}
}

void BnTable::GetStatistics(const BnStatisticsPullRequest* req, butil::IOBuf& bn_statistics_buf, BnStatisticsPullResponse* resp) {
Expand All @@ -120,9 +125,9 @@ void BnTable::GetIncStatistics(butil::IOBuf& bn_statistics_buf) {
bn_statistics_buf.append(inc_sum_.data(), inc_sum_.size() * sizeof(double));
bn_statistics_buf.append(inc_squared_sum_.data(), inc_squared_sum_.size() * sizeof(double));
bn_statistics_buf.append(inc_count_.data(), inc_count_.size() * sizeof(double));
inc_sum_.setZero();
inc_squared_sum_.setZero();
inc_count_.setZero();
inc_sum_.setZero();
inc_squared_sum_.setZero();
inc_count_.setZero();
}


Expand Down Expand Up @@ -167,6 +172,11 @@ void BnTable::Load(const std::string& filepath) {
in_stream.iword(SERIALIZE_FMT_ID) = SF_BIN;

int bn_size = 0;
bool use_pctr_dnn_bn = false;

in_stream.read(reinterpret_cast<char*>(&use_pctr_dnn_bn), sizeof(use_pctr_dnn_bn));
CHECK_EQ(use_pctr_dnn_bn_, use_pctr_dnn_bn) << "bn calculate logic should be same, before use pctrdnn is " << use_pctr_dnn_bn;

in_stream.read(reinterpret_cast<char*>(&bn_size), sizeof(bn_size));

for( int i = 0; i < bn_size; i++) {
Expand All @@ -187,6 +197,7 @@ void BnTable::Save(const std::string& filepath) {
boost::iostreams::stream<FileWriterSink> out_stream(writer_sink);
out_stream.iword(SERIALIZE_FMT_ID) = SF_BIN;

out_stream.write(reinterpret_cast<const char*>(&use_pctr_dnn_bn_), sizeof(use_pctr_dnn_bn_));
out_stream.write(reinterpret_cast<const char*>(&bn_size_), sizeof(bn_size_));

for( int i = 0; i < bn_size_; i++) {
Expand Down Expand Up @@ -217,8 +228,8 @@ uint32_t BnTableRegistry::Register(BnTable* table) {
return table_handle;
}

BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count) {
BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count);
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn) {
BnTable* table = new BnTable(name, shard_num, self_shard_id, bn_size, sync, moment, max_count, use_pctr_dnn_bn);

table->SetHandle(BnTableRegistry::Instance()->Register(table));

Expand Down
7 changes: 4 additions & 3 deletions core/ps/table/bn_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ namespace tensornet {

class BnTable {
public:
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);
BnTable(const std::string& name,int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn);

~BnTable() = default;

Expand Down Expand Up @@ -67,7 +67,8 @@ class BnTable {
uint32_t handle_ = 0;
std::string name_;
uint32_t bn_size_ = 0;
bool synchronized_ = false;
bool synchronized_ = false;
bool use_pctr_dnn_bn_ = false;
float moment_ = 0.0;
uint64_t max_count_ = 0;
Eigen::ArrayXd total_sum_;
Expand Down Expand Up @@ -103,7 +104,7 @@ class BnTableRegistry {
std::vector<BnTable*> tables_;
};

BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count);
BnTable* CreateBnTable(const std::string& name, int shard_num, int self_shard_id, int bn_size, bool sync, float moment, uint64_t max_count, bool use_pctr_dnn_bn);

} // namespace tensornet

Expand Down
2 changes: 2 additions & 0 deletions tensornet/layers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,6 @@

from .embedding_features import EmbeddingFeatures
from .sequence_embedding_features import SequenceEmbeddingFeatures
from .normalization_layer import TNBatchNormalizationBase
from .normalization_layer import TNBatchNormalization
from .normalization_layer import PCTRDNNBatchNormalization
49 changes: 45 additions & 4 deletions tensornet/layers/normalization_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from tensorflow.python.ops import variable_scope, array_ops


class TNBatchNormalization(Layer):
class TNBatchNormalizationBase(Layer):
"""
Reference: https://github.com/keras-team/keras/blob/v3.5.0/keras/src/layers/normalization/batch_normalization.py
Expand All @@ -25,8 +25,10 @@ class TNBatchNormalization(Layer):
sync_freq: frequency that bn statistics will be sent to other ranks(based on batches). Only should be used when 'synchronized' is True
max_count: Threshold that to avoid bn statistics overflow. Note that: it's record number, not batch number. This is an empirical parameter that needs to be adjusted based on the size of the training data.
"""
_USE_PCTR_DNN_BN = False

def __init__(self, center=True, scale=True, epsilon=1e-5, momentum=0.99, name=None, synchronized=False, sync_freq=1,max_count=100000,**kwargs):
super(TNBatchNormalization, self).__init__(**kwargs)
super(TNBatchNormalizationBase, self).__init__(**kwargs)
self.center = center
self.scale = scale
self.epsilon = epsilon
Expand Down Expand Up @@ -97,8 +99,7 @@ def build(self, input_shape):
initializer=self.local_squared_num_initializer,
trainable=False)

self.bn_table_handle = tn.core.create_bn_table(self.name, self.apply_axis[0], self.synchronized, self.momentum, self.max_count)

self.bn_table_handle = tn.core.create_bn_table(self.name, self.apply_axis[0], self.synchronized, self.momentum, self.max_count, self._USE_PCTR_DNN_BN)

def call(self, inputs, training=None):

Expand Down Expand Up @@ -147,3 +148,43 @@ def save_bn_table(self, filepath):
def load_bn_table(self, filepath):
return tn.core.load_bn_table(self.bn_table_handle, filepath)


class TNBatchNormalization(TNBatchNormalizationBase):
"""
Calculate incremental count, sum, squared sum. use (squared_sum / count - (sum / count).square) as var
"""

class PCTRDNNBatchNormalization(TNBatchNormalizationBase):
"""
Calculate incremental count, sum. Calculate incremental (data - mean).sqrt() as var
"""
_USE_PCTR_DNN_BN = True

def call(self, inputs, training=None):

@tf.function
def _increment_and_check_count():
self.batch_counter.assign_add(1)
if tf.equal(self.batch_counter, self.sync_freq):
self.bn_statistics_push(True)
self.batch_counter.assign(0)
else:
self.bn_statistics_push(False)

self.update_moments()
mean = self.moving_mean
var = self.moving_variance

if training:
local_count_sample = tf.ones_like(inputs, name="count")
self.local_sum.assign(tf.reduce_sum(inputs, axis=self.moments_axes))
self.local_squared_sum.assign(tf.reduce_sum(tf.square(inputs - self.moving_mean), axis=self.moments_axes))
self.local_count.assign(tf.reduce_sum(local_count_sample, axis=self.moments_axes))
if self.synchronized:
_increment_and_check_count()
else:
self.bn_statistics_push(False)

outputs = tf.nn.batch_normalization(x=inputs, mean=mean, variance=var, offset=self.beta, scale=self.gamma, variance_epsilon=self.epsilon)

return outputs
4 changes: 2 additions & 2 deletions tensornet/model/Model.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ def save_weights(self, filepath, overwrite=True, save_format=None, dt="", root=T
layer.save_sparse_table(cp_dir, mode)
elif isinstance(layer, tn.layers.SequenceEmbeddingFeatures):
layer.save_sparse_table(cp_dir, mode)
elif isinstance(layer, tn.layers.TNBatchNormalization):
elif isinstance(layer, tn.layers.TNBatchNormalizationBase):
if tn.core.self_shard_id() == 0:
layer.bn_statistics_pull()
layer.save_bn_table(cp_dir)
Expand Down Expand Up @@ -223,7 +223,7 @@ def load_weights(self, filepath, by_name=False, skip_mismatch=False, include_dt=
layer.load_sparse_table(cp_dir, mode)
elif isinstance(layer, tn.layers.SequenceEmbeddingFeatures):
layer.load_sparse_table(cp_dir, mode)
elif isinstance(layer, tn.layers.TNBatchNormalization):
elif isinstance(layer, tn.layers.TNBatchNormalizationBase):
layer.load_bn_table(cp_dir)

# dense weight
Expand Down

0 comments on commit 19eac0d

Please sign in to comment.