From c65bacf44c12eac543d5c4075f6c8a9621b058bc Mon Sep 17 00:00:00 2001 From: WyattBlue Date: Wed, 18 Oct 2023 17:19:22 -0400 Subject: [PATCH] Add class special form --- auto_editor/lang/palet.py | 170 +++++++++++++++++++++++++------- auto_editor/lib/contracts.py | 20 ++-- auto_editor/lib/data_structs.py | 19 ++++ setup.py | 2 +- 4 files changed, 161 insertions(+), 50 deletions(-) diff --git a/auto_editor/lang/palet.py b/auto_editor/lang/palet.py index 0451cad90..a012dc13f 100644 --- a/auto_editor/lang/palet.py +++ b/auto_editor/lang/palet.py @@ -1043,53 +1043,104 @@ def guard_term(node: Node, n: int, u: int) -> None: def syn_set(env: Env, node: Node) -> None: guard_term(node, 3, 3) - if type(node[1]) is not Sym: - raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") + if type(node[1]) is Sym: + name = node[1].val + if name not in env: + raise MyError(f"{node[0]}: Can't set variable `{name}` before definition") + env[name] = my_eval(env, node[2]) + return None - name = node[1].val - if name not in env: - raise MyError(f"{node[0]}: Can't set variable `{name}` before definition") - env[name] = my_eval(env, node[2]) + if type(node[1]) is tuple and len(node[1]) == 3 and node[1][0] == Sym("@r"): + base = my_eval(env, node[1][1]) + name = node[1][2].val + for i, item in enumerate(base.attrs[0::2]): + if name == item: + result = my_eval(env, node[2]) + check_args(item, (result,), (1, 1), (base.attrs[i * 2 + 1],)) + base.values[i] = result + return None + + raise MyError(f"{node[0]}: {base.name} has no attribute `{name}`") + + raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") def syn_incf(env: Env, node: Node) -> None: guard_term(node, 2, 3) - if type(node[1]) is not Sym: - raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") - name = node[1].val + incre_by = 1 + if len(node) == 3: + incre_by = my_eval(env, node[2]) + if not is_num(incre_by): + raise MyError(f"{node[0]}: Expected number? got: {print_str(incre_by)}") - if env[name] is None: - raise MyError(f"{node[0]}: `{name}` is not defined") - if not is_num(env[name]): - raise MyError(f"{node[0]}: `{name}` is not a number?") + if type(node[1]) is Sym: + name = node[1].val - if len(node) == 3: - if not is_num(num := my_eval(env, node[2])): - raise MyError(f"{node[0]}: Expected number? got: {print_str(num)}") - env[name] += num - else: - env[name] += 1 + if type(env[name]) is NotFound: + raise MyError(f"{node[0]}: `{name}` is not defined") + if not is_num(env[name]): + raise MyError(f"{node[0]}: `{name}` is not a number?") + env[name] += incre_by + return None + + if type(node[1]) is tuple and len(node[1]) == 3 and node[1][0] == Sym("@r"): + base = my_eval(env, node[1][1]) + if type(node[1][2]) is not Sym: + raise MyError(f"{node[0]}: class attribute must be an identifier") + name = node[1][2].val + for i, item in enumerate(base.attrs[0::2]): + if name == item: + if not is_num(base.values[i]): + raise MyError(f"{node[0]}: `{name}` is not a number?") + + check_args( + name, (base.values[i] + incre_by,), (1, 1), (base.attrs[i * 2 + 1],) + ) + base.values[i] += incre_by + return None + raise MyError(f"{node[0]}: {base.name} has no attribute `{name}`") + + raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") def syn_decf(env: Env, node: Node) -> None: guard_term(node, 2, 3) - if type(node[1]) is not Sym: - raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") - name = node[1].val + incre_by = 1 + if len(node) == 3: + incre_by = my_eval(env, node[2]) + if not is_num(incre_by): + raise MyError(f"{node[0]}: Expected number? got: {print_str(incre_by)}") - if env[name] is None: - raise MyError(f"{node[0]}: `{name}` is not defined") - if not is_num(env[name]): - raise MyError(f"{node[0]}: `{name}` is not a number?") + if type(node[1]) is Sym: + name = node[1].val - if len(node) == 3: - if not is_num(num := my_eval(env, node[2])): - raise MyError(f"{node[0]}: Expected number? got: {print_str(num)}") - env[name] -= num - else: - env[name] -= 1 + if type(env[name]) is NotFound: + raise MyError(f"{node[0]}: `{name}` is not defined") + if not is_num(env[name]): + raise MyError(f"{node[0]}: `{name}` is not a number?") + env[name] -= incre_by + return None + + if type(node[1]) is tuple and len(node[1]) == 3 and node[1][0] == Sym("@r"): + base = my_eval(env, node[1][1]) + if type(node[1][2]) is not Sym: + raise MyError(f"{node[0]}: class attribute must be an identifier") + name = node[1][2].val + for i, item in enumerate(base.attrs[0::2]): + if name == item: + if not is_num(base.values[i]): + raise MyError(f"{node[0]}: `{name}` is not a number?") + + check_args( + name, (base.values[i] - incre_by,), (1, 1), (base.attrs[i * 2 + 1],) + ) + base.values[i] -= incre_by + return None + raise MyError(f"{node[0]}: {base.name} has no attribute `{name}`") + + raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") def syn_strappend(env: Env, node: Node) -> None: @@ -1099,7 +1150,7 @@ def syn_strappend(env: Env, node: Node) -> None: raise MyError(f"{node[0]}: Expected identifier, got {print_str(node[1])}") name = node[1].val - if env[name] is None: + if type(env[name]) is NotFound: raise MyError(f"{node[0]}: `{name}` is not defined") if not is_str(env[name]): raise MyError(f"{node[0]}: `{name}` is not a string?") @@ -1344,8 +1395,40 @@ def syn_let_star(env: Env, node: Node) -> Any: return my_eval(inner_env, node[-1]) -def syn_class(env: Env, node: Node) -> Any: - ... +def syn_class(env: Env, node: Node) -> None: + if len(node) < 2: + raise MyError(f"{node[0]}: Expects at least 1 term") + + if type(node[1]) is not Sym: + raise MyError("class name must be an identifier") + + attr_len = len(node) - 2 + attrs: Any = [None] * (attr_len * 2) + contracts = [None] * attr_len + + for i, item in enumerate(node[2:]): + if type(item) is not tuple or len(item) != 2: + raise MyError(f"{node[0]}: Invalid syntax") + + contracts[i] = my_eval(env, item[1]) + attrs[i * 2] = item[0].val + attrs[i * 2 + 1] = contracts[i] + + name = node[1].val + pred = name + "?" + attrs = tuple(attrs) + + env[name] = Proc( + name, + lambda *args: PaletClass(name, attrs, list(args)), + (attr_len, attr_len), + *contracts, + ) + env[pred] = Proc( + pred, + lambda v: type(v) is PaletClass and v.name == name and v.attrs == attrs, + (1, 1), + ) def attr(env: Env, node: Node) -> Any: @@ -1354,6 +1437,15 @@ def attr(env: Env, node: Node) -> Any: if type(node[2]) is not Sym: raise MyError("@r: attribute must be an identifier") + base = my_eval(env, node[1]) + if type(base) is PaletClass: + if type(name := node[2]) is not Sym: + raise MyError("@r: class attribute must be an identifier") + + for i, item in enumerate(base.attrs[0::2]): + if name.val == item: + return base.values[i] + return my_eval(env, (node[2], node[1])) @@ -1379,7 +1471,9 @@ def my_eval(env: Env, node: object) -> Any: raise MyError( f"variable `{node.val}` not found. Did you mean: {mat[0]}" ) - raise MyError(f'variable `{node.val}` not found. Did you mean a string literal.') + raise MyError( + f"variable `{node.val}` not found. Did you mean a string literal." + ) return val if isinstance(node, Method): @@ -1447,7 +1541,7 @@ def my_eval(env: Env, node: object) -> Any: "case": Syntax(syn_case), "let": Syntax(syn_let), "let*": Syntax(syn_let_star), - #"class": Syntax(syn_class), + "class": Syntax(syn_class), "@r": Syntax(attr), # loops "for": Syntax(syn_for), @@ -1527,7 +1621,7 @@ def my_eval(env: Env, node: object) -> Any: "min": Proc("min", lambda *v: min(v), (1, None), is_real), "sin": Proc("sin", math.sin, (1, 1), is_real), "cos": Proc("cos", math.cos, (1, 1), is_real), - "log": Proc("log", math.log, (1, 2), is_real), + "log": Proc("log", math.log, (1, 2), andc(is_real, gt_c(0))), "tan": Proc("tan", math.tan, (1, 1), is_real), "mod": Proc("mod", lambda a, b: a % b, (2, 2), is_int), "modulo": Proc("modulo", lambda a, b: a % b, (2, 2), is_int), diff --git a/auto_editor/lib/contracts.py b/auto_editor/lib/contracts.py index 7bf3b02ce..523f76fd0 100644 --- a/auto_editor/lib/contracts.py +++ b/auto_editor/lib/contracts.py @@ -137,35 +137,33 @@ def is_contract(c: object) -> bool: def andc(*cs: object) -> Proc: - return Proc( - "flat-and/c", lambda v: all([check_contract(c, v) for c in cs]), (1, 1), any_p - ) + name = "(and/c " + " ".join((f"{c}" for c in cs)) + ")" + return Proc(name, lambda v: all((check_contract(c, v) for c in cs)), (1, 1), any_p) def orc(*cs: object) -> Proc: - return Proc( - "flat-or/c", lambda v: any([check_contract(c, v) for c in cs]), (1, 1), any_p - ) + name = "(or/c " + " ".join((f"{c}" for c in cs)) + ")" + return Proc(name, lambda v: any((check_contract(c, v) for c in cs)), (1, 1), any_p) def notc(c: object) -> Proc: - return Proc("flat-not/c", lambda v: not check_contract(c, v), (1, 1), [any_p]) + return Proc("flat-not/c", lambda v: not check_contract(c, v), (1, 1), any_p) def gte_c(n: int | float | Fraction) -> Proc: - return Proc(f"(>=/c {n})", lambda i: i >= n, (1, 1), [is_real]) + return Proc(f"(>=/c {n})", lambda i: i >= n, (1, 1), is_real) def gt_c(n: int | float | Fraction) -> Proc: - return Proc(f"(>/c {n})", lambda i: i > n, (1, 1), [is_real]) + return Proc(f"(>/c {n})", lambda i: i > n, (1, 1), is_real) def lte_c(n: int | float | Fraction) -> Proc: - return Proc(f"(<=/c {n})", lambda i: i <= n, (1, 1), [is_real]) + return Proc(f"(<=/c {n})", lambda i: i <= n, (1, 1), is_real) def lt_c(n: int | float | Fraction) -> Proc: - return Proc(f"( Proc: diff --git a/auto_editor/lib/data_structs.py b/auto_editor/lib/data_structs.py index 24d277dda..371ec0b4a 100644 --- a/auto_editor/lib/data_structs.py +++ b/auto_editor/lib/data_structs.py @@ -1,6 +1,7 @@ from __future__ import annotations from collections.abc import Iterator +from dataclasses import dataclass from fractions import Fraction from io import StringIO from typing import Any @@ -11,6 +12,7 @@ class NotFound: pass + class Env: __slots__ = ("data", "outer") @@ -257,3 +259,20 @@ def print_str(val: object) -> str: return f"'{display_str(val)}" return display_str(val) + + +@dataclass(slots=True) +class PaletClass: + name: str + attrs: tuple + values: list + + def __str__(self) -> str: + result = StringIO() + result.write(f"({self.name}") + for i, val in enumerate(self.values): + result.write(f" #:{self.attrs[i * 2]} {print_str(val)}") + result.write(")") + return result.getvalue() + + __repr__ = __str__ diff --git a/setup.py b/setup.py index 8d795ba80..7d1c6306f 100644 --- a/setup.py +++ b/setup.py @@ -40,7 +40,7 @@ def pip_version(): install_requires=[ "numpy>=1.22.0", "pillow==10.0.1", - "pyav==11.0.1", + "pyav==11.1.0", "ae-ffmpeg==1.1.*", ], python_requires=">=3.10",