From 7a29d78ffd9426d713fb406902716d3978335574 Mon Sep 17 00:00:00 2001 From: Ch0ronomato Date: Sun, 8 Dec 2024 09:08:01 -0800 Subject: [PATCH] Add invert and split --- pytensor/link/pytorch/dispatch/basic.py | 9 +++++++++ pytensor/link/pytorch/dispatch/scalar.py | 6 ++++++ 2 files changed, 15 insertions(+) diff --git a/pytensor/link/pytorch/dispatch/basic.py b/pytensor/link/pytorch/dispatch/basic.py index 11e1d6c63a..6cf4f29aab 100644 --- a/pytensor/link/pytorch/dispatch/basic.py +++ b/pytensor/link/pytorch/dispatch/basic.py @@ -19,6 +19,7 @@ Eye, Join, MakeVector, + Split, TensorFromScalar, ) @@ -185,3 +186,11 @@ def tensorfromscalar(x): return torch.as_tensor(x) return tensorfromscalar + + +@pytorch_funcify.register(Split) +def pytorch_funcify_Split(op, node, **kwargs): + def inner_fn(x, dim, split_amounts): + return x.split(split_amounts.tolist(), dim=dim.item()) + + return inner_fn diff --git a/pytensor/link/pytorch/dispatch/scalar.py b/pytensor/link/pytorch/dispatch/scalar.py index 1416e58f55..2505ef655a 100644 --- a/pytensor/link/pytorch/dispatch/scalar.py +++ b/pytensor/link/pytorch/dispatch/scalar.py @@ -5,11 +5,17 @@ from pytensor.link.pytorch.dispatch.basic import pytorch_funcify from pytensor.scalar.basic import ( Cast, + Invert, ScalarOp, ) from pytensor.scalar.math import Softplus +@pytorch_funcify.register(Invert) +def pytorch_funcify_invert(op, node, **kwargs): + return torch.bitwise_not + + @pytorch_funcify.register(ScalarOp) def pytorch_funcify_ScalarOp(op, node, **kwargs): """Return pytorch function that implements the same computation as the Scalar Op.