Skip to content

Commit

Permalink
Add Counted mixin class and refactor form signature computation (#178)
Browse files Browse the repository at this point in the history
* Add Counted mixin class and refactor form signature computation

* fixup

* fixup

---------

Co-authored-by: Matthew Scroggs <[email protected]>
  • Loading branch information
connorjward and mscroggs authored Aug 8, 2023
1 parent 83f3408 commit b0d635a
Show file tree
Hide file tree
Showing 8 changed files with 91 additions and 83 deletions.
8 changes: 1 addition & 7 deletions ufl/algorithms/signature.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,6 @@ def compute_terminal_hashdata(expressions, renumbering):
# arguments, and just take repr of the rest of the terminals while
# we're iterating over them
terminal_hashdata = {}
labels = {}
index_numbering = {}
for expression in expressions:
for expr in traverse_unique_terminals(expression):
Expand All @@ -69,12 +68,7 @@ def compute_terminal_hashdata(expressions, renumbering):
data = expr._ufl_signature_data_(renumbering)

elif isinstance(expr, Label):
# Numbering labels as we visit them # TODO: Include in
# renumbering
data = labels.get(expr)
if data is None:
data = "L%d" % len(labels)
labels[expr] = data
data = expr._ufl_signature_data_(renumbering)

elif isinstance(expr, ExprList):
# Not really a terminal but can have 0 operands...
Expand Down
13 changes: 4 additions & 9 deletions ufl/coefficient.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,28 +19,27 @@
from ufl.functionspace import AbstractFunctionSpace, FunctionSpace, MixedFunctionSpace
from ufl.form import BaseForm
from ufl.split_functions import split
from ufl.utils.counted import counted_init
from ufl.utils.counted import Counted
from ufl.duals import is_primal, is_dual

# --- The Coefficient class represents a coefficient in a form ---


class BaseCoefficient(object):
class BaseCoefficient(Counted):
"""UFL form argument type: Parent Representation of a form coefficient."""

# Slots are disabled here because they cause trouble in PyDOLFIN
# multiple inheritance pattern:
# __slots__ = ("_count", "_ufl_function_space", "_repr", "_ufl_shape")
_ufl_noslots_ = True
__slots__ = ()
_globalcount = 0
_ufl_is_abstract_ = True

def __getnewargs__(self):
return (self._ufl_function_space, self._count)

def __init__(self, function_space, count=None):
counted_init(self, count, Coefficient)
Counted.__init__(self, count, Coefficient)

if isinstance(function_space, FiniteElementBase):
# For legacy support for .ufl files using cells, we map
Expand All @@ -57,9 +56,6 @@ def __init__(self, function_space, count=None):
self._repr = "BaseCoefficient(%s, %s)" % (
repr(self._ufl_function_space), repr(self._count))

def count(self):
return self._count

@property
def ufl_shape(self):
"Return the associated UFL shape."
Expand Down Expand Up @@ -111,14 +107,14 @@ class Cofunction(BaseCoefficient, BaseForm):

__slots__ = (
"_count",
"_counted_class",
"_arguments",
"_ufl_function_space",
"ufl_operands",
"_repr",
"_ufl_shape",
"_hash"
)
# _globalcount = 0
_primal = False
_dual = True

Expand Down Expand Up @@ -161,7 +157,6 @@ class Coefficient(FormArgument, BaseCoefficient):
"""UFL form argument type: Representation of a form coefficient."""

_ufl_noslots_ = True
_globalcount = 0
_primal = True
_dual = False

Expand Down
10 changes: 3 additions & 7 deletions ufl/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@
from ufl.core.ufl_type import ufl_type
from ufl.core.terminal import Terminal
from ufl.domain import as_domain
from ufl.utils.counted import counted_init
from ufl.utils.counted import Counted


@ufl_type()
class Constant(Terminal):
class Constant(Terminal, Counted):
_ufl_noslots_ = True
_globalcount = 0

def __init__(self, domain, shape=(), count=None):
Terminal.__init__(self)
counted_init(self, count=count, countedclass=Constant)
Counted.__init__(self, count, Constant)

self._ufl_domain = as_domain(domain)
self._ufl_shape = shape
Expand All @@ -31,9 +30,6 @@ def __init__(self, domain, shape=(), count=None):
self._repr = "Constant({}, {}, {})".format(
repr(self._ufl_domain), repr(self._ufl_shape), repr(self._count))

def count(self):
return self._count

@property
def ufl_shape(self):
return self._ufl_shape
Expand Down
13 changes: 4 additions & 9 deletions ufl/core/multiindex.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
# Modified by Massimiliano Leoni, 2016.


from ufl.utils.counted import counted_init
from ufl.utils.counted import Counted
from ufl.core.ufl_type import ufl_type
from ufl.core.terminal import Terminal

Expand Down Expand Up @@ -70,20 +70,15 @@ def __repr__(self):
return r


class Index(IndexBase):
class Index(IndexBase, Counted):
"""UFL value: An index with no value assigned.
Used to represent free indices in Einstein indexing notation."""
__slots__ = ("_count",)

_globalcount = 0
__slots__ = ("_count", "_counted_class")

def __init__(self, count=None):
IndexBase.__init__(self)
counted_init(self, count, Index)

def count(self):
return self._count
Counted.__init__(self, count, Index)

def __hash__(self):
return hash(("Index", self._count))
Expand Down
58 changes: 47 additions & 11 deletions ufl/form.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,15 @@
from itertools import chain

from ufl.checks import is_scalar_constant_expression
from ufl.constant import Constant
from ufl.constantvalue import Zero
from ufl.core.expr import Expr, ufl_err_str
from ufl.core.ufl_type import UFLType, ufl_type
from ufl.domain import extract_unique_domain, sort_domains
from ufl.equation import Equation
from ufl.integral import Integral
from ufl.utils.counted import Counted
from ufl.utils.sorting import sorted_by_count

# Export list for ufl.classes
__all_classes__ = ["Form", "BaseForm", "ZeroBaseForm"]
Expand Down Expand Up @@ -257,8 +260,9 @@ class Form(BaseForm):
"_arguments",
"_coefficients",
"_coefficient_numbering",
"_constant_numbering",
"_constants",
"_constant_numbering",
"_terminal_numbering",
"_hash",
"_signature",
# --- Dict that external frameworks can place framework-specific
Expand Down Expand Up @@ -289,11 +293,10 @@ def __init__(self, integrals):
self._coefficients = None
self._coefficient_numbering = None
self._constant_numbering = None
self._terminal_numbering = None

from ufl.algorithms.analysis import extract_constants
self._constants = extract_constants(self)
self._constant_numbering = dict(
(c, i) for i, c in enumerate(self._constants))

# Internal variables for caching of hash and signature after
# first request
Expand Down Expand Up @@ -406,8 +409,15 @@ def coefficients(self):
def coefficient_numbering(self):
"""Return a contiguous numbering of coefficients in a mapping
``{coefficient:number}``."""
# cyclic import
from ufl.coefficient import Coefficient

if self._coefficient_numbering is None:
self._analyze_form_arguments()
self._coefficient_numbering = {
expr: num
for expr, num in self.terminal_numbering().items()
if isinstance(expr, Coefficient)
}
return self._coefficient_numbering

def constants(self):
Expand All @@ -416,8 +426,38 @@ def constants(self):
def constant_numbering(self):
"""Return a contiguous numbering of constants in a mapping
``{constant:number}``."""
if self._constant_numbering is None:
self._constant_numbering = {
expr: num
for expr, num in self.terminal_numbering().items()
if isinstance(expr, Constant)
}
return self._constant_numbering

def terminal_numbering(self):
"""Return a contiguous numbering for all counted objects in the form.
The returned object is mapping from terminal to its number (an integer).
The numbering is computed per type so :class:`Coefficient`s,
:class:`Constant`s, etc will each be numbered from zero.
"""
# cyclic import
from ufl.algorithms.analysis import extract_type

if self._terminal_numbering is None:
exprs_by_type = defaultdict(set)
for counted_expr in extract_type(self, Counted):
exprs_by_type[counted_expr._counted_class].add(counted_expr)

numbering = {}
for exprs in exprs_by_type.values():
for i, expr in enumerate(sorted_by_count(exprs)):
numbering[expr] = i
self._terminal_numbering = numbering
return self._terminal_numbering

def signature(self):
"Signature for use with jit cache (independent of incidental numbering of indices etc.)"
if self._signature is None:
Expand Down Expand Up @@ -625,23 +665,19 @@ def _analyze_form_arguments(self):
sorted(set(arguments), key=lambda x: x.number()))
self._coefficients = tuple(
sorted(set(coefficients), key=lambda x: x.count()))
self._coefficient_numbering = dict(
(c, i) for i, c in enumerate(self._coefficients))

