Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Arithmetic coding demo #631

Open
wants to merge 17 commits into
base: main
Choose a base branch
from
157 changes: 157 additions & 0 deletions examples/arithmetic-coding.dx
Original file line number Diff line number Diff line change
@@ -0,0 +1,157 @@
'# Lossless compression
Based on the implementation of [rANS](https://github.com/j-towns/ans-notes/blob/master/rans.py) by James Townsend.

-- prelude additions

def lowWord' (x : Word64) : Word8 = internalCast _ x
-- W8ToI (lowWord' (IToW64 100))

instance Integral Word64
idiv = \x y. %idiv x y
rem = \x y. %irem x y

instance Mul Word64
mul = \x y. %imul x y
one = IToW64 1

instance Eq Word64
(==) = \x y. W8ToB $ %ieq x y

instance Ord Word64
(>) = \x y. W8ToB $ %igt x y
(<) = \x y. W8ToB $ %ilt x y

p_prec = 3
s_prec = 64
t_prec = 32
p_int: Int = %shl 1 p_prec
p_mask: Word64 = (one .<<. p_prec) - one
t_mask: Word64 = (one .<<. t_prec) - one
s_min: Word64 = one .<<. (s_prec - t_prec)
s_max: Word64 = one .<<. s_prec
Alphabet = Fin 26
Interval = Fin p_int
Message = (Word64 & List Word64)

'Utilities

def mod' (x: Word64) (y: Word64) : Word64 = rem (y + rem x y) y

def charToIdx (c: Word8) : Int = W8ToI c - W8ToI 'a'
def idxToChar (i: Int) : Word8 = IToW8 (i + (W8ToI 'a'))

def get_cs (ps: Alphabet=>Word64) : Alphabet=>Word64 =
withState zero \total.
for i. if ps.i > zero
then
currTotal = get total
newTotal = currTotal + ps.i
total := newTotal
currTotal
else zero

def get_ps (str: (Fin l)=>Word8) : Alphabet=>Word64 =
a: Alphabet => Word64 = zero
yieldState a \ref. for i.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This one can also be computed with a parallel Accum.

i' = (charToIdx str.i)@_
ref!i' := (get ref).i' + one

def get_cs_map (ps: Alphabet=>Word64) : Interval=>Word8 =
init: List Word8 = (AsList 0 [])
map' = yieldState init \map.
for i.
count = W8ToI $ lowWord' ps.i
boundedIter count 0 \_.
map := (get map) <> (AsList 1 [idxToChar (ordinal i)])
Continue
(AsList _ map'') = map'
map = for i:Interval. map''.(unsafeFromOrdinal _ (ordinal i))
map

-- a string to prep the statistics
xs' = "abbccddc"
(AsList l xs) = xs'
ps = get_ps xs
cs = get_cs ps
cs_map = get_cs_map ps

def g (x: Word8) : (Word64 & Word64) =
x_idx = charToIdx x
(cs.(x_idx@_), ps.(x_idx@_))

def f (s': Word64) : (Word8 & (Word64 & Word64)) =
idx = W8ToI $ lowWord' s'
x = cs_map.(idx@_)
(x, g x)

def stack_pop ((AsList l' t'): List Word64) : (Word64 & List Word64) =
l'' = l' - 1
tail = slice t' 1 (Fin l'')
head = t'.(0@_)
(head, (AsList _ tail))

def stack_push (t_top: Word64) (t: List Word64) : (List Word64) =
(AsList 1 [t_top]) <> t

'Coding Interface

def pop ((s, t): Message) : (Message & Word8) =
s_bar = s .&. p_mask
(x, (c, p)) = f s_bar
s' = p * (s .>>. p_prec) + s_bar - c
-- TODO: use a while loop, not a do-while
m' = case s' < s_min of
True ->
yieldState (s', t) \m'.
s'' = fstRef m'
t'' = sndRef m'
while do
(t_top, t') = stack_pop (get t'')
t'' := t'
s'' := ((get s'') .<<. t_prec) + t_top
(get s'') < s_min
False -> (s', t)
(m', x)

def push ((s, t): Message) (x: Word8) : Message =
(c, p) = g x
(s', t') = case s >= (p .<<. (s_prec - p_prec)) of
True ->
yieldState (s, t) \m'.
s' = fstRef m'
t' = sndRef m'
while do
t' := stack_push ((get s') .&. t_mask) (get t')
s' := (get s') .>>. t_prec
get s' >= (p .<<. (s_prec - p_prec))
False -> (s, t)
s'' = ((idiv s' p) .<<. p_prec) + (mod' s' p) + c
(s'', t')


'Demo

-- initialize message
m_init: Message = (s_min, AsList 0 [])
xs_init' = "abbcbcdccdccacbbacccdabbbaccccd"
(AsList l' xs_init) = xs_init'

m' = yieldState m_init \m.
for i:(Fin l').
m := push (get m) xs_init.i

init_args: (Message & List Word8) = (m', AsList 0 [])
(_, xs_decoded) = yieldState init_args \ref.
m = fstRef ref
xs_decoded = sndRef ref
for i:(Fin l').
(m', x) = pop (get m)
m := m'
xs_decoded := (AsList 1 [x]) <> (get xs_decoded)
get ref

(AsList _ xs'') = xs_decoded
:p xs'' == xs_init