diff --git a/instruct_qa/generation/generator.py b/instruct_qa/generation/generator.py index d616ae2..c8416a1 100644 --- a/instruct_qa/generation/generator.py +++ b/instruct_qa/generation/generator.py @@ -74,7 +74,7 @@ def __init__(self, *args, **kwargs): } if completion_type is not None: - self.model_map[model_name] = completion_type + self.model_map[self.model_name] = completion_type assert ( self.model_name in self.model_map