Skip to content

Commit

Permalink
Merge pull request #4 from amoffat/feature/recursive-types
Browse files Browse the repository at this point in the history
recursive types support
  • Loading branch information
amoffat authored Sep 11, 2024
2 parents d8b0b32 + 5e77730 commit 6c1a4e2
Show file tree
Hide file tree
Showing 11 changed files with 394 additions and 69 deletions.
4 changes: 4 additions & 0 deletions .flake8
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[flake8]
max-line-length = 80
select = C,E,F,W,B,B950
ignore = E501
4 changes: 3 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,6 @@ __pycache__/
/.coverage
/docs/build
TODO.md
/demo.py
/demo.py
.stignore
.pytest_cache/
4 changes: 4 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# 0.6.0 - 9/10/24

- Recursive types

# 0.5.0 - 9/3/24

- Handle `Optional` arguments
Expand Down
73 changes: 73 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,79 @@ print(like_inception) # Prints a movie similar to inception

```

## Recursive types

It can handle self-referential types. For example, each `Character` has a `social_graph`, and each `SocialGraph` is composed of `Characters`.

```python
from dataclasses import dataclass
from pprint import pprint

from manifest import ai


@dataclass
class Character:
name: str
occupation: str
social_graph: "SocialGraph"


@dataclass
class SocialGraph:
friends: list[Character]
enemies: list[Character]


@ai
def get_character_social_graph(character_name: str) -> SocialGraph:
"""For a given fictional character, return their social graph, resolving
each friend and enemy's social graph recursively."""


graph = get_character_social_graph("Walter White")
pprint(graph)

```

```
SocialGraph(
friends=[
Character(
name='Jesse Pinkman',
occupation='Meth Manufacturer',
social_graph=SocialGraph(
friends=[Character(name='Walter White', occupation='Chemistry Teacher', social_graph=SocialGraph(friends=[], enemies=[]))],
enemies=[Character(name='Hank Schrader', occupation='DEA Agent', social_graph=SocialGraph(friends=[], enemies=[]))]
)
),
Character(
name='Saul Goodman',
occupation='Lawyer',
social_graph=SocialGraph(friends=[Character(name='Walter White', occupation='Chemistry Teacher', social_graph=SocialGraph(friends=[], enemies=[]))], enemies=[])
)
],
enemies=[
Character(
name='Hank Schrader',
occupation='DEA Agent',
social_graph=SocialGraph(
friends=[Character(name='Marie Schrader', occupation='Radiologic Technologist', social_graph=SocialGraph(friends=[], enemies=[]))],
enemies=[Character(name='Walter White', occupation='Meth Manufacturer', social_graph=SocialGraph(friends=[], enemies=[]))]
)
),
Character(
name='Gus Fring',
occupation='Businessman',
social_graph=SocialGraph(
friends=[Character(name='Mike Ehrmantraut', occupation='Fixer', social_graph=SocialGraph(friends=[], enemies=[]))],
enemies=[Character(name='Walter White', occupation='Meth Manufacturer', social_graph=SocialGraph(friends=[], enemies=[]))]
)
)
]
)
```

# How does it work?

Manifest relies heavily on runtime metadata, such as a function's name,
Expand Down
20 changes: 16 additions & 4 deletions manifest/decorator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from functools import wraps
from io import BytesIO
from types import UnionType
from typing import Any, Callable, get_origin, get_type_hints
from typing import Any, Callable, cast, get_origin, get_type_hints

from manifest import exc, initialize, parser, serde, tmpl
from manifest.llm.base import LLM
Expand Down Expand Up @@ -31,6 +31,8 @@ def ai(*decorator_args, **decorator_kwargs) -> Callable:
def outer(fn: Callable) -> Callable:
"""Decorator that wraps a function and produces a function that, when
executed, will call the LLM to provide the return value."""
caller_frame = inspect.currentframe().f_back.f_back # type: ignore
caller_ns = cast(dict[str, Any], caller_frame.f_globals) # type: ignore

ants = get_type_hints(fn)
name = fn.__name__
Expand All @@ -40,7 +42,11 @@ def outer(fn: Callable) -> Callable:
# restricted to hydrating only these types.
type_registry: dict[str, Any] = {}
for value in ants.values():
extract_type_registry(type_registry, value)
extract_type_registry(
registry=type_registry,
obj=value,
caller_ns=caller_ns,
)

try:
return_type = ants["return"]
Expand All @@ -63,7 +69,10 @@ def inner(*exec_args, **exec_kwargs) -> Any:
# class may know about what the LLM service supports, for example,
# if the type needs to be wrapped in an envelope. We do the same
# with deserialization, later in the call.
return_type_spec = llm.serialize(return_type)
return_type_spec = llm.serialize(
return_type=return_type,
caller_ns=caller_ns,
)
return_type_spec_json = json.dumps(return_type_spec, indent=2)

service: Service = llm.service()
Expand Down Expand Up @@ -129,7 +138,10 @@ def inner(*exec_args, **exec_kwargs) -> Any:
args.append(
{
"name": arg_name,
"schema": serde.serialize(arg_type),
"schema": serde.serialize(
data_type=arg_type,
caller_ns=caller_ns,
),
"value": arg_value,
"src": src,
}
Expand Down
6 changes: 5 additions & 1 deletion manifest/llm/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,11 @@ def service() -> "Service":

@staticmethod
@abstractmethod
def serialize(return_type: Type | UnionType) -> Any:
def serialize(
*,
return_type: Type | UnionType,
caller_ns: dict[str, Any],
) -> Any:
"""Serialize the return type of a function into jsonschema"""

@staticmethod
Expand Down
15 changes: 13 additions & 2 deletions manifest/llm/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,15 +29,22 @@ def service() -> Service:
return Service.OPENAI

@staticmethod
def serialize(return_type: Type | UnionType) -> Any:
def serialize(
*,
return_type: Type | UnionType,
caller_ns: dict[str, Any],
) -> Any:
"""Serialize our return type to jsonschema, while also wrapping it in
an envelope because OpenAI does not support top-level simple types
(type: string), only objects."""
ResponseEnvelope = make_dataclass(
"ResponseEnvelope",
[("contents", return_type)],
)
return serde.serialize(ResponseEnvelope)
return serde.serialize(
data_type=ResponseEnvelope,
caller_ns=caller_ns,
)

@staticmethod
def deserialize(
Expand All @@ -48,7 +55,11 @@ def deserialize(
"""Deserialize our data from OpenAI into the expected return type,
while taking care to strip off the envelope beforehand."""
data = data["contents"]
old_defs = schema.pop("$defs", None)
schema = schema["properties"]["contents"]
if old_defs:
schema["$defs"] = old_defs

obj = serde.deserialize(
schema=schema,
data=data,
Expand Down
Loading

0 comments on commit 6c1a4e2

Please sign in to comment.