Skip to content

Commit

Permalink
Pr comments
Browse files Browse the repository at this point in the history
  • Loading branch information
Ch0ronomato committed Dec 8, 2024
1 parent b6ce485 commit 2f70694
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 23 deletions.
2 changes: 1 addition & 1 deletion pytensor/link/pytorch/dispatch/elemwise.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def elemwise_fn(*inputs):

out_shape = bcasted_inputs[0].size()
out_size = out_shape.numel()
raveled_outputs = [torch.zeros(out_size) for out in node.outputs]
raveled_outputs = [torch.empty(out_size) for out in node.outputs]

for i in range(out_size):
core_outs = base_fn(*(inp[i] for inp in raveled_inputs))
Expand Down
27 changes: 5 additions & 22 deletions pytensor/link/pytorch/linker.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,19 +9,6 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.gen_functors = []

def input_filter(self, inp):
from pytensor.link.pytorch.dispatch import pytorch_typify

return pytorch_typify(inp)

def output_filter(self, var, out):
from torch import is_tensor

if is_tensor(out):
return out.cpu()
else:
return out

def fgraph_convert(self, fgraph, input_storage, storage_map, **kwargs):
from pytensor.link.pytorch.dispatch import pytorch_funcify

Expand Down Expand Up @@ -67,34 +54,30 @@ def __init__(self, fn, gen_functors):
self.fn = torch.compile(fn)
self.gen_functors = gen_functors.copy()

def __call__(self, *args, **kwargs):
def __call__(self, *inputs, **kwargs):
import pytensor.link.utils

# set attrs
for n, fn in self.gen_functors:
setattr(pytensor.link.utils, n[1:], fn)

res = self.fn(*args, **kwargs)
# Torch does not accept numpy inputs and may return GPU objects
outs = self.fn(*(pytorch_typify(inp) for inp in inputs), **kwargs)

# unset attrs
for n, _ in self.gen_functors:
if getattr(pytensor.link.utils, n[1:], False):
delattr(pytensor.link.utils, n[1:])

return res
return tuple(out.cpu().numpy() for out in outs)

def __del__(self):
del self.gen_functors

inner_fn = wrapper(fn, self.gen_functors)
self.gen_functors = []

# Torch does not accept numpy inputs and may return GPU objects
def create_outputs(*inputs, inner_fn=inner_fn):
outs = inner_fn(*(pytorch_typify(inp) for inp in inputs))
return tuple(out.cpu().numpy() for out in outs)

return create_outputs
return inner_fn

def create_thunk_inputs(self, storage_map):
thunk_inputs = []
Expand Down

0 comments on commit 2f70694

Please sign in to comment.