Skip to content

Commit

Permalink
Update to always using ops.shape (#1231)
Browse files Browse the repository at this point in the history
With recently released changes to keras-core, ops.shape will always
return a tuple.
  • Loading branch information
mattdangerw authored Aug 31, 2023
1 parent a1987b8 commit 60af93f
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 5 deletions.
4 changes: 2 additions & 2 deletions keras_nlp/samplers/beam_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,12 +124,12 @@ def create_beams(x):

def flatten_beams(x):
"""Combine the beam dim and batch dim."""
flat_shape = (batch_size * self.num_beams,) + tuple(x.shape)[2:]
flat_shape = (batch_size * self.num_beams,) + ops.shape(x)[2:]
return ops.reshape(x, flat_shape)

def unflatten_beams(x):
"""Separate the beam dim and batch dim."""
unflat_shape = (batch_size, self.num_beams) + tuple(x.shape)[1:]
unflat_shape = (batch_size, self.num_beams) + ops.shape(x)[1:]
return ops.reshape(x, unflat_shape)

if mask is None:
Expand Down
6 changes: 3 additions & 3 deletions keras_nlp/samplers/contrastive_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,17 +103,17 @@ def __call__(
def create_beams(x):
"""Add initial beam state."""
x = ops.repeat(x, self.k, axis=0)
flat_shape = (batch_size * self.k,) + tuple(x.shape)[1:]
flat_shape = (batch_size * self.k,) + ops.shape(x)[1:]
return ops.reshape(x, flat_shape)

def flatten_beams(x):
"""Combine the beam dim and batch dim."""
flat_shape = (batch_size * self.k,) + tuple(x.shape)[2:]
flat_shape = (batch_size * self.k,) + ops.shape(x)[2:]
return ops.reshape(x, flat_shape)

def unflatten_beams(x):
"""Separate the beam dim and batch dim."""
unflat_shape = (batch_size, self.k) + tuple(x.shape)[1:]
unflat_shape = (batch_size, self.k) + ops.shape(x)[1:]
return ops.reshape(x, unflat_shape)

mask = ops.zeros_like(prompt, dtype="bool") if mask is None else mask
Expand Down

0 comments on commit 60af93f

Please sign in to comment.