Skip to content

Commit

Permalink
get_quantization_config to return dict
Browse files Browse the repository at this point in the history
  • Loading branch information
deep-diver committed Aug 15, 2024
1 parent a0f38e3 commit e911214
Show file tree
Hide file tree
Showing 5 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion scripts/run_cpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,7 +122,7 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args).to_dict()
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_dpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args).to_dict()
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
Expand Down
2 changes: 1 addition & 1 deletion scripts/run_orpo.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args).to_dict()
quantization_config = get_quantization_config(model_args)

model = AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
Expand Down
4 changes: 2 additions & 2 deletions scripts/run_sft.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def main():
torch_dtype = (
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype)
)
quantization_config = get_quantization_config(model_args).to_dict()
quantization_config = get_quantization_config(model_args)

model_kwargs = dict(
revision=model_args.model_revision,
Expand All @@ -117,7 +117,7 @@ def main():
torch_dtype=torch_dtype,
use_cache=False if training_args.gradient_checkpointing else True,
device_map=get_kbit_device_map() if quantization_config is not None else None,
quantization_config=quantization_config.to_dict(),
quantization_config=quantization_config,
)

model = model_args.model_name_or_path
Expand Down
2 changes: 1 addition & 1 deletion src/alignment/model_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ def get_quantization_config(model_args: ModelArguments) -> BitsAndBytesConfig |
elif model_args.load_in_8bit:
quantization_config = BitsAndBytesConfig(
load_in_8bit=True,
)
).to_dict()
else:
quantization_config = None

Expand Down

0 comments on commit e911214

Please sign in to comment.