From fe827e4b621040805cbd01b848e975f488a85c9b Mon Sep 17 00:00:00 2001 From: Matthias Diener Date: Thu, 6 Oct 2022 22:49:44 -0500 Subject: [PATCH] Create WithTag --- loopy/kernel/creation.py | 3 +++ loopy/statistics.py | 9 ++++++-- loopy/symbolic.py | 46 ++++++++++++++++++++++++++++++++++++++++ test/test_statistics.py | 31 +++++++++++++++++++++++++++ 4 files changed, 87 insertions(+), 2 deletions(-) diff --git a/loopy/kernel/creation.py b/loopy/kernel/creation.py index 3fc3dfe9f..42ed0c046 100644 --- a/loopy/kernel/creation.py +++ b/loopy/kernel/creation.py @@ -1348,6 +1348,9 @@ def __init__(self, add_assignment): self.expr_to_var = {} super().__init__() + def map_with_tag(self, expr, additional_inames): + return super().map_with_tag(expr) + def map_reduction(self, expr, additional_inames): additional_inames = additional_inames | frozenset(expr.inames) diff --git a/loopy/statistics.py b/loopy/statistics.py index bdcdb0878..16ba57490 100755 --- a/loopy/statistics.py +++ b/loopy/statistics.py @@ -636,10 +636,14 @@ class Op(ImmutableRecord): A :class:`str` representing the kernel name where the operation occurred. + .. attribute:: tags + + A :class:`frozenset` of tags to the operation. + """ def __init__(self, dtype=None, name=None, count_granularity=None, - kernel_name=None): + kernel_name=None, tags=None): if count_granularity not in CountGranularity.ALL+[None]: raise ValueError("Op.__init__: count_granularity '%s' is " "not allowed. count_granularity options: %s" @@ -651,7 +655,8 @@ def __init__(self, dtype=None, name=None, count_granularity=None, super().__init__(dtype=dtype, name=name, count_granularity=count_granularity, - kernel_name=kernel_name) + kernel_name=kernel_name, + tags=tags) def __repr__(self): # Record.__repr__ overridden for consistent ordering and conciseness diff --git a/loopy/symbolic.py b/loopy/symbolic.py index b6bd1d009..5343d8f75 100644 --- a/loopy/symbolic.py +++ b/loopy/symbolic.py @@ -114,6 +114,10 @@ # {{{ mappers with support for loopy-specific primitives class IdentityMapperMixin: + def map_with_tag(self, expr, *args, **kwargs): + new_expr = self.rec(expr.expr, *args, **kwargs) + return WithTag(expr.tags, new_expr) + def map_literal(self, expr, *args, **kwargs): return expr @@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr): class WalkMapperMixin: + def map_with_tag(self, expr, *args, **kwargs): + if not self.visit(expr, *args, **kwargs): + return + + self.rec(expr.expr, *args, **kwargs) + def map_literal(self, expr, *args, **kwargs): self.visit(expr, *args, **kwargs) @@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase): class CombineMapper(CombineMapperBase): + def map_with_tag(self, expr, *args, **kwargs): + return self.rec(expr.expr, *args, **kwargs) + def map_reduction(self, expr, *args, **kwargs): return self.rec(expr.expr, *args, **kwargs) @@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase, class StringifyMapper(StringifyMapperBase): + def map_with_tag(self, expr, *args): + from pymbolic.mapper.stringifier import PREC_NONE + return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}" + def map_literal(self, expr, *args): return expr.s @@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs): def map_loopy_function_identifier(self, expr, *args, **kwargs): return set() + def map_with_tag(self, expr, *args, **kwargs): + deps = self.rec(expr.expr, *args, **kwargs) + return deps + def map_sub_array_ref(self, expr, *args, **kwargs): deps = self.rec(expr.subscript, *args, **kwargs) return deps - set(expr.swept_inames) @@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None): mapper_method = intern("map_tagged_variable") +class WithTag(LoopyExpressionBase): + """ + Represents a frozenset of tags attached to an :attr:`expr`. + """ + + init_arg_names = ("tags", "expr") + + def __init__(self, tags, expr): + self.tags = tags + self.expr = expr + + def __getinitargs__(self): + return (self.tags, self.expr) + + def get_hash(self): + return hash((self.__class__, self.tags, self.expr)) + + def is_equal(self, other): + return (other.__class__ == self.__class__ + and other.tags == self.tags + and other.expr == self.expr) + + mapper_method = intern("map_with_tag") + + class Reduction(LoopyExpressionBase): """ Represents a reduction operation on :attr:`expr` across :attr:`inames`. diff --git a/test/test_statistics.py b/test/test_statistics.py index 4218067fa..06724879f 100644 --- a/test/test_statistics.py +++ b/test/test_statistics.py @@ -1531,6 +1531,37 @@ def test_no_loop_ops(): assert f64_mul == 1 +from pytools.tag import Tag + + +class MyCostTag(Tag): + pass + + +class MyCostTag2(Tag): + pass + + +def test_op_with_tag(): + from loopy.symbolic import WithTag + from pymbolic.primitives import Subscript, Variable, Sum + + knl = lp.make_kernel( + "{[i]: 0<=i 1: exec(sys.argv[1])