def _compute_renumbering(self):
# Include integration domains and coefficients in renumbering
dn = self.domain_numbering()
cn = self.coefficient_numbering()
cnstn = self.constant_numbering()
tn = self.terminal_numbering()
renumbering = {}
renumbering.update(dn)
renumbering.update(cn)
renumbering.update(cnstn)
renumbering.update(tn)

# Add domains of coefficients, these may include domains not
# among integration domains
k = len(dn)
for c in cn:
for c in self.coefficients():
d = extract_unique_domain(c)
if d is not None and d not in renumbering:
renumbering[d] = k
Expand Down
11 changes: 4 additions & 7 deletions ufl/matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,32 +13,32 @@
from ufl.core.ufl_type import ufl_type
from ufl.argument import Argument
from ufl.functionspace import AbstractFunctionSpace
from ufl.utils.counted import counted_init
from ufl.utils.counted import Counted


# --- The Matrix class represents a matrix, an assembled two form ---

@ufl_type()
class Matrix(BaseForm):
class Matrix(BaseForm, Counted):
"""An assemble linear operator between two function spaces."""

__slots__ = (
"_count",
"_counted_class",
"_ufl_function_spaces",
"ufl_operands",
"_repr",
"_hash",
"_ufl_shape",
"_arguments")
_globalcount = 0

