Skip to content

Commit

Permalink
Add keyword procedures
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Sep 20, 2023
1 parent 0219e0f commit 47a52e9
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 41 deletions.
136 changes: 98 additions & 38 deletions auto_editor/lang/palet.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import cmath
import math
import random
from dataclasses import dataclass
from difflib import get_close_matches
from fractions import Fraction
from functools import reduce
Expand All @@ -17,6 +18,7 @@
from typing import TYPE_CHECKING

import numpy as np
from numpy import logical_and, logical_not, logical_or, logical_xor

from auto_editor.analyze import edit_method, mut_remove_large, mut_remove_small
from auto_editor.lib.contracts import *
Expand Down Expand Up @@ -64,17 +66,10 @@ class ClosingError(MyError):
}


@dataclass(slots=True)
class Token:
__slots__ = ("type", "value")

def __init__(self, type: str, value: Any):
self.type = type
self.value = value

def __str__(self) -> str:
return f"(Token {print_str(self.type)} {print_str(self.value)})"

__repr__ = __str__
type: str
value: Any


class Lexer:
Expand Down Expand Up @@ -374,14 +369,12 @@ def handle_strings() -> bool:
###############################################################################


@dataclass(slots=True)
class Method:
__slots__ = "val"

def __init__(self, val: str):
self.val = val
val: str

def __str__(self) -> str:
return f'(Method "{self.val}")'
return f"#<method:{self.val}>"

__repr__ = __str__

Expand Down Expand Up @@ -474,17 +467,16 @@ def check_args(
) -> None:
lower, upper = arity
amount = len(values)
if upper is not None and lower > upper:
raise ValueError("lower must be less than upper")
if lower == upper and len(values) != lower:
raise MyError(f"{o}: Arity mismatch. Expected {lower}, got {amount}")

assert not (upper is not None and lower > upper)
base = f"`{o}` has an arity mismatch. Expected "

if lower == upper and len(values) != lower:
raise MyError(f"{base}{lower}, got {amount}")
if upper is None and amount < lower:
raise MyError(f"{o}: Arity mismatch. Expected at least {lower}, got {amount}")
raise MyError(f"{base}at least {lower}, got {amount}")
if upper is not None and (amount > upper or amount < lower):
raise MyError(
f"{o}: Arity mismatch. Expected between {lower} and {upper}, got {amount}"
)
raise MyError(f"{base}between {lower} and {upper}, got {amount}")

