Skip to content

Commit

Permalink
Add class special form
Browse files Browse the repository at this point in the history
  • Loading branch information
WyattBlue committed Oct 18, 2023
1 parent e49591c commit c65bacf
Show file tree
Hide file tree
Showing 4 changed files with 161 additions and 50 deletions.
170 changes: 132 additions & 38 deletions auto_editor/lang/palet.py
Original file line number Diff line number Diff line change
Expand Up @@ -1043,53 +1043,104 @@ def guard_term(node: Node, n: int, u: int) -> None:
def syn_set(env: Env, node: Node) -> None:
guard_term(node, 3, 3)

if type(node[1]) is not Sym:
raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")
if type(node[1]) is Sym:
name = node[1].val
if name not in env:
raise MyError(f"{node[0]}: Can't set variable `{name}` before definition")
env[name] = my_eval(env, node[2])
return None

name = node[1].val
if name not in env:
raise MyError(f"{node[0]}: Can't set variable `{name}` before definition")
env[name] = my_eval(env, node[2])
if type(node[1]) is tuple and len(node[1]) == 3 and node[1][0] == Sym("@r"):
base = my_eval(env, node[1][1])
name = node[1][2].val
for i, item in enumerate(base.attrs[0::2]):
if name == item:
result = my_eval(env, node[2])
check_args(item, (result,), (1, 1), (base.attrs[i * 2 + 1],))
base.values[i] = result
return None

raise MyError(f"{node[0]}: {base.name} has no attribute `{name}`")

raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")


def syn_incf(env: Env, node: Node) -> None:
guard_term(node, 2, 3)

if type(node[1]) is not Sym:
raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")
name = node[1].val
incre_by = 1
if len(node) == 3:
incre_by = my_eval(env, node[2])
if not is_num(incre_by):
raise MyError(f"{node[0]}: Expected number? got: {print_str(incre_by)}")

if env[name] is None:
raise MyError(f"{node[0]}: `{name}` is not defined")
if not is_num(env[name]):
raise MyError(f"{node[0]}: `{name}` is not a number?")
if type(node[1]) is Sym:
name = node[1].val

if len(node) == 3:
if not is_num(num := my_eval(env, node[2])):
raise MyError(f"{node[0]}: Expected number? got: {print_str(num)}")
env[name] += num
else:
env[name] += 1
if type(env[name]) is NotFound:
raise MyError(f"{node[0]}: `{name}` is not defined")
if not is_num(env[name]):
raise MyError(f"{node[0]}: `{name}` is not a number?")
env[name] += incre_by
return None

if type(node[1]) is tuple and len(node[1]) == 3 and node[1][0] == Sym("@r"):
base = my_eval(env, node[1][1])
if type(node[1][2]) is not Sym:
raise MyError(f"{node[0]}: class attribute must be an identifier")
name = node[1][2].val
for i, item in enumerate(base.attrs[0::2]):
if name == item:
if not is_num(base.values[i]):
raise MyError(f"{node[0]}: `{name}` is not a number?")

check_args(
name, (base.values[i] + incre_by,), (1, 1), (base.attrs[i * 2 + 1],)
)
base.values[i] += incre_by
return None
raise MyError(f"{node[0]}: {base.name} has no attribute `{name}`")

raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")


def syn_decf(env: Env, node: Node) -> None:
guard_term(node, 2, 3)

if type(node[1]) is not Sym:
raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")
name = node[1].val
incre_by = 1
if len(node) == 3:
incre_by = my_eval(env, node[2])
if not is_num(incre_by):
raise MyError(f"{node[0]}: Expected number? got: {print_str(incre_by)}")

if env[name] is None:
raise MyError(f"{node[0]}: `{name}` is not defined")
if not is_num(env[name]):
raise MyError(f"{node[0]}: `{name}` is not a number?")
if type(node[1]) is Sym:
name = node[1].val