def __getnewargs__(self):
return (self._ufl_function_spaces[0], self._ufl_function_spaces[1],
self._count)

def __init__(self, row_space, column_space, count=None):
BaseForm.__init__(self)
counted_init(self, count, Matrix)
Counted.__init__(self, count, Matrix)

if not isinstance(row_space, AbstractFunctionSpace):
raise ValueError("Expecting a FunctionSpace as the row space.")
Expand All @@ -52,9 +52,6 @@ def __init__(self, row_space, column_space, count=None):
self._hash = None
self._repr = f"Matrix({self._ufl_function_spaces[0]!r}, {self._ufl_function_spaces[1]!r}, {self._count!r})"

def count(self):
return self._count

def ufl_function_spaces(self):
"Get the tuple of function spaces of this coefficient."
return self._ufl_function_spaces
Expand Down
45 changes: 21 additions & 24 deletions ufl/utils/counted.py
Original file line number Diff line number Diff line change
@@ -1,43 +1,40 @@
# -*- coding: utf-8 -*-
"Utilites for types with a global unique counter attached to each object."
"Mixin class for types with a global unique counter attached to each object."

# Copyright (C) 2008-2016 Martin Sandve Alnæs
#
# This file is part of UFL (https://www.fenicsproject.org)
#
# SPDX-License-Identifier: LGPL-3.0-or-later

import itertools

def counted_init(self, count=None, countedclass=None):
"Initialize a counted object, see ExampleCounted below for how to use."

if countedclass is None:
countedclass = type(self)
class Counted:
"""Mixin class for globally counted objects."""

if count is None:
count = countedclass._globalcount
# Mixin classes do not work well with __slots__ so _count must be
# added to the __slots__ of the inheriting class
__slots__ = ()

self._count = count
_counter = None

if self._count >= countedclass._globalcount:
countedclass._globalcount = self._count + 1
def __init__(self, count=None, counted_class=None):
"""Initialize the Counted instance.
:arg count: The object count, if ``None`` defaults to the next value
according to the global counter (per type).
:arg counted_class: Class to attach the global counter too. If ``None``
then ``type(self)`` will be used.
class ExampleCounted(object):
"""An example class for classes of objects identified by a global counter.
"""
# create a new counter for each subclass
counted_class = counted_class or type(self)
if counted_class._counter is None:
counted_class._counter = itertools.count()

Mimic this class to create globally counted objects within a single type.
"""
# Store the count for each object
__slots__ = ("_count",)
self._count = count if count is not None else next(counted_class._counter)
self._counted_class = counted_class

# Store a global counter with the class
_globalcount = 0

# Call counted_init with an optional constructor argument and the class
def __init__(self, count=None):
counted_init(self, count, ExampleCounted)

# Make the count accessible
def count(self):
return self._count
Loading

0 comments on commit b0d635a

Please sign in to comment.