Skip to content

Commit

Permalink
Merge pull request #3 from neu-pml/develop
Browse files Browse the repository at this point in the history
Develop
  • Loading branch information
esennesh authored Dec 14, 2024
2 parents 065e1eb + 74a8723 commit 7a2b181
Show file tree
Hide file tree
Showing 5 changed files with 714 additions and 11 deletions.
2 changes: 0 additions & 2 deletions discopy/cartesian.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@
from discopy.cat import factory
from discopy.monoidal import Ty, assert_isatomic


@factory
class Diagram(symmetric.Diagram):
"""
Expand Down Expand Up @@ -171,5 +170,4 @@ def __call__(self, other):
return self.cod.ar.copy(self(other.dom), len(other.cod))
return super().__call__(other)


Id = Diagram.id
2 changes: 1 addition & 1 deletion discopy/cat.py
Original file line number Diff line number Diff line change
Expand Up @@ -567,7 +567,7 @@ def free_symbols(self) -> "set[sympy.Symbol]":
def recursive_free_symbols(data):
if isinstance(data, Mapping):
data = data.values()
if isinstance(data, Iterable):
if not isinstance(data, str) and isinstance(data, Iterable):
# Handles numpy 0-d arrays, which are actually not iterable.
if not hasattr(data, "shape") or data.shape != ():
return set().union(*map(recursive_free_symbols, data))
Expand Down
75 changes: 67 additions & 8 deletions discopy/python.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from dataclasses import dataclass

from discopy.cat import Composable, assert_iscomposable
from discopy.monoidal import Whiskerable
from discopy.monoidal import PRO, Whiskerable

Ty = tuple[type, ...]

Expand Down Expand Up @@ -82,6 +82,58 @@ def is_tuple(typ: type) -> bool:
"""
return getattr(typ, "__origin__", typ) is tuple

class product:
def __init__(self, *functions):
if not functions:
raise TypeError(repr(type(self).__name__) +
' needs at least one argument')

_functions = []
doms = []
for function, dom in functions:
if not callable(function):
raise TypeError(repr(type(self).__name__) +
' arguments must be callable')
if isinstance(function, product):
_functions = _functions + function.__wrapped__
doms = doms + function._doms
else:
_functions.append(function)
doms.append(dom)
self.__wrapped__ = _functions
self._doms = doms

def __call__(self, *args):
i = 0
result = ()
for func, dom in zip(self.__wrapped__, self._doms):
val = tuplify(func(*args[i:i+len(dom)]))
result = result + val
i += len(dom)
return untuplify(result)

class compose:
def __init__(self, *functions):
if not functions:
raise TypeError(repr(type(self).__name__) +
' needs at least one argument')

_functions = []
for function in reversed(functions):
if not callable(function):
raise TypeError(repr(type(self).__name__) +
' arguments must be callable')

if isinstance(function, compose):
_functions = _functions + function.__wrapped__
else:
_functions.append(function)
self.__wrapped__ = _functions

def __call__(self, *values):
for func in self.__wrapped__:
values = func(*tuplify(values))
return values

@dataclass
class Function(Composable[Ty], Whiskerable):
Expand Down Expand Up @@ -119,7 +171,7 @@ def id(dom: Ty) -> Function:
The identity function on a given tuple of types :code:`dom`.
Parameters:
dom (python.Ty) : The typle of types on which to take the identity.
dom (python.Ty) : The tuple of types on which to take the identity.
"""
return Function(lambda *xs: untuplify(xs), dom, dom)

Expand All @@ -131,8 +183,13 @@ def then(self, other: Function) -> Function:
other : The other function to compose in sequence.
"""
assert_iscomposable(self, other)
return Function(
lambda *args: other(*tuplify(self(*args))), self.dom, other.cod)

if self.inside == untuplify:
return other
if other.inside == untuplify:
return self
function = compose(other.inside, self.inside)
return Function(function, self.dom, other.cod)

def __call__(self, *xs):
return self.inside(*xs)
Expand All @@ -144,10 +201,12 @@ def tensor(self, other: Function) -> Function:
Parameters:
other : The other function to compose in sequence.
"""
def inside(*xs):
left, right = xs[:len(self.dom)], xs[len(self.dom):]
return untuplify(tuplify(self(*left)) + tuplify(other(*right)))
return Function(inside, self.dom + other.dom, self.cod + other.cod)
if self.dom == PRO(0) and self.inside == untuplify:
return other
if other.dom == PRO(0) and other.inside == untuplify:
return self
prod = product((self.inside, self.dom), (other.inside, other.dom))
return Function(prod, self.dom + other.dom, self.cod + other.cod)

@staticmethod
def swap(x: Ty, y: Ty) -> Function:
Expand Down
Loading

0 comments on commit 7a2b181

Please sign in to comment.