Skip to content

Commit

Permalink
Bug fix&tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Dan-wanna-M committed Sep 28, 2024
1 parent 9f4b2c8 commit f384add
Show file tree
Hide file tree
Showing 5 changed files with 66 additions and 20 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ requires = ["setuptools>=61.0"]
build-backend = "setuptools.build_meta"
[project]
name = "formatron"
version = "0.4.3"
version = "0.4.4"
authors = [
{name = "Xintong Sun", email = "[email protected]"},
]
Expand Down
2 changes: 1 addition & 1 deletion src/formatron/formats/json.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def builtin_literal(current: typing.Type, nonterminal: str):
result = []
for i, arg in enumerate(args):
if isinstance(arg, str):
new_items.append(f'"\\\"{repr(arg)[1:-1]}\\\""')
new_items.append(f'"\\"{repr(arg)[1:-1]}\\""')
elif isinstance(arg, bool):
new_items.append(f'"{str(arg).lower()}"')
elif isinstance(arg, int):
Expand Down
63 changes: 46 additions & 17 deletions tests/snapshots/snap_test_formatter.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@
start ::= \'Today, I want to eat \' __choice_0_0_food \'\\n\' "My food\'s ID is " __choice_3_0_ID \'.\\n\' "\\nWhat\'s more, indentations\\nare handled\\nappropriately." \'My weight is 14.4kg and my color is pink. This is my personal info json: \' __json_4_0_json \'\\n\';'''

snapshots['test_formatter 2'] = '''Today, I want to eat banana
My food's ID is sweet.
My food's ID is a.
What's more, indentations
are handled
appropriately.My weight is 14.4kg and my color is pink. This is my personal info json: {"name":"Van","weight":1.4,"color":"red"}
'''

snapshots['test_formatter 3'] = {
'ID': GenericRepr("<re.Match object; span=(0, 5), match='sweet'>"),
'ID': GenericRepr("<re.Match object; span=(0, 1), match='a'>"),
'food': 'banana',
'json': GenericRepr("Test(name='Van', weight=1.4, color='red')")
}
Expand Down Expand Up @@ -118,23 +118,23 @@
start ::= __json_0_0_json '\\n';'''

snapshots['test_formatter_dict_inference 2'] = '''{"name":"example","gender":"male"}
snapshots['test_formatter_dict_inference 2'] = '''{"name":"admin","gender":"male"}
'''

snapshots['test_formatter_dict_inference 3'] = {
'json': {
'gender': 'male',
'name': 'example'
'name': 'admin'
}
}

snapshots['test_formatter_json_schema 1'] = '''{"name":"mahmood","age":18}
snapshots['test_formatter_json_schema 1'] = '''{"name":"123","age":180}
'''

snapshots['test_formatter_json_schema 2'] = {
'json': {
'age': 18,
'name': 'mahmood'
'age': 180,
'name': '123'
}
}

Expand Down Expand Up @@ -169,7 +169,7 @@
snapshots['test_formatter_str 1'] = '''__str_0_0 ::= #e'.*?(?:\\\\.)';
start ::= __str_0_0 '\\n';'''

snapshots['test_formatter_str 2'] = '''请问这个词的典故是什么?如果没有,请提供上上文,便可以。如果提到的“典故”指的是文学作品,那么这个词可能是:A lost book, a lost song, or a lost play.
snapshots['test_formatter_str 2'] = '''🤔" I replied.
'''

