Skip to content

Commit

Permalink
Bugfix for output_generation_logits in tensorrtllm (#11820)
Browse files Browse the repository at this point in the history
Signed-off-by: Abhishree <[email protected]>
  • Loading branch information
athitten authored Jan 11, 2025
1 parent 1ab22d1 commit 7f3ac6b
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion nemo/export/tensorrt_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,7 +1177,7 @@ def triton_infer_fn(self, **inputs: np.ndarray):
lora_uids = np.char.decode(inputs.pop("lora_uids").astype("bytes"), encoding="utf-8")
infer_input["lora_uids"] = lora_uids[0].tolist()
if "output_generation_logits" in inputs:
generation_logits_available = inputs["output_generation_logits"]
generation_logits_available = inputs["output_generation_logits"][0][0]
infer_input["output_generation_logits"] = inputs.pop("output_generation_logits")[0][0]

if generation_logits_available:
Expand Down

0 comments on commit 7f3ac6b

Please sign in to comment.