Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support on_unused_input for string parameter names in eval #1085

Merged
merged 5 commits into from
Nov 12, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 6 additions & 2 deletions pytensor/graph/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,16 +616,20 @@ def eval(
"""
from pytensor.compile.function import function

ignore_unused_input = kwargs.get("on_unused_input", None) in ("ignore", "warn")

def convert_string_keys_to_variables(inputs_to_values) -> dict["Variable", Any]:
new_input_to_values = {}
for key, value in inputs_to_values.items():
if isinstance(key, str):
matching_vars = get_var_by_name([self], key)
if not matching_vars:
raise ValueError(f"{key} not found in graph")
if not ignore_unused_input:
raise ValueError(f"{key} not found in graph")
elif len(matching_vars) > 1:
raise ValueError(f"Found multiple variables with name {key}")
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[matching_vars[0]] = value
else:
new_input_to_values[key] = value
return new_input_to_values
Expand Down
4 changes: 4 additions & 0 deletions tests/graph/test_basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -367,6 +367,10 @@ def test_eval_kwargs(self):
self.w.eval({self.z: 3, self.x: 2.5})
assert self.w.eval({self.z: 3, self.x: 2.5}, on_unused_input="ignore") == 6.0

# regression test for https://github.com/pymc-devs/pytensor/issues/1084
q = self.x + 1
assert q.eval({"x": 1, "y": 2}, on_unused_input="ignore") == 2.0
tvwenger marked this conversation as resolved.
Show resolved Hide resolved

@pytest.mark.filterwarnings("error")
def test_eval_unashable_kwargs(self):
y_repl = constant(2.0, dtype="floatX")
Expand Down
Loading