snapshots['test_formatter_str 3'] = {
Expand Down Expand Up @@ -215,30 +215,59 @@
start ::= __json_0_0_json '\\n';'''

snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "A", "active": true}, {"id": 2, "name": "B", "active": true}, {"id": 3, "name": "C", "active": true}, {"id": 4, "name": "D", "active": true}]
snapshots['test_formatter_top_level_array_json_schema 2'] = '''[{"id": 1, "name": "Tom", "active": true}, {"id": 2, "name": "Mike", "active": true}, {"id": 3, "name": "John", "active": true}]
'''

snapshots['test_formatter_top_level_array_json_schema 3'] = {
'json': [
{
'active': True,
'id': 1,
'name': 'A'
'name': 'Tom'
},
{
'active': True,
'id': 2,
'name': 'B'
'name': 'Mike'
},
{
'active': True,
'id': 3,
'name': 'C'
},
{
'active': True,
'id': 4,
'name': 'D'
'name': 'John'
}
]
}

snapshots['test_grammar_literal 1'] = '''integer ::= #"-?(0|[1-9]\\\\d*)";
number ::= #"-?(0|[1-9]\\\\d*)(\\\\.\\\\d+)?([eE][+-]?\\\\d+)?";
string ::= #\'"([^\\\\\\\\"\\u0000-\\u001f]|\\\\\\\\["\\\\\\\\bfnrt/]|\\\\\\\\u[0-9A-Fa-f]{4})*"\';
boolean ::= "true"|"false";
null ::= "null";
array ::= array_begin (json_value (comma json_value)*)? array_end;
object ::= object_begin (string colon json_value (comma string colon json_value)*)? object_end;
json_value ::= number|string|boolean|null|array|object;
comma ::= #"[ \t
\r]*,[ \t
\r]*";
colon ::= #"[ \t
\r]*:[ \t
\r]*";
object_begin ::= #"\\\\{[ \t
\r]*";
object_end ::= #"[ \t
\r]*\\\\}";
array_begin ::= #"\\\\[[ \t
\r]*";
array_end ::= #"[ \t
\r]*\\\\]";
__json_0_0_json ::= object_begin \'"a"\' colon __json_0_0_json_a object_end;
__json_0_0_json_a ::= "\\"114\\"" | "\\"514\\"";
start ::= __json_0_0_json '\\n';'''

snapshots['test_grammar_literal 2'] = '''{"a":"114"}
'''

snapshots['test_grammar_literal 3'] = {
'json': GenericRepr("A(a='114')")
}
2 changes: 1 addition & 1 deletion tests/snapshots/snap_test_grammar_gen.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@
start_e_1 ::= string;
start_e_0 ::= array_begin (start_e_0_value (comma start_e_0_value)*)? array_end;
start_e_0_value ::= number;
start_c ::= "\\"114\'\\"\\"" | "\\"514\\"" | "true" | "\\"1919\\"" | "\\"810\\"";
start_c ::= "\\"114\\\'"\\"" | "\\"514\\"" | "true" | "\\"1919\\"" | "\\"810\\"";
start_b ::= start_b_required?;
start_b_required ::= integer;
start_a ::= start_a_required?;
Expand Down
17 changes: 17 additions & 0 deletions tests/test_formatter.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from typing import Literal
from formatron.schemas import json_schema
from formatron.schemas.dict_inference import infer_mapping
from formatron.formatter import FormatterBuilder
Expand Down Expand Up @@ -157,3 +158,19 @@ def add(a: int, b: int, /, *, c: int):
snapshot.assert_match(
pipeline.generate("This is a random json: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5)))
snapshot.assert_match(pipeline.formatter.captures)

def test_grammar_literal(snapshot):
FormatterBuilder._formatter_builder_counter = 0
f = FormatterBuilder()
class A(formatron.schemas.pydantic.ClassSchema):
a: Literal['114', '514']
f.append_line(
f"{f.json(A, capture_name='json')}")
model = RWKV(
"assets/RWKV-5-World-0.4B-v2-20231113-ctx4096.pth", 'cuda fp16')
pipeline = formatron.integrations.RWKV.PIPELINE(model, "rwkv_vocab_v20230424", f)
np.random.seed(42)
snapshot.assert_match(pipeline.formatter.grammar_str)
snapshot.assert_match(
pipeline.generate("This is a random json: ", token_count=256, args=formatron.integrations.RWKV.PIPELINE_ARGS(top_p=0.5)))
snapshot.assert_match(pipeline.formatter.captures)

0 comments on commit f384add

Please sign in to comment.