Skip to content

Latest commit

 

History

History
96 lines (78 loc) · 3.14 KB

File metadata and controls

96 lines (78 loc) · 3.14 KB

tfra.dynamic_embedding.get_variable

View source on GitHub




Gets an Variable object with this name if it exists,

tfra.dynamic_embedding.get_variable(
    name,
    key_dtype=dtypes.int64,
    value_dtype=dtypes.float32,
    dim=1,
    devices=None,
    partitioner=default_partition_fn,
    shared_name='get_variable',
    initializer=None,
    trainable=(True),
    checkpoint=(True),
    init_size=0,
    kv_creator=None,
    restrict_policy=None,
    bp_v2=(False)
)
 or create a new one.

Args:

  • name: A unique name for the Variable.
  • key_dtype: the type of the key tensors.
  • value_dtype: the type of the value tensors.
  • dim: the length of the value array for each key.
  • devices: the list of devices holding the tables. One table will be created on each device.
  • partitioner: partition function of keys, return the partition index for each key.

Example partition func:

def default_partition_fn(keys, shard_num):
  return tf.cast(keys % shard_num, dtype=tf.int32)
  • shared_name: No used.
  • initializer: The value to use if a key is missing in the hash table. which can a python number, numpy array or tf.initializer instances. If initializer is None (the default), 0 will be used.
  • trainable: Bool. If true, the variable will be treated as a trainable. Default is true.
  • checkpoint: if True, the contents of the SparseVariable are saved to and restored from checkpoints. If shared_name is empty for a checkpointed table, it is shared using the table node name.
  • init_size: initial size for the Variable and initial size of each hash tables will be int(init_size / N), N is the number of the devices.
  • restrict_policy: a restrict policy to specify the rule to restrict the size of variable. If in training program, the variable is updated by optimizer, then the sparse slot variables in optimizer are also be restricted.
  • bp_v2: By default with bp_v2=False, the optimizer will update dynamic embedding values by setting (key, value) after optimizer.apply_gradient. If one key is used by multiple workers at the same time, only one of them will be seen, while the others are overwritten. By setting bp_v2=True, the optimizer will update parameters by adding delta instead of setting, which solves the race condition problem among workers during backpropagation in large-scale distributed asynchronous training.

Returns:

A Variable object.