if cont is None:
return
Expand All @@ -493,7 +485,7 @@ def check_args(
check = cont[-1] if i >= len(cont) else cont[i]
if not check_contract(check, val):
exp = f"{check}" if callable(check) else print_str(check)
raise MyError(f"{o} expected a {exp}, got {print_str(val)}")
raise MyError(f"`{o}` expected a {exp}, got {print_str(val)}")


is_cont = Contract("contract?", is_contract)
Expand Down Expand Up @@ -575,7 +567,7 @@ def _sqrt(v: Number) -> Number:
def _xor(*vals: Any) -> bool | BoolList:
if is_boolarr(vals[0]):
check_args("xor", vals, (2, None), [is_boolarr])
return reduce(lambda a, b: boolop(a, b, np.logical_xor), vals)
return reduce(lambda a, b: boolop(a, b, logical_xor), vals)
check_args("xor", vals, (2, None), [is_bool])
return reduce(lambda a, b: a ^ b, vals)

Expand Down Expand Up @@ -829,6 +821,63 @@ def __call__(self, *args: Any) -> Any:
return my_eval(inner_env, self.body[-1])


@dataclass(slots=True)
class KeywordProc:
env: Env
name: str
parms: list[str]
kw_parms: list[str]
body: list
arity: tuple[int, None]
contracts: list[Any] | None = None

def __call__(self, *args: Any) -> Any:
env = {}

for i, parm in enumerate(self.parms):
if type(args[i]) is Keyword:
raise MyError(f"Invalid keyword `{args[i]}`")
env[parm] = args[i]

remain_args = args[len(self.parms) :]

allow_pos = True
pos_index = 0
key = ""
for arg in remain_args:
if type(arg) is Keyword:
if key:
raise MyError("Expected value for keyword but got another keyword")
key = arg.val
allow_pos = False
elif key:
env[key] = arg
key = ""
else:
if not allow_pos:
raise MyError("Positional argument not allowed here")
if pos_index >= len(self.kw_parms):
base = f"`{self.name}` has an arity mismatch. Expected"
upper = len(self.parms) + len(self.kw_parms)
raise MyError(f"{base} at most {upper}")

env[self.kw_parms[pos_index]] = arg
pos_index += 1

inner_env = Env(env, self.env)

for item in self.body[0:-1]:
my_eval(inner_env, item)

return my_eval(inner_env, self.body[-1])

def __str__(self) -> str:
return self.name

def __repr__(self) -> str:
return f"#<kw-proc:{self.name}>"


class Syntax:
__slots__ = "syn"

Expand Down Expand Up @@ -898,13 +947,29 @@ def syn_define(env: Env, node: list) -> None:

n = term[0].val
parms: list[str] = []
for item in term[1:]:
if type(item) is not Sym:
raise MyError(f"{node[0]}: must be an identifier")
kparms: list[str] = []
kw_only = False

parms.append(f"{item}")
for item in term[1:]:
if kw_only:
if type(item) is Sym:
raise MyError(f"{node[0]}: {item} must be a keyword")
if type(item) is not Keyword:
raise MyError(f"{node[0]}: must be an identifier or keyword")
kparms.append(item.val)
else:
if type(item) is Keyword:
kw_only = True
kparms.append(item.val)
elif type(item) is Sym:
parms.append(item.val)
else:
raise MyError(f"{node[0]}: must be an identifier")

env[n] = UserProc(env, n, parms, body)
if kw_only:
env[n] = KeywordProc(env, n, parms, kparms, body, (len(parms), None))
else:
env[n] = UserProc(env, n, parms, body)
return None

elif type(node[1]) is not Sym:
Expand Down Expand Up @@ -1124,7 +1189,7 @@ def syn_and(env: Env, node: list) -> Any:
if is_boolarr(first):
vals = [first] + [my_eval(env, n) for n in node[2:]]
check_args(node[0], vals, (2, None), [is_boolarr])
return reduce(lambda a, b: boolop(a, b, np.logical_and), vals)
return reduce(lambda a, b: boolop(a, b, logical_and), vals)

raise MyError(f"{node[0]} expects (or/c bool? bool-array?)")

Expand All @@ -1148,7 +1213,7 @@ def syn_or(env: Env, node: list) -> Any:
if is_boolarr(first):
vals = [first] + [my_eval(env, n) for n in node[2:]]
check_args(node[0], vals, (2, None), [is_boolarr])
return reduce(lambda a, b: boolop(a, b, np.logical_or), vals)
return reduce(lambda a, b: boolop(a, b, logical_or), vals)

raise MyError(f"{node[0]} expects (or/c bool? bool-array?)")

Expand Down Expand Up @@ -1390,12 +1455,7 @@ def my_eval(env: Env, node: object) -> Any:
"begin": Proc("begin", lambda *x: x[-1] if x else None, (0, None)),
"void": Proc("void", lambda *v: None, (0, 0)),
# control / b-arrays
"not": Proc(
"not",
lambda v: not v if type(v) is bool else np.logical_not(v),
(1, 1),
[bool_or_barr],
),
"not": Proc("not", lambda v: not v if type(v) is bool else logical_not(v), (1, 1), [bool_or_barr]),
"and": Syntax(syn_and),
"or": Syntax(syn_or),
"xor": Proc("xor", _xor, (2, None), [bool_or_barr]),
Expand Down
24 changes: 21 additions & 3 deletions resources/scripts/scope.pal
Original file line number Diff line number Diff line change
@@ -1,13 +1,11 @@
#!/usr/bin/env auto-editor palet

#lang palet
; Enforce lexical scoping

; Enforce lexical scoping
(define (f x) (lambda (y) (+ x y)))
(assert (equal? ((f 10) 12) 22))

; Test that variables do not leak scope

(define (outer a)
(define (inner1 b)
(define (inner2 c) c)
Expand All @@ -21,6 +19,26 @@
(assert (not (var-exists? 'b)))
(assert (not (var-exists? 'c)))

; Test keyword arguments
(define (f1 a b c) (vector a b c))
(define (f2 a #:b #:c) (vector a b c))
;(define (f3 a #:b [#:c 0]) (vector a b c))
;(define (f4 [a 2] [#:b 1] [#:c 0]) (vector a b c))

; Invalid defines
; (define (f [a 2] b c) (vector a b c))
; (define (f a #:b c) (vector a b c))
; (define (f a [#:b 1] #:c) (vector a b c))
; (define (f [a 2] #:b [#:c 0]) (vector a b c))
; (define (f a a #:b) (void))
; (define (f a #:a #:b) (void))

(assert (equal? (f1 3 2 1) #(3 2 1)))
(assert (equal? (f2 3 2 1) #(3 2 1)))
(assert (equal? (f2 3 2 #:c 1) #(3 2 1)))
(assert (equal? (f2 3 #:b 2 #:c 1) #(3 2 1)))
(assert (equal? (f2 3 #:c 1 #:b 2) #(3 2 1)))

; Test `let` and `let*`

(assert (equal? (let ([x 5]) x) 5))
Expand Down

0 comments on commit 47a52e9

Please sign in to comment.