if len(node) == 3:
if not is_num(num := my_eval(env, node[2])):
raise MyError(f"{node[0]}: Expected number? got: {print_str(num)}")
env[name] -= num
else:
env[name] -= 1
if type(env[name]) is NotFound:
raise MyError(f"{node[0]}: `{name}` is not defined")
if not is_num(env[name]):
raise MyError(f"{node[0]}: `{name}` is not a number?")
env[name] -= incre_by
return None

if type(node[1]) is tuple and len(node[1]) == 3 and node[1][0] == Sym("@r"):
base = my_eval(env, node[1][1])
if type(node[1][2]) is not Sym:
raise MyError(f"{node[0]}: class attribute must be an identifier")
name = node[1][2].val
for i, item in enumerate(base.attrs[0::2]):
if name == item:
if not is_num(base.values[i]):
raise MyError(f"{node[0]}: `{name}` is not a number?")

check_args(
name, (base.values[i] - incre_by,), (1, 1), (base.attrs[i * 2 + 1],)
)
base.values[i] -= incre_by
return None
raise MyError(f"{node[0]}: {base.name} has no attribute `{name}`")

raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")


def syn_strappend(env: Env, node: Node) -> None:
Expand All @@ -1099,7 +1150,7 @@ def syn_strappend(env: Env, node: Node) -> None:
raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}")
name = node[1].val

if env[name] is None:
if type(env[name]) is NotFound:
raise MyError(f"{node[0]}: `{name}` is not defined")
if not is_str(env[name]):
raise MyError(f"{node[0]}: `{name}` is not a string?")
Expand Down Expand Up @@ -1344,8 +1395,40 @@ def syn_let_star(env: Env, node: Node) -> Any:
return my_eval(inner_env, node[-1])


def syn_class(env: Env, node: Node) -> Any:
...
def syn_class(env: Env, node: Node) -> None:
if len(node) < 2:
raise MyError(f"{node[0]}: Expects at least 1 term")

if type(node[1]) is not Sym:
raise MyError("class name must be an identifier")

attr_len = len(node) - 2
attrs: Any = [None] * (attr_len * 2)
contracts = [None] * attr_len

for i, item in enumerate(node[2:]):
if type(item) is not tuple or len(item) != 2:
raise MyError(f"{node[0]}: Invalid syntax")

contracts[i] = my_eval(env, item[1])
attrs[i * 2] = item[0].val
attrs[i * 2 + 1] = contracts[i]

name = node[1].val
pred = name + "?"
attrs = tuple(attrs)

env[name] = Proc(
name,
lambda *args: PaletClass(name, attrs, list(args)),
(attr_len, attr_len),
*contracts,
)
env[pred] = Proc(
pred,
lambda v: type(v) is PaletClass and v.name == name and v.attrs == attrs,
(1, 1),
)


def attr(env: Env, node: Node) -> Any:
Expand All @@ -1354,6 +1437,15 @@ def attr(env: Env, node: Node) -> Any:
if type(node[2]) is not Sym:
raise MyError("@r: attribute must be an identifier")

base = my_eval(env, node[1])
if type(base) is PaletClass:
if type(name := node[2]) is not Sym:
raise MyError("@r: class attribute must be an identifier")

for i, item in enumerate(base.attrs[0::2]):
if name.val == item:
return base.values[i]

return my_eval(env, (node[2], node[1]))


Expand All @@ -1379,7 +1471,9 @@ def my_eval(env: Env, node: object) -> Any:
raise MyError(
f"variable `{node.val}` not found. Did you mean: {mat[0]}"
)
raise MyError(f'variable `{node.val}` not found. Did you mean a string literal.')
raise MyError(
f"variable `{node.val}` not found. Did you mean a string literal."
)
return val

if isinstance(node, Method):
Expand Down Expand Up @@ -1447,7 +1541,7 @@ def my_eval(env: Env, node: object) -> Any:
"case": Syntax(syn_case),
"let": Syntax(syn_let),
"let*": Syntax(syn_let_star),
#"class": Syntax(syn_class),
"class": Syntax(syn_class),
"@r": Syntax(attr),
# loops
"for": Syntax(syn_for),
Expand Down Expand Up @@ -1527,7 +1621,7 @@ def my_eval(env: Env, node: object) -> Any:
"min": Proc("min", lambda *v: min(v), (1, None), is_real),
"sin": Proc("sin", math.sin, (1, 1), is_real),
"cos": Proc("cos", math.cos, (1, 1), is_real),
"log": Proc("log", math.log, (1, 2), is_real),
"log": Proc("log", math.log, (1, 2), andc(is_real, gt_c(0))),
"tan": Proc("tan", math.tan, (1, 1), is_real),
"mod": Proc("mod", lambda a, b: a % b, (2, 2), is_int),
"modulo": Proc("modulo", lambda a, b: a % b, (2, 2), is_int),
Expand Down
20 changes: 9 additions & 11 deletions auto_editor/lib/contracts.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,35 +137,33 @@ def is_contract(c: object) -> bool:


def andc(*cs: object) -> Proc:
return Proc(
"flat-and/c", lambda v: all([check_contract(c, v) for c in cs]), (1, 1), any_p
)
name = "(and/c " + " ".join((f"{c}" for c in cs)) + ")"
return Proc(name, lambda v: all((check_contract(c, v) for c in cs)), (1, 1), any_p)


def orc(*cs: object) -> Proc:
return Proc(
"flat-or/c", lambda v: any([check_contract(c, v) for c in cs]), (1, 1), any_p
)
name = "(or/c " + " ".join((f"{c}" for c in cs)) + ")"
return Proc(name, lambda v: any((check_contract(c, v) for c in cs)), (1, 1), any_p)


def notc(c: object) -> Proc:
return Proc("flat-not/c", lambda v: not check_contract(c, v), (1, 1), [any_p])
return Proc("flat-not/c", lambda v: not check_contract(c, v), (1, 1), any_p)


def gte_c(n: int | float | Fraction) -> Proc:
return Proc(f"(>=/c {n})", lambda i: i >= n, (1, 1), [is_real])
return Proc(f"(>=/c {n})", lambda i: i >= n, (1, 1), is_real)


def gt_c(n: int | float | Fraction) -> Proc:
return Proc(f"(>/c {n})", lambda i: i > n, (1, 1), [is_real])
return Proc(f"(>/c {n})", lambda i: i > n, (1, 1), is_real)


def lte_c(n: int | float | Fraction) -> Proc:
return Proc(f"(<=/c {n})", lambda i: i <= n, (1, 1), [is_real])
return Proc(f"(<=/c {n})", lambda i: i <= n, (1, 1), is_real)


def lt_c(n: int | float | Fraction) -> Proc:
return Proc(f"(</c {n})", lambda i: i < n, (1, 1), [is_real])
return Proc(f"(</c {n})", lambda i: i < n, (1, 1), is_real)


def between_c(n: Any, m: Any) -> Proc:
Expand Down
19 changes: 19 additions & 0 deletions auto_editor/lib/data_structs.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

from collections.abc import Iterator
from dataclasses import dataclass
from fractions import Fraction
from io import StringIO
from typing import Any
Expand All @@ -11,6 +12,7 @@
class NotFound:
pass


class Env:
__slots__ = ("data", "outer")

Expand Down Expand Up @@ -257,3 +259,20 @@ def print_str(val: object) -> str:
return f"'{display_str(val)}"

return display_str(val)


@dataclass(slots=True)
class PaletClass:
name: str
attrs: tuple
values: list

def __str__(self) -> str:
result = StringIO()
result.write(f"({self.name}")
for i, val in enumerate(self.values):
result.write(f" #:{self.attrs[i * 2]} {print_str(val)}")
result.write(")")
return result.getvalue()

__repr__ = __str__
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def pip_version():
install_requires=[
"numpy>=1.22.0",
"pillow==10.0.1",
"pyav==11.0.1",
"pyav==11.1.0",
"ae-ffmpeg==1.1.*",
],
python_requires=">=3.10",
Expand Down

0 comments on commit c65bacf

Please sign in to comment.