From f384add2f9a9bd52dea42f1cf57e3627202efc5b Mon Sep 17 00:00:00 2001 From: Huanghe Date: Fri, 27 Sep 2024 22:44:33 -0500 Subject: [PATCH] Bug fix&tests --- pyproject.toml | 2 +- src/formatron/formats/json.py | 2 +- tests/snapshots/snap_test_formatter.py | 63 +++++++++++++++++------- tests/snapshots/snap_test_grammar_gen.py | 2 +- tests/test_formatter.py | 17 +++++++ 5 files changed, 66 insertions(+), 20 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 7a6734b4..80b6ce78 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"] build-backend = "setuptools.build_meta" [project] name = "formatron" -version = "0.4.3" +version = "0.4.4" authors = [ {name = "Xintong Sun", email = "xs28@rice.edu"}, ] diff --git a/src/formatron/formats/json.py b/src/formatron/formats/json.py index 09c3bb19..491e9134 100644 --- a/src/formatron/formats/json.py +++ b/src/formatron/formats/json.py @@ -140,7 +140,7 @@ def builtin_literal(current: typing.Type, nonterminal: str): result = [] for i, arg in enumerate(args): if isinstance(arg, str): - new_items.append(f'"\\\"{repr(arg)[1:-1]}\\\""') + new_items.append(f'"\\"{repr(arg)[1:-1]}\\""') elif isinstance(arg, bool): new_items.append(f'"{str(arg).lower()}"') elif isinstance(arg, int): diff --git a/tests/snapshots/snap_test_formatter.py b/tests/snapshots/snap_test_formatter.py index fe6adca0..6fd74d01 100644 --- a/tests/snapshots/snap_test_formatter.py +++ b/tests/snapshots/snap_test_formatter.py @@ -41,7 +41,7 @@ start ::= \'Today, I want to eat \' __choice_0_0_food \'\\n\' "My food\'s ID is " __choice_3_0_ID \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately." \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __json_4_0_json \'\\n\';''' snapshots['test_formatter 2'] = '''Today, I want to eat banana -My food's ID is sweet. +My food's ID is a. What's more, indentations are handled @@ -49,7 +49,7 @@ ''' snapshots['test_formatter 3'] = { - 'ID': GenericRepr(""), + 'ID': GenericRepr(""), 'food': 'banana', 'json': GenericRepr("Test(name='Van', weight=1.4, color='red')") } @@ -118,23 +118,23 @@ start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_dict_inference 2'] = '''{"name":"example","gender":"male"} +snapshots['test_formatter_dict_inference 2'] = '''{"name":"admin","gender":"male"} ''' snapshots['test_formatter_dict_inference 3'] = { 'json': { 'gender': 'male', - 'name': 'example' + 'name': 'admin' } } -snapshots['test_formatter_json_schema 1'] = '''{"name":"mahmood","age":18} +snapshots['test_formatter_json_schema 1'] = '''{"name":"123","age":180} ''' snapshots['test_formatter_json_schema 2'] = { 'json': { - 'age': 18, - 'name': 'mahmood' + 'age': 180, + 'name': '123' } } @@ -169,7 +169,7 @@ snapshots['test_formatter_str 1'] = '''__str_0_0 ::= #e'.*?(?:\\\\.)'; start ::= __str_0_0 '\\n';''' -snapshots['test_formatter_str 2'] = '''请问这个词的典故是什么?如果没有,请提供上上文,便可以。如果提到的“典故”指的是文学作品,那么这个词可能是:A lost book, a lost song, or a lost play. +snapshots['test_formatter_str 2'] = '''🤔" I replied. ''' snapshots['test_formatter_str 3'] = { @@ -215,7 +215,7 @@ start ::= __json_0_0_json '\\n';''' -snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "A", "active": true}, {"id": 2, "name": "B", "active": true}, {"id": 3, "name": "C", "active": true}, {"id": 4, "name": "D", "active": true}] +snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "Tom", "active": true}, {"id": 2, "name": "Mike", "active": true}, {"id": 3, "name": "John", "active": true}] ''' snapshots['test_formatter_top_level_array_json_schema 3'] = { @@ -223,22 +223,51 @@ { 'active': True, 'id': 1, - 'name': 'A' + 'name': 'Tom' }, { 'active': True, 'id': 2, - 'name': 'B' + 'name': 'Mike' }, { 'active': True, 'id': 3, - 'name': 'C' - }, - { - 'active': True, - 'id': 4, - 'name': 'D' + 'name': 'John' } ] } + +snapshots['test_grammar_literal 1'] = '''integer ::= #"-?(0|[1-9]\\\\d*)"; +number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?"; +string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\'; +boolean ::= "true"|"false"; +null ::= "null"; +array ::= array_begin (json_value (comma json_value)*)? array_end; +object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end; +json_value ::= number|string|boolean|null|array|object; +comma ::= #"[ \t +\r]*,[ \t +\r]*"; +colon ::= #"[ \t +\r]*:[ \t +\r]*"; +object_begin ::= #"\\\\{[ \t +\r]*"; +object_end ::= #"[ \t +\r]*\\\\}"; +array_begin ::= #"\\\\[[ \t +\r]*"; +array_end ::= #"[ \t +\r]*\\\\]"; +__json_0_0_json ::= object_begin \'"a"\' colon __json_0_0_json_a object_end; +__json_0_0_json_a ::= "\\"114\\"" | "\\"514\\""; + +start ::= __json_0_0_json '\\n';''' + +snapshots['test_grammar_literal 2'] = '''{"a":"114"} +''' + +snapshots['test_grammar_literal 3'] = { + 'json': GenericRepr("A(a='114')") +} diff --git a/tests/snapshots/snap_test_grammar_gen.py b/tests/snapshots/snap_test_grammar_gen.py index ed8f9644..2cd5e45f 100644 --- a/tests/snapshots/snap_test_grammar_gen.py +++ b/tests/snapshots/snap_test_grammar_gen.py @@ -157,7 +157,7 @@ start_e_1 ::= string; start_e_0 ::= array_begin (start_e_0_value (comma start_e_0_value)*)? array_end; start_e_0_value ::= number; -start_c ::= "\\"114\'\\"\\"" | "\\"514\\"" | "true" | "\\"1919\\"" | "\\"810\\""; +start_c ::= "\\"114\\\'"\\"" | "\\"514\\"" | "true" | "\\"1919\\"" | "\\"810\\""; start_b ::= start_b_required?; start_b_required ::= integer; start_a ::= start_a_required?; diff --git a/tests/test_formatter.py b/tests/test_formatter.py index 806968f6..6747001a 100644 --- a/tests/test_formatter.py +++ b/tests/test_formatter.py @@ -1,3 +1,4 @@ +from typing import Literal from formatron.schemas import json_schema from formatron.schemas.dict_inference import infer_mapping from formatron.formatter import FormatterBuilder @@ -157,3 +158,19 @@ def add(a: int, b: int, /, *, c: int): snapshot.assert_match( pipeline.generate("This is a random json: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5))) snapshot.assert_match(pipeline.formatter.captures) + +def test_grammar_literal(snapshot): + FormatterBuilder._formatter_builder_counter = 0 + f = FormatterBuilder() + class A(formatron.schemas.pydantic.ClassSchema): + a: Literal['114', '514'] + f.append_line( + f"{f.json(A, capture_name='json')}") + model = RWKV( + "assets/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth", 'cuda fp16') + pipeline = formatron.integrations.RWKV.PIPELINE(model, "rwkv_vocab_v20230424", f) + np.random.seed(42) + snapshot.assert_match(pipeline.formatter.grammar_str) + snapshot.assert_match( + pipeline.generate("This is a random json: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5))) + snapshot.assert_match(pipeline.formatter.captures) \ No newline at end of file