Skip to content

Commit

Permalink
More Pre-Commit Fixes, I need to make this actually run pre-commit
Browse files Browse the repository at this point in the history
  • Loading branch information
Helw150 committed Nov 20, 2024
1 parent 708955a commit a87e8a6
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 5 deletions.
3 changes: 2 additions & 1 deletion src/levanter/data/audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ def __init__(
padding=True,
):
self.feature_extractor: SequenceFeatureExtractor = processor.feature_extractor
if tokenizer.pad_token_id == None:
if tokenizer.pad_token_id is None:
override_token = list(tokenizer.added_tokens_decoder.items())[-1]
tokenizer.pad_token_id = override_token[0]
tokenizer.pad_tokn = str(override_token[1])
Expand Down Expand Up @@ -276,6 +276,7 @@ class ProcessedAudioCache(AsyncDataset[AudioTextDict]):
def __init__(self, cache: TreeCache[AudioTextDict]):
super().__init__()
self.cache = cache
self._cached_len: Optional[int] = None

async def async_len(self) -> int:
return await self.cache.async_len()
Expand Down
11 changes: 7 additions & 4 deletions src/levanter/main/train_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,8 @@
from levanter.compat.hf_checkpoints import HFCompatConfig, ModelWithHfSerializationMixin, save_hf_checkpoint_callback
from levanter.data.audio import AudioIODatasetConfig, AudioMixtureDatasetConfig, AudioTextDataset
from levanter.models.asr_model import ASRConfig, AudioTextExample
from levanter.models.diva import DivaASRModel, diva_connector_only
from levanter.models.whisper import WhisperConfig
from levanter.models.diva import diva_connector_only
from levanter.optim import AdamConfig, OptimizerConfig
from levanter.trainer import Trainer, TrainerConfig
from levanter.utils.jax_utils import parameter_count
Expand Down Expand Up @@ -138,12 +138,15 @@ def compute_loss(
if vocab_size != Vocab.size:
logger.info(f"Rounding vocab size from {vocab_size} to {Vocab.size} for partitioning")

state = trainer.initial_state(training_key, model_init=lambda: config.model.build_asr(Vocab, key=model_key), )
state = trainer.initial_state(
training_key,
model_init=lambda: config.model.build_asr(Vocab, key=model_key),
)

if int(state.step) == 0:
if config.diva_training:
if config.diva_training and config.model.asr_model_type == DivaASRModel:
state = dataclasses.replace(state, model=None)
model = config.model.asr_model_type.init(Vocab, config.model, key=model_key, init_from_submodels=True)
model = DivaASRModel.init(Vocab, config.model, key=model_key, init_from_submodels=True)
model = named_jit(trainer.mp.cast_to_param, parameter_axis_mapping)(model)
state = dataclasses.replace(state, model=model, is_trainable=diva_connector_only(model))
# TODO: I don't love that we init the model twice, but it's not a big deal i think?
Expand Down

0 comments on commit a87e8a6

Please sign in to comment.