Skip to content

Commit

Permalink
fix: model, dtype, assertions
Browse files Browse the repository at this point in the history
Signed-off-by: Will Johnson <[email protected]>
  • Loading branch information
willmj committed Nov 15, 2024
1 parent 549f1af commit ac757a8
Show file tree
Hide file tree
Showing 2 changed files with 9 additions and 9 deletions.
17 changes: 9 additions & 8 deletions tests/acceleration/test_acceleration_framework.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@
# for some reason the CI will raise an import error if we try to import
# these from tests.data
TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join(
os.path.dirname(__file__), "../data/twitter_complaints_json.json"
os.path.dirname(__file__), "../data/twitter_complaints_small.json"
)
TWITTER_COMPLAINTS_TOKENIZED = os.path.join(
os.path.dirname(__file__),
Expand Down Expand Up @@ -365,7 +365,7 @@ def test_framework_raises_due_to_invalid_arguments(
acceleration_configs_map,
ids=["bitsandbytes", "auto_gptq"],
)
def test_framework_intialized_properly_peft(
def test_framework_initialized_properly_peft(
quantized_lora_config, model_name_or_path, mock_and_spy
):
"""Ensure that specifying a properly configured acceleration dataclass
Expand Down Expand Up @@ -417,7 +417,7 @@ def test_framework_intialized_properly_peft(
"and foak plugins"
),
)
def test_framework_intialized_properly_foak():
def test_framework_initialized_properly_foak():
"""Ensure that specifying a properly configured acceleration dataclass
properly activates the framework plugin and runs the train sucessfully.
"""
Expand Down Expand Up @@ -486,19 +486,20 @@ def test_framework_intialized_properly_foak():
not is_fms_accelerate_available(plugins="moe"),
reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin",
)
def test_framework_intialized_properly_moe():
def test_framework_initialized_properly_moe():
"""Ensure that specifying a properly configured acceleration dataclass
properly activates the framework plugin and runs the train sucessfully.
"""

with tempfile.TemporaryDirectory() as tempdir:

model_args = copy.deepcopy(MODEL_ARGS)
model_args.model_name_or_path = "ibm-granite/granite-3.0-1b-a400m-instruct"
model_args.use_flash_attn = True
model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE"
model_args.torch_dtype = torch.bfloat16
train_args = copy.deepcopy(TRAIN_ARGS)
train_args.output_dir = tempdir
train_args.save_strategy = "no"
train_args.bf16 = True
data_args = copy.deepcopy(DATA_ARGS)
data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT
data_args.response_template = "\n\n### Label:"
Expand Down Expand Up @@ -530,8 +531,8 @@ def test_framework_intialized_properly_moe():
)

# spy inside the train to ensure that the ilab plugin is called
assert spy["model_loader_calls"] == 0
assert spy["augmentation_calls"] == 1
assert spy["model_loader_calls"] == 1
assert spy["augmentation_calls"] == 0
assert spy["get_ready_for_train_calls"] == 1


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -322,7 +322,6 @@ def _descend_and_set(path: List[str], d: Dict):
already_set.add(prefix_path)
_descend_and_set(path, asdict(datacls))

print(configuration_contents)
return configuration_contents

def to_yaml(self, filename: str):
Expand Down

0 comments on commit ac757a8

Please sign in to comment.