diff --git a/pretraining.py b/pretraining.py index c6523d6..37fe9b5 100644 --- a/pretraining.py +++ b/pretraining.py @@ -428,8 +428,7 @@ def main(): model.print_trainable_parameters() else: logger.info("Full parameters training") - if model_args.model_type in ['chatglm']: - model = model.half() + model = model.float() print_trainable_parameters(model) # Preprocessing the datasets. diff --git a/supervised_finetuning.py b/supervised_finetuning.py index fe5d1ff..ab843ce 100644 --- a/supervised_finetuning.py +++ b/supervised_finetuning.py @@ -873,8 +873,7 @@ def preprocess_function(examples): model.print_trainable_parameters() else: logger.info("Fine-tuning method: Full parameters training") - if model_args.model_type in ['chatglm']: - model = model.half() + model = model.float() print_trainable_parameters(model) logger.debug(f"Model: {model}")