forked from BartoszMilewski/XOperad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Forest.hs
46 lines (38 loc) · 1.55 KB
/
Forest.hs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE GADTs #-}
module Forest where
import Numbers
import Data.Proxy
import Data.Constraint
-- f n : a tree with n branches (inputs)
-- Forest f i o : a forest of o trees with total i branches
-- Cons
-- i1 i2 (i1 + i2)
-- \ / | | | | | | | | | |
-- | | | | | |
-- n 1 + n
data Forest f n m where
Nil :: Forest f Z Z
Cons :: f i1 -> Forest f i2 n -> Forest f (i1 + i2) (S n)
instance Show (Forest f n m) where
show Nil = "."
show (Cons _ fs) = "V" ++ ", " ++ show fs
class Graded (f :: Nat -> *) where
grade :: f n -> SNat n
inputs :: forall f n m. (forall k. f k -> SNat k) -> Forest f n m -> SNat n
inputs _ Nil = SZ
inputs g (Cons a as) = g a `plus` inputs g as
replicateF :: SNat n -> f i -> Forest f (n * i) n
replicateF SZ _ = Nil
replicateF (SS n) f = Cons f (replicateF n f)
splitForest :: forall m n i f r. SNat m -> SNat n -> Forest f i (m + n) ->
-- to simulate exists i1 i2, we need to CPS this function
(forall i1 i2. (i ~ (i1 + i2)) => (Forest f i1 m, Forest f i2 n) -> r) -> r
splitForest SZ _ fs k = k (Nil, fs)
splitForest (SS (sm :: SNat m_1)) sn (Cons (t :: f i1) (ts :: Forest f i2 (m_1 + n))) k =
splitForest sm sn ts (\((m_frag :: Forest f i3 m_1), (n_frag :: Forest f i4 n)) ->
case plusAssoc (Proxy :: Proxy i1) (Proxy :: Proxy i3) (Proxy :: Proxy i4) of Dict -> k (Cons t m_frag, n_frag))