Skip to content

Commit

Permalink
Fix quantizer base dtype
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 660442990
  • Loading branch information
bignamehyp authored and pax authors committed Aug 7, 2024
1 parent 81154d8 commit f99d63c
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions praxis/layers/video/quantizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def __call__(self, inputs: JTensor) -> tuple[JTensor, NestedMap]:
result_dict = NestedMap()
group_size = int(math.ceil(self.embedding_dim / self.num_token_groups))
base = jnp.power(
2, jnp.arange(self.embedding_dim, dtype=jnp.uint8) % group_size
2, jnp.arange(self.embedding_dim, dtype=jnp.uint32) % group_size
)
samples = inputs >= 0
quantized = jnp.where(samples, 1.0, -1.0)
Expand Down Expand Up @@ -215,7 +215,7 @@ def __call__(self, inputs: JTensor) -> tuple[JTensor, NestedMap]:
def decode_ids(self, inputs: JTensor) -> JTensor:
group_size = int(math.ceil(self.embedding_dim / self.num_token_groups))
base = jnp.power(
2, jnp.arange(self.embedding_dim, dtype=jnp.uint8) % group_size
2, jnp.arange(self.embedding_dim, dtype=jnp.uint32) % group_size
)
if self.num_token_groups == 1:
inputs = inputs[..., None]
Expand Down

0 comments on commit f99d63c

Please sign in to comment.