Skip to content

Commit

Permalink
Create WithTag
Browse files Browse the repository at this point in the history
  • Loading branch information
matthiasdiener committed Oct 7, 2022
1 parent 8a2b06b commit fe827e4
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 2 deletions.
3 changes: 3 additions & 0 deletions loopy/kernel/creation.py
Original file line number Diff line number Diff line change
Expand Up @@ -1348,6 +1348,9 @@ def __init__(self, add_assignment):
self.expr_to_var = {}
super().__init__()

def map_with_tag(self, expr, additional_inames):
return super().map_with_tag(expr)

def map_reduction(self, expr, additional_inames):
additional_inames = additional_inames | frozenset(expr.inames)

Expand Down
9 changes: 7 additions & 2 deletions loopy/statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -636,10 +636,14 @@ class Op(ImmutableRecord):
A :class:`str` representing the kernel name where the operation occurred.
.. attribute:: tags
A :class:`frozenset` of tags to the operation.
"""

def __init__(self, dtype=None, name=None, count_granularity=None,
kernel_name=None):
kernel_name=None, tags=None):
if count_granularity not in CountGranularity.ALL+[None]:
raise ValueError("Op.__init__: count_granularity '%s' is "
"not allowed. count_granularity options: %s"
Expand All @@ -651,7 +655,8 @@ def __init__(self, dtype=None, name=None, count_granularity=None,

super().__init__(dtype=dtype, name=name,
count_granularity=count_granularity,
kernel_name=kernel_name)
kernel_name=kernel_name,
tags=tags)

def __repr__(self):
# Record.__repr__ overridden for consistent ordering and conciseness
Expand Down
46 changes: 46 additions & 0 deletions loopy/symbolic.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,10 @@
# {{{ mappers with support for loopy-specific primitives

class IdentityMapperMixin:
def map_with_tag(self, expr, *args, **kwargs):
new_expr = self.rec(expr.expr, *args, **kwargs)
return WithTag(expr.tags, new_expr)

def map_literal(self, expr, *args, **kwargs):
return expr

Expand Down Expand Up @@ -207,6 +211,12 @@ def map_common_subexpression_uncached(self, expr):


class WalkMapperMixin:
def map_with_tag(self, expr, *args, **kwargs):
if not self.visit(expr, *args, **kwargs):
return

self.rec(expr.expr, *args, **kwargs)

def map_literal(self, expr, *args, **kwargs):
self.visit(expr, *args, **kwargs)

Expand Down Expand Up @@ -273,6 +283,9 @@ class CallbackMapper(IdentityMapperMixin, CallbackMapperBase):


class CombineMapper(CombineMapperBase):
def map_with_tag(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

def map_reduction(self, expr, *args, **kwargs):
return self.rec(expr.expr, *args, **kwargs)

Expand All @@ -298,6 +311,10 @@ class ConstantFoldingMapper(ConstantFoldingMapperBase,


class StringifyMapper(StringifyMapperBase):
def map_with_tag(self, expr, *args):
from pymbolic.mapper.stringifier import PREC_NONE
return f"WithTag({expr.tags}, {self.rec(expr.expr, PREC_NONE)}"

def map_literal(self, expr, *args):
return expr.s

Expand Down Expand Up @@ -440,6 +457,10 @@ def map_tagged_variable(self, expr, *args, **kwargs):
def map_loopy_function_identifier(self, expr, *args, **kwargs):
return set()

def map_with_tag(self, expr, *args, **kwargs):
deps = self.rec(expr.expr, *args, **kwargs)
return deps

def map_sub_array_ref(self, expr, *args, **kwargs):
deps = self.rec(expr.subscript, *args, **kwargs)
return deps - set(expr.swept_inames)
Expand Down Expand Up @@ -712,6 +733,31 @@ def copy(self, *, name=None, tags=None):
mapper_method = intern("map_tagged_variable")


class WithTag(LoopyExpressionBase):
"""
Represents a frozenset of tags attached to an :attr:`expr`.
"""

init_arg_names = ("tags", "expr")

def __init__(self, tags, expr):
self.tags = tags
self.expr = expr

def __getinitargs__(self):
return (self.tags, self.expr)

def get_hash(self):
return hash((self.__class__, self.tags, self.expr))

def is_equal(self, other):
return (other.__class__ == self.__class__
and other.tags == self.tags
and other.expr == self.expr)

mapper_method = intern("map_with_tag")


class Reduction(LoopyExpressionBase):
"""
Represents a reduction operation on :attr:`expr` across :attr:`inames`.
Expand Down
31 changes: 31 additions & 0 deletions test/test_statistics.py
Original file line number Diff line number Diff line change
Expand Up @@ -1531,6 +1531,37 @@ def test_no_loop_ops():
assert f64_mul == 1


from pytools.tag import Tag


class MyCostTag(Tag):
pass


class MyCostTag2(Tag):
pass


def test_op_with_tag():
from loopy.symbolic import WithTag
from pymbolic.primitives import Subscript, Variable, Sum

knl = lp.make_kernel(
"{[i]: 0<=i<n}",
[
lp.Assignment("c[i]",
Sum(
(WithTag(frozenset((MyCostTag(),)),
Subscript(Variable("a"), Variable("i"))),
WithTag(frozenset((MyCostTag2(),)),
Subscript(Variable("b"), Variable("i"))))))
])

knl = lp.add_dtypes(knl, {"a": np.float64, "b": np.float64})

_op_map = lp.get_op_map(knl, subgroup_size=32)


if __name__ == "__main__":
if len(sys.argv) > 1:
exec(sys.argv[1])
Expand Down

0 comments on commit fe827e4

Please sign in to comment.