Skip to content

Commit

Permalink
Merge pull request #1066 from llmware-ai/update-model-catalog-1026
Browse files Browse the repository at this point in the history
model catalog updates
  • Loading branch information
doberst authored Oct 26, 2024
2 parents df2000d + 5348010 commit 5b97ccc
Showing 1 changed file with 35 additions and 4 deletions.
39 changes: 35 additions & 4 deletions llmware/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,7 +219,7 @@ def validate(cls, model_card_dict):
return True

@classmethod
def add_model(cls, model_card_dict):
def add_model(cls, model_card_dict, over_write=True):

""" Adds a model to the registry """

Expand All @@ -231,8 +231,15 @@ def add_model(cls, model_card_dict):
if (model["model_name"] in [model_card_dict["model_name"], model_card_dict["display_name"]] or
model["display_name"] in [model_card_dict["model_name"], model_card_dict["display_name"]]):

raise LLMWareException(message=f"Exception: model name overlaps with another model already "
f"in the ModelCatalog - {model}")
if not over_write:

raise LLMWareException(message=f"Exception: model name overlaps with another model already "
f"in the ModelCatalog - {model}")

else:
# logger.warning(f"_ModelRegistry - over-write = True - {model['model_name']} - mew model added.")

del cls.registered_models[i]

# go ahead and add model to the catalog

Expand Down Expand Up @@ -477,6 +484,9 @@ def __init__(self):
self.api_key= None
self.custom_loader = None

# new - add - 102024
self.model_kwargs = {}

def to_state_dict(self):

""" Writes selected model state parameters to dictionary. """
Expand Down Expand Up @@ -889,6 +899,7 @@ def model_load_optimizer(self):
# to "re-direct" the model loading parameters
if isinstance(success_dict, dict):
for k, v in success_dict.items():
# updating and setting attrs
setattr(self,k,v)

return True
Expand Down Expand Up @@ -935,6 +946,14 @@ def load_model (self, selected_model, api_key=None, use_gpu=True, sample=True,ge

raise ModelNotFoundException(self.selected_model)

# new - 1020 add
if self.model_kwargs:
if not kwargs:
kwargs = {}
for k,v in self.model_kwargs.items():
kwargs.update({k:v})
# end - new add

# step 2- instantiate the right model class
my_model = self.get_model_by_name(model_card["model_name"], api_key=self.api_key,
api_endpoint=self.api_endpoint, **kwargs)
Expand Down Expand Up @@ -1697,7 +1716,19 @@ def logit_analysis(self, response, model_card, hf_tokenizer_name,api_key=None):

for x in range(0, len(logits[i])):
if logits[i][x][0] in marker_tokens:
new_entry = (marker_token_lookup[logits[i][x][0]],

# if model catalog loaded from json config file, then dict number converted to str

if logits[i][x][0] in marker_token_lookup:
entry0 = marker_token_lookup[logits[i][x][0]]

elif str(logits[i][x][0]) in marker_token_lookup:
entry0 = marker_token_lookup[str(logits[i][x][0])]

else:
entry0 = "NA"

new_entry = (entry0,
logits[i][x][0],
logits[i][x][1])
marker_token_probs.append(new_entry)
Expand Down

0 comments on commit 5b97ccc

Please sign in to comment.