From 3ac751bc9fbd00f848a54b2524ddda73535672ed Mon Sep 17 00:00:00 2001 From: pranavm Date: Fri, 1 Nov 2024 11:33:40 -0700 Subject: [PATCH] Speeds up datatype constraints tests Makes datatype constraints tests significantly faster by not evaluating inputs. --- tripy/tests/constraints/object_builders.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/tripy/tests/constraints/object_builders.py b/tripy/tests/constraints/object_builders.py index 7dbfe1ad7..57d6cc766 100644 --- a/tripy/tests/constraints/object_builders.py +++ b/tripy/tests/constraints/object_builders.py @@ -26,7 +26,6 @@ def tensor_builder(init, dtype, namespace): if init is None: out = tp.ones(dtype=namespace[dtype], shape=(3, 2)) - out.eval() return out elif not isinstance(init, tp.Tensor): return init @@ -34,7 +33,8 @@ def tensor_builder(init, dtype, namespace): out = init if dtype is not None: out = tp.cast(out, dtype=namespace[dtype]) - out.eval() + # Need to evaluate when casting because we run into MLIR-TRT bugs while deriving upper bounds. + out.eval() return out @@ -47,8 +47,6 @@ def tensor_list_builder(init, dtype, namespace): out = [tp.ones(shape=(3, 2), dtype=namespace[dtype]) for _ in range(2)] else: out = [tp.cast(tens, dtype=namespace[dtype]) for tens in init] - for t in out: - t.eval() return out @@ -132,7 +130,7 @@ def default_builder(init, dtype, namespace): "pad": {"pad": [(0, 1), (1, 0)]}, "permute": {"perm": [1, 0]}, "prod": {"dim": 0}, - "quantize": {"scale": tp.Tensor([1, 1, 1]), "dim": 0}, + "quantize": {"input": tp.ones((3, 2)), "scale": tp.Tensor([1, 1, 1]), "dim": 0}, "repeat": {"repeats": 2, "dim": 0}, "reshape": {"shape": [6]}, "resize": {