-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.rkt
195 lines (166 loc) · 6.85 KB
/
utils.rkt
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
#lang racket
(require ffi/unsafe)
(require sham/jit)
(require "ast.rkt")
;; (require "ast.rkt"
;; "prelude/type-defines.rkt"
;; "prelude/template-format.rkt"
;; )
;; (require (for-syntax racket/syntax))
(provide (all-defined-out))
(define (prob->real x) (exp x))
(define (real->prob x) (log x))
(define (nat->prob x) (real->prob (exact->inexact x)))
;; (define (logsumexp2 a b)
;; (if (> a b)
;; (+ a (log (exp (- b a))))
;; (+ b (log (exp (- a b))))))
;; (define (one-of-type t)
;; (if (equal? t 'prob)
;; (real->prob 1.0)
;; 1.0))
;; (define (zero-of-type t)
;; (if (equal? t 'prob)
;; (real->prob 0.0)
;; 0.0))
;; (define logspace-add
;; (λ args
;; (real->prob (apply + (map prob->real args)))))
;; (define (replicate-vector n i)
;; (build-vector n (const i)))
;; (define (read-vector-from-csv fname)
;; (call-with-input-file fname
;; (lambda (in)
;; (for/vector [(s (in-lines in))]
;; (string->number s)))))
;; (define (get-cmd-argument i)
;; (vector-ref (current-command-line-arguments) i))
;; ;; utils for running jit
;; (define (compile-hakaru file-name info)
;; (define module-env (compile-file file-name info))
;; ;; (jit-dump-module module-env)
;; (define init-rng (jit-get-function 'init-rng module-env))
;; (init-rng)
;; (define prog (jit-get-function 'prog module-env))
;; prog)
;; (define (get-array-function-sym f t)
;; (define ts (get-type-string t))
;; (string->symbol
;; (match f
;; ['get-index (format get-index-fun-format ts)]
;; ['get-size (format get-array-size-fun-format ts)]
;; ['get-data (format get-array-data-fun-format ts)]
;; ['set-index! (format set-index-fun-format ts)]
;; ['make (format make-array-fun-format ts)]
;; ['new (format new-size-array-fun-format ts)])))
;; (define (get-pair-function-sym f t)
;; (define ts (get-type-string t))
;; (string->symbol
;; (match f
;; ['make (format make-pair-fun-format ts)]
;; ['car (format pair-car-fun-format ts)]
;; ['cdr (format pair-cdr-fun-format ts)]
;; ['set-car! (format pair-set-car-fun-format ts)]
;; ['set-cdr! (format pair-set-cdr-fun-format ts)])))
;; (define (get-function-sym module-env sym t)
;; (jit-get-function (if (equal? (car t) 'pair)
;; (get-pair-function-sym sym t)
;; (get-array-function-sym sym t))
;; module-env))
;; (define (get-racket-type t)
;; (match t
;; ['real _double]
;; ['prob _double]
;; ['nat _uint64]
;; ['unit _uint64]
;; [else _pointer]))
;; (define (rkt->jit module-env type val)
;; (match type
;; [`(array ,t)
;; (define size (length val))
;; (define arr ((get-function-sym module-env 'new type) size))
;; (for ([j (in-range size)]
;; [v val])
;; ((get-function-sym module-env 'set-index! type)
;; arr j (rkt->jit module-env t v)))
;; arr]
;; [`(pair ,ta ,tb)
;; (define tav (rkt->jit module-env ta (car val)))
;; (define tbv (rkt->jit module-env tb (cdr val)))
;; ((get-function-sym module-env 'make type) tav tbv)]
;; ['prob ((jit-get-function 'real2prob module-env) (exact->inexact val))]
;; ['real (exact->inexact val)]
;; [else val]))
;; (define (jit->rkt module-env type val)
;; (match type
;; [`(array ,t)
;; (define size ((get-function module-env 'get-size type) val))
;; (define data ((get-function module-env 'get-data type) val))
;; (printf "array-size: ~a\n" size)
;; (for/list ([j (in-range size)])
;; (jit->rkt module-env t ((get-function module-env 'get-index type) val j)))]
;; [`(pair ,ta ,tb)
;; (define tav
;; (jit->rkt module-env ta ((get-function 'module-env 'car type) val)))
;; (define tbv
;; (jit->rkt module-env tb ((get-function 'module-env 'cdr type) val)))
;; (cons tav tbv)]
;; [`(pointer ,t)
;; (jit->rkt module-env t val)]
;; ['prob ((jit-get-function 'prob2real module-env) val)]
;; [else val]))
(define (rkt-type t)
(match t
['nat _uint64]
['prob _double]
['real _double]))
(define (make-fixed-hakrit-array arr type) (list->cblock arr (rkt-type type)))
(define (make-sized-hakrit-array arr type)
(define ret (list->cblock (cons (car arr) arr) (rkt-type type)))
(ptr-set! ret _uint64 0 (length arr))
ret)
(define (fixed-hakrit-array-ref arr type index) (ptr-ref arr (rkt-type type) index))
(define (fixed-hakrit-array-set! arr type index value) (ptr-set! arr (rkt-type type) index value))
(define (sized-hakrit-array-ref arr type index) (ptr-ref arr (rkt-type type) (add1 index)))
(define (sized-hakrit-array-set! arr type index value) (ptr-set! arr (rkt-type type) (add1 index) value))
(define (sized-hakrit-array-size arr) (ptr-ref arr _uint64 0))
(define (sized-hakrit-array->racket-list ptr type)
(define size (sized-hakrit-array-size ptr))
(define lst (cblock->list ptr (rkt-type type) (add1 size)))
(cdr lst))
(define (fixed-hakrit-array->racket-list ptr type size)
(cblock->list ptr (rkt-type type) size))
(define (nat-array lst) (make-fixed-hakrit-array lst 'nat))
(define (nat-array-ref arr index) (fixed-hakrit-array-ref arr 'nat index))
(define (nat-array-set! arr index val) (fixed-hakrit-array-set! arr 'nat index val))
(define (sized-nat-array lst) (make-sized-hakrit-array lst 'nat))
(define (sized-nat-array-ref arr index) (sized-hakrit-array-ref arr 'nat index))
(define (sized-nat-array-set! arr index val) (sized-hakrit-array-set! arr 'nat index val))
(define (real-array lst) (make-fixed-hakrit-array lst 'real))
(define (real-array-ref arr index) (fixed-hakrit-array-ref arr 'real index))
(define (real-array-set! arr index val) (fixed-hakrit-array-set! arr 'real index val))
(define (sized-real-array lst) (make-sized-hakrit-array lst 'real))
(define (sized-real-array-ref arr index) (sized-hakrit-array-ref arr 'real index))
(define (sized-real-array-set! arr index val) (sized-hakrit-array-set! arr 'real index val))
;; (define (real-array lst)
;; (make-fixed-hakrit-array lst 'real))
;; (define (real-array-ref arr index)
;; (fixed-hakrit-array-ref arr 'real index))
;; (define (real-array-set! arr index val)
;; (fixed-hakrit-array-set! arr 'real index val))
(define (cons-array-pair a b)
(define p (list->cblock '(0 0) _uint64))
(ptr-set! p _pointer 0 a)
(ptr-set! p _pointer 1 b)
p)
(define (car-array-pair pr) (ptr-ref pr _pointer 0))
(define (cdr-array-pair pr) (ptr-ref pr _pointer 1))
(module+ test
(require rackunit)
(define tfa (make-fixed-hakrit-array '(1.0 2.0 3.0 4.0 5.0 6.0 7.0 8.0 9.0 0.0) 'real))
(check-= (fixed-hakrit-array-ref tfa 'real 2) 3.0 0.0001)
(fixed-hakrit-array-set! tfa 'real 2 42.0)
(check-= (fixed-hakrit-array-ref tfa 'real 2) 42.0 0.0001))
;; interpreter utils
(define (make-interp-hakrit-array val) (if (vector? val) val (apply vector val)))
(define interp-hakr-array-ref vector-ref)