diff --git a/tests/acceleration/test_acceleration_framework.py b/tests/acceleration/test_acceleration_framework.py index deb3319ad..62481ffd1 100644 --- a/tests/acceleration/test_acceleration_framework.py +++ b/tests/acceleration/test_acceleration_framework.py @@ -57,7 +57,7 @@ # for some reason the CI will raise an import error if we try to import # these from tests.data TWITTER_COMPLAINTS_JSON_FORMAT = os.path.join( - os.path.dirname(__file__), "../data/twitter_complaints_json.json" + os.path.dirname(__file__), "../data/twitter_complaints_small.json" ) TWITTER_COMPLAINTS_TOKENIZED = os.path.join( os.path.dirname(__file__), @@ -365,7 +365,7 @@ def test_framework_raises_due_to_invalid_arguments( acceleration_configs_map, ids=["bitsandbytes", "auto_gptq"], ) -def test_framework_intialized_properly_peft( +def test_framework_initialized_properly_peft( quantized_lora_config, model_name_or_path, mock_and_spy ): """Ensure that specifying a properly configured acceleration dataclass @@ -417,7 +417,7 @@ def test_framework_intialized_properly_peft( "and foak plugins" ), ) -def test_framework_intialized_properly_foak(): +def test_framework_initialized_properly_foak(): """Ensure that specifying a properly configured acceleration dataclass properly activates the framework plugin and runs the train sucessfully. """ @@ -486,7 +486,7 @@ def test_framework_intialized_properly_foak(): not is_fms_accelerate_available(plugins="moe"), reason="Only runs if fms-accelerate is installed along with accelerated-moe plugin", ) -def test_framework_intialized_properly_moe(): +def test_framework_initialized_properly_moe(): """Ensure that specifying a properly configured acceleration dataclass properly activates the framework plugin and runs the train sucessfully. """ @@ -494,11 +494,12 @@ def test_framework_intialized_properly_moe(): with tempfile.TemporaryDirectory() as tempdir: model_args = copy.deepcopy(MODEL_ARGS) - model_args.model_name_or_path = "ibm-granite/granite-3.0-1b-a400m-instruct" - model_args.use_flash_attn = True + model_args.model_name_or_path = "Isotonic/TinyMixtral-4x248M-MoE" + model_args.torch_dtype = torch.bfloat16 train_args = copy.deepcopy(TRAIN_ARGS) train_args.output_dir = tempdir train_args.save_strategy = "no" + train_args.bf16 = True data_args = copy.deepcopy(DATA_ARGS) data_args.training_data_path = TWITTER_COMPLAINTS_JSON_FORMAT data_args.response_template = "\n\n### Label:" @@ -530,8 +531,8 @@ def test_framework_intialized_properly_moe(): ) # spy inside the train to ensure that the ilab plugin is called - assert spy["model_loader_calls"] == 0 - assert spy["augmentation_calls"] == 1 + assert spy["model_loader_calls"] == 1 + assert spy["augmentation_calls"] == 0 assert spy["get_ready_for_train_calls"] == 1 diff --git a/tuning/config/acceleration_configs/acceleration_framework_config.py b/tuning/config/acceleration_configs/acceleration_framework_config.py index e533e8839..7da11a5de 100644 --- a/tuning/config/acceleration_configs/acceleration_framework_config.py +++ b/tuning/config/acceleration_configs/acceleration_framework_config.py @@ -322,7 +322,6 @@ def _descend_and_set(path: List[str], d: Dict): already_set.add(prefix_path) _descend_and_set(path, asdict(datacls)) - print(configuration_contents) return configuration_contents def to_yaml(self, filename: str):