From dfdc8dfc69219c525a47090bbf7bda1d725fa219 Mon Sep 17 00:00:00 2001 From: WyattBlue Date: Mon, 26 Feb 2024 03:07:13 -0500 Subject: [PATCH] Make audio, motion, regular procedures --- auto_editor/analyze.py | 82 ------------------------ auto_editor/lang/palet.py | 121 ++++++++++++++++++++++++++--------- auto_editor/lib/contracts.py | 67 +++++++++++++++---- 3 files changed, 146 insertions(+), 124 deletions(-) diff --git a/auto_editor/analyze.py b/auto_editor/analyze.py index 6b442a284..b3d9c5723 100644 --- a/auto_editor/analyze.py +++ b/auto_editor/analyze.py @@ -21,13 +21,10 @@ from auto_editor.lib.data_structs import Sym from auto_editor.render.subtitle import SubtitleParser from auto_editor.utils.cmdkw import ( - ParserError, Required, - parse_with_palet, pAttr, pAttrs, ) -from auto_editor.utils.func import boolop from auto_editor.wavfile import read if TYPE_CHECKING: @@ -38,7 +35,6 @@ from numpy.typing import NDArray from auto_editor.ffwrapper import FileInfo - from auto_editor.lib.data_structs import Env from auto_editor.output import Ensure from auto_editor.utils.bar import Bar from auto_editor.utils.log import Log @@ -412,81 +408,3 @@ def motion(self, s: int, blur: int, width: int) -> NDArray[np.float64]: self.bar.end() return self.cache("motion", mobj, threshold_list[:index]) - - -def edit_method(val: str, filesetup: FileSetup, env: Env) -> NDArray[np.bool_]: - assert isinstance(filesetup, FileSetup) - src = filesetup.src - tb = filesetup.tb - ensure = filesetup.ensure - strict = filesetup.strict - bar = filesetup.bar - temp = filesetup.temp - log = filesetup.log - - if ":" in val: - method, attrs = val.split(":", 1) - else: - method, attrs = val, "" - - levels = Levels(ensure, src, tb, bar, temp, log) - - if method == "none": - return levels.none() - if method == "all/e": - return levels.all() - - try: - obj = parse_with_palet(attrs, builder_map[method], env) - except ParserError as e: - log.error(e) - - try: - if method == "audio": - s = obj["stream"] - if s == "all" or s == Sym("all"): - total_list: NDArray[np.bool_] | None = None - for s in range(len(src.audios)): - audio_list = to_threshold(levels.audio(s), obj["threshold"]) - if total_list is None: - total_list = audio_list - else: - total_list = boolop(total_list, audio_list, np.logical_or) - - if total_list is None: - if strict: - log.error("Input has no audio streams.") - stream_data = levels.all() - else: - stream_data = total_list - else: - assert isinstance(s, int) - stream_data = to_threshold(levels.audio(s), obj["threshold"]) - - assert isinstance(obj["minclip"], int) - assert isinstance(obj["mincut"], int) - - mut_remove_small(stream_data, obj["minclip"], replace=1, with_=0) - mut_remove_small(stream_data, obj["mincut"], replace=0, with_=1) - - return stream_data - - if method == "motion": - return to_threshold( - levels.motion(obj["stream"], obj["blur"], obj["width"]), - obj["threshold"], - ) - - if method == "subtitle": - return levels.subtitle( - obj["pattern"], - obj["stream"], - obj["ignore_case"], - obj["max_count"], - ) - except LevelError as e: - if strict: - log.error(e) - - return levels.all() - raise ValueError("Unreachable") diff --git a/auto_editor/lang/palet.py b/auto_editor/lang/palet.py index 09898a5d6..29981e2ed 100644 --- a/auto_editor/lang/palet.py +++ b/auto_editor/lang/palet.py @@ -18,7 +18,12 @@ import numpy as np from numpy import logical_and, logical_not, logical_or, logical_xor -from auto_editor.analyze import edit_method, mut_remove_large, mut_remove_small +from auto_editor.analyze import ( + LevelError, + mut_remove_large, + mut_remove_small, + to_threshold, +) from auto_editor.lib.contracts import * from auto_editor.lib.data_structs import * from auto_editor.lib.err import MyError @@ -49,7 +54,6 @@ class ClosingError(MyError): LPAREN, RPAREN, LBRAC, RBRAC, LCUR, RCUR, EOF = "(", ")", "[", "]", "{", "}", "EOF" VAL, QUOTE, SEC, DB, DOT, VLIT = "VAL", "QUOTE", "SEC", "DB", "DOT", "VLIT" SEC_UNITS = ("s", "sec", "secs", "second", "seconds") -METHODS = ("audio:", "motion:", "subtitle:") brac_pairs = {LPAREN: RPAREN, LBRAC: RBRAC, LCUR: RCUR} str_escape = { @@ -315,7 +319,6 @@ def get_next_token(self) -> Token: result = "" has_illegal = False - is_method = False def normal() -> bool: return ( @@ -334,21 +337,14 @@ def handle_strings() -> bool: while normal(): result += self.char - if (result + ":") in METHODS: - is_method = True - normal = handle_strings + # if (result + ":") in METHODS: + # is_method = True + # normal = handle_strings if self.char in "'`|\\": has_illegal = True self.advance() - if is_method: - return Token(VAL, Method(result)) - - for method in METHODS: - if result == method[:-1]: - return Token(VAL, Method(result)) - if self.char == ".": # handle `object.method` syntax self.advance() return Token(DOT, (Sym(result), self.get_next_token())) @@ -368,16 +364,6 @@ def handle_strings() -> bool: ############################################################################### -@dataclass(slots=True) -class Method: - val: str - - def __str__(self) -> str: - return f"#" - - __repr__ = __str__ - - class Parser: def __init__(self, lexer: Lexer): self.lexer = lexer @@ -807,7 +793,7 @@ def __call__(self, *args: Any) -> Any: @dataclass(slots=True) -class KeywordProc: +class KeywordUserProc: env: Env name: str parms: list[str] @@ -952,7 +938,7 @@ def syn_define(env: Env, node: Node) -> None: raise MyError(f"{node[0]}: must be an identifier") if kw_only: - env[n] = KeywordProc(env, n, parms, kparms, body, (len(parms), None)) + env[n] = KeywordUserProc(env, n, parms, kparms, body, (len(parms), None)) else: env[n] = UserProc(env, n, parms, (), body) return None @@ -1481,6 +1467,60 @@ def edit_all() -> np.ndarray: return env["@levels"].all() +def edit_audio( + threshold: float = 0.04, stream: object = "all", mincut: int = 6 , minclip: int = 3 +) -> np.ndarray: + if "@levels" not in env or "@filesetup" not in env: + raise MyError("Can't use `audio` if there's no input media") + + levels = env["@levels"] + src = env["@filesetup"].src + + stream_data: NDArray[np.bool_] | None = None + if stream == "all" or stream == Sym("all"): + stream_range = range(0, len(src.audios)) + else: + assert isinstance(stream, int) + stream_range = range(stream, stream + 1) + + try: + for s in stream_range: + audio_list = to_threshold(levels.audio(s), threshold) + if stream_data is None: + stream_data = audio_list + else: + stream_data = boolop(stream_data, audio_list, np.logical_or) + except LevelError as e: + raise MyError(e) + + if stream_data is not None: + mut_remove_small(stream_data, minclip, replace=1, with_=0) + mut_remove_small(stream_data, mincut, replace=0, with_=1) + + return stream_data + + return levels.all() + + +def edit_motion( + threshold: float = 0.02, + stream: int = 0, + blur: int = 9, + width: int = 400, +) -> np.ndarray: + if "@levels" not in env: + raise MyError("Can't use `motion` if there's no input media") + + return to_threshold(env["@levels"].motion(stream, blur, width), threshold) + + +def edit_subtitle(pattern, stream=0, ignore_case=False, max_count=None): + if "@levels" not in env: + raise MyError("Can't use `subtitle` if there's no input media") + + return env["@levels"].subtitle(pattern, stream, ignore_case, max_count) + + def my_eval(env: Env, node: object) -> Any: if type(node) is Sym: val = env.get(node.val) @@ -1494,11 +1534,6 @@ def my_eval(env: Env, node: object) -> Any: ) return val - if isinstance(node, Method): - if "@filesetup" not in env: - raise MyError("Can't use edit methods if there's no input files") - return edit_method(node.val, env["@filesetup"], env) - if type(node) is list: return [my_eval(env, item) for item in node] @@ -1530,7 +1565,21 @@ def my_eval(env: Env, node: object) -> Any: if type(oper) is Syntax: return oper(env, node) - return oper(*(my_eval(env, c) for c in node[1:])) + i = 1 + args: list[Any] = [] + kwargs: dict[str, Any] = {} + while i < len(node): + result = my_eval(env, node[i]) + if type(result) is Keyword: + i += 1 + if i >= len(node): + raise MyError("todo: write good error message") + kwargs[result.val] = my_eval(env, node[i]) + else: + args.append(result) + i += 1 + + return oper(*args, **kwargs) return node @@ -1545,6 +1594,16 @@ def my_eval(env: Env, node: object) -> Any: # edit procedures "none": Proc("none", edit_none, (0, 0)), "all/e": Proc("all/e", edit_all, (0, 0)), + "audio": Proc("audio", edit_audio, (0, 4), + is_threshold, orc(is_nat, Sym("all"), "all"), is_nat, + {"threshold": 0, "stream": 1, "minclip": 2, "mincut": 2} + ), + "motion": Proc("motion", edit_motion, (0, 4), + is_threshold, is_nat, is_nat, is_nat1 + ), + "subtitle": Proc("subtitle", edit_subtitle, (1, 4), + is_str, is_nat, is_bool, orc(is_nat, is_void) + ), # syntax "lambda": Syntax(syn_lambda), "λ": Syntax(syn_lambda), diff --git a/auto_editor/lib/contracts.py b/auto_editor/lib/contracts.py index f3ff7a579..421c96c99 100644 --- a/auto_editor/lib/contracts.py +++ b/auto_editor/lib/contracts.py @@ -5,7 +5,7 @@ from fractions import Fraction from typing import Any -from .data_structs import Sym, print_str +from .data_structs import Sym, display_str, print_str from .err import MyError @@ -47,7 +47,7 @@ def check_contract(c: object, val: object) -> bool: def check_args( - o: str, + name: str, values: list | tuple, arity: tuple[int, int | None], cont: tuple[Any, ...], @@ -56,7 +56,7 @@ def check_args( amount = len(values) assert not (upper is not None and lower > upper) - base = f"`{o}` has an arity mismatch. Expected " + base = f"`{name}` has an arity mismatch. Expected " if lower == upper and len(values) != lower: raise MyError(f"{base}{lower}, got {amount}") @@ -72,11 +72,11 @@ def check_args( check = cont[-1] if i >= len(cont) else cont[i] if not check_contract(check, val): exp = f"{check}" if callable(check) else print_str(check) - raise MyError(f"`{o}` expected a {exp}, got {print_str(val)}") + raise MyError(f"`{name}` expected {exp}, but got {print_str(val)}") class Proc: - __slots__ = ("name", "proc", "arity", "contracts") + __slots__ = ("name", "proc", "arity", "contracts", "kw_contracts") def __init__( self, n: str, p: Callable, a: tuple[int, int | None] = (1, None), *c: Any @@ -84,11 +84,52 @@ def __init__( self.name = n self.proc = p self.arity = a - self.contracts: tuple[Any, ...] = c - def __call__(self, *args: Any) -> Any: - check_args(self.name, args, self.arity, self.contracts) - return self.proc(*args) + if c and type(c[-1]) is dict: + self.kw_contracts: dict[str, int] | None = c[-1] + self.contracts: tuple[Any, ...] = c[:-1] + else: + self.kw_contracts = None + self.contracts = c + + def __call__(self, *args: Any, **kwargs: Any): + lower, upper = self.arity + amount = len(args) + cont = self.contracts + kws = self.kw_contracts + + assert not (upper is not None and lower > upper) + base = f"`{name}` has an arity mismatch. Expected " + + if lower == upper and len(args) != lower: + raise MyError(f"{base}{lower}, got {amount}") + if upper is None and amount < lower: + raise MyError(f"{base}at least {lower}, got {amount}") + if upper is not None and (amount > upper or amount < lower): + raise MyError(f"{base}between {lower} and {upper}, got {amount}") + + if not cont: + return + + if kws is not None: + for key, val in kwargs.items(): + check = cont[-1] if kws[key] >= len(cont) else cont[kws[key]] + if not check_contract(check, val): + exp = f"{check}" if callable(check) else print_str(check) + raise MyError( + f"`{name} #:{key}` expected {exp}, but got {print_str(val)}" + ) + + elif len(kwargs) > 0: + raise MyError("Keyword arguments are not allowed here") + + for i, val in enumerate(values): + check = cont[-1] if i >= len(cont) else cont[i] + if not check_contract(check, val): + exp = f"{check}" if callable(check) else print_str(check) + raise MyError(f"`{name}` expected {exp}, but got {print_str(val)}") + + return self.proc(*args, **kwargs) def __str__(self) -> str: return self.name @@ -137,13 +178,17 @@ def is_contract(c: object) -> bool: is_proc = Contract("procedure?", lambda v: isinstance(v, Proc | Contract)) +def contract_printer(cs: object) -> str: + return " ".join(c.name if isinstance(c, Proc | Contract) else print_str(c) for c in cs) + + def andc(*cs: object) -> Proc: - name = "(and/c " + " ".join(f"{c}" for c in cs) + ")" + name = f"(and/c {contract_printer(cs)})" return Proc(name, lambda v: all(check_contract(c, v) for c in cs), (1, 1), any_p) def orc(*cs: object) -> Proc: - name = "(or/c " + " ".join(f"{c}" for c in cs) + ")" + name = f"(or/c {contract_printer(cs)})" return Proc(name, lambda v: any(check_contract(c, v) for c in cs), (1, 1), any_p)