Skip to content

Commit

Permalink
Move json and xml parsers to core (#15026)
Browse files Browse the repository at this point in the history
<!-- Thank you for contributing to LangChain!

Please title your PR "<package>: <description>", where <package> is
whichever of langchain, community, core, experimental, etc. is being
modified.

Replace this entire comment with:
  - **Description:** a description of the change, 
  - **Issue:** the issue # it fixes if applicable,
  - **Dependencies:** any dependencies required for this change,
- **Twitter handle:** we announce bigger features on Twitter. If your PR
gets announced, and you'd like a mention, we'll gladly shout you out!

Please make sure your PR is passing linting and testing before
submitting. Run `make format`, `make lint` and `make test` from the root
of the package you've modified to check this locally.

See contribution guidelines for more information on how to write/run
tests, lint, etc: https://python.langchain.com/docs/contributing/

If you're adding a new integration, please include:
1. a test for the integration, preferably unit tests that do not rely on
network access,
2. an example notebook showing its use. It lives in
`docs/docs/integrations` directory.

If no one reviews your PR within a few days, please @-mention one of
@baskaryan, @eyurtsev, @hwchase17.
 -->
  • Loading branch information
nfcampos authored Dec 21, 2023
1 parent d5533b7 commit 71076cc
Show file tree
Hide file tree
Showing 10 changed files with 841 additions and 421 deletions.
4 changes: 4 additions & 0 deletions libs/core/langchain_core/output_parsers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
BaseLLMOutputParser,
BaseOutputParser,
)
from langchain_core.output_parsers.json import SimpleJsonOutputParser
from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser,
ListOutputParser,
Expand All @@ -14,6 +15,7 @@
BaseCumulativeTransformOutputParser,
BaseTransformOutputParser,
)
from langchain_core.output_parsers.xml import XMLOutputParser

__all__ = [
"BaseLLMOutputParser",
Expand All @@ -26,4 +28,6 @@
"StrOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
]
195 changes: 195 additions & 0 deletions libs/core/langchain_core/output_parsers/json.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,195 @@
from __future__ import annotations

import json
import re
from json import JSONDecodeError
from typing import Any, Callable, List, Optional

import jsonpatch # type: ignore[import]

from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser


def _replace_new_line(match: re.Match[str]) -> str:
value = match.group(2)
value = re.sub(r"\n", r"\\n", value)
value = re.sub(r"\r", r"\\r", value)
value = re.sub(r"\t", r"\\t", value)
value = re.sub(r'(?<!\\)"', r"\"", value)

return match.group(1) + value + match.group(3)


def _custom_parser(multiline_string: str) -> str:
"""
The LLM response for `action_input` may be a multiline
string containing unescaped newlines, tabs or quotes. This function
replaces those characters with their escaped counterparts.
(newlines in JSON must be double-escaped: `\\n`)
"""
if isinstance(multiline_string, (bytes, bytearray)):
multiline_string = multiline_string.decode()

multiline_string = re.sub(
r'("action_input"\:\s*")(.*)(")',
_replace_new_line,
multiline_string,
flags=re.DOTALL,
)

return multiline_string


# Adapted from https://github.com/KillianLucas/open-interpreter/blob/main/interpreter/utils/parse_partial_json.py
# MIT License
def parse_partial_json(s: str, *, strict: bool = False) -> Any:
"""Parse a JSON string that may be missing closing braces.
Args:
s: The JSON string to parse.
strict: Whether to use strict parsing. Defaults to False.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Attempt to parse the string as-is.
try:
return json.loads(s, strict=strict)
except json.JSONDecodeError:
pass

# Initialize variables.
new_s = ""
stack = []
is_inside_string = False
escaped = False

# Process each character in the string one at a time.
for char in s:
if is_inside_string:
if char == '"' and not escaped:
is_inside_string = False
elif char == "\n" and not escaped:
char = "\\n" # Replace the newline character with the escape sequence.
elif char == "\\":
escaped = not escaped
else:
escaped = False
else:
if char == '"':
is_inside_string = True
escaped = False
elif char == "{":
stack.append("}")
elif char == "[":
stack.append("]")
elif char == "}" or char == "]":
if stack and stack[-1] == char:
stack.pop()
else:
# Mismatched closing character; the input is malformed.
return None

# Append the processed character to the new string.
new_s += char

# If we're still inside a string at the end of processing,
# we need to close the string.
if is_inside_string:
new_s += '"'

# Close any remaining open structures in the reverse order that they were opened.
for closing_char in reversed(stack):
new_s += closing_char

# Attempt to parse the modified string as JSON.
try:
return json.loads(new_s, strict=strict)
except json.JSONDecodeError:
# If we still can't parse the string as JSON, return None to indicate failure.
return None


def parse_json_markdown(
json_string: str, *, parser: Callable[[str], Any] = json.loads
) -> dict:
"""
Parse a JSON string from a Markdown string.
Args:
json_string: The Markdown string.
Returns:
The parsed JSON object as a Python dictionary.
"""
# Try to find JSON string within triple backticks
match = re.search(r"```(json)?(.*)```", json_string, re.DOTALL)

# If no match found, assume the entire string is a JSON string
if match is None:
json_str = json_string
else:
# If match found, use the content within the backticks
json_str = match.group(2)

# Strip whitespace and newlines from the start and end
json_str = json_str.strip()

# handle newlines and other special characters inside the returned value
json_str = _custom_parser(json_str)

# Parse the JSON string into a Python dictionary
parsed = parser(json_str)

return parsed


def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
"""
Parse a JSON string from a Markdown string and check that it
contains the expected keys.
Args:
text: The Markdown string.
expected_keys: The expected keys in the JSON string.
Returns:
The parsed JSON object as a Python dictionary.
"""
try:
json_obj = parse_json_markdown(text)
except json.JSONDecodeError as e:
raise OutputParserException(f"Got invalid JSON object. Error: {e}")
for key in expected_keys:
if key not in json_obj:
raise OutputParserException(
f"Got invalid return object. Expected key `{key}` "
f"to be present, but got {json_obj}"
)
return json_obj


class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object.
When used in streaming mode, it will yield partial JSON objects containing
all the keys that have been returned so far.
In streaming, if `diff` is set to `True`, yields JSONPatch operations
describing the difference between the previous and the current object.
"""

def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch

def parse(self, text: str) -> Any:
text = text.strip()
try:
return parse_json_markdown(text.strip(), parser=parse_partial_json)
except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e

@property
def _type(self) -> str:
return "simple_json_output_parser"
135 changes: 135 additions & 0 deletions libs/core/langchain_core/output_parsers/xml.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
import re
import xml.etree.ElementTree as ET
from typing import Any, AsyncIterator, Dict, Iterator, List, Optional, Union

from langchain_core.messages import BaseMessage
from langchain_core.output_parsers.transform import BaseTransformOutputParser
from langchain_core.runnables.utils import AddableDict

XML_FORMAT_INSTRUCTIONS = """The output should be formatted as a XML file.
1. Output should conform to the tags below.
2. If tags are not given, make them on your own.
3. Remember to always open and close all the tags.
As an example, for the tags ["foo", "bar", "baz"]:
1. String "<foo>\n <bar>\n <baz></baz>\n </bar>\n</foo>" is a well-formatted instance of the schema.
2. String "<foo>\n <bar>\n </foo>" is a badly-formatted instance.
3. String "<foo>\n <tag>\n </tag>\n</foo>" is a badly-formatted instance.
Here are the output tags:
```
{tags}
```""" # noqa: E501


class XMLOutputParser(BaseTransformOutputParser):
"""Parse an output using xml format."""

tags: Optional[List[str]] = None
encoding_matcher: re.Pattern = re.compile(
r"<([^>]*encoding[^>]*)>\n(.*)", re.MULTILINE | re.DOTALL
)

def get_format_instructions(self) -> str:
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)

def parse(self, text: str) -> Dict[str, List[Any]]:
text = text.strip("`").strip("xml")
encoding_match = self.encoding_matcher.search(text)
if encoding_match:
text = encoding_match.group(2)

text = text.strip()
if (text.startswith("<") or text.startswith("\n<")) and (
text.endswith(">") or text.endswith(">\n")
):
root = ET.fromstring(text)
return self._root_to_dict(root)
else:
raise ValueError(f"Could not parse output: {text}")

def _transform(
self, input: Iterator[Union[str, BaseMessage]]
) -> Iterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()

async def _atransform(
self, input: AsyncIterator[Union[str, BaseMessage]]
) -> AsyncIterator[AddableDict]:
parser = ET.XMLPullParser(["start", "end"])
current_path: List[str] = []
current_path_has_children = False
async for chunk in input:
if isinstance(chunk, BaseMessage):
# extract text
chunk_content = chunk.content
if not isinstance(chunk_content, str):
continue
chunk = chunk_content
# pass chunk to parser
parser.feed(chunk)
# yield all events
for event, elem in parser.read_events():
if event == "start":
# update current path
current_path.append(elem.tag)
current_path_has_children = False
elif event == "end":
# remove last element from current path
current_path.pop()
# yield element
if not current_path_has_children:
yield nested_element(current_path, elem)
# prevent yielding of parent element
current_path_has_children = True
# close parser
parser.close()

def _root_to_dict(self, root: ET.Element) -> Dict[str, List[Any]]:
"""Converts xml tree to python dictionary."""
result: Dict[str, List[Any]] = {root.tag: []}
for child in root:
if len(child) == 0:
result[root.tag].append({child.tag: child.text})
else:
result[root.tag].append(self._root_to_dict(child))
return result

@property
def _type(self) -> str:
return "xml"


def nested_element(path: List[str], elem: ET.Element) -> Any:
"""Get nested element from path."""
if len(path) == 0:
return AddableDict({elem.tag: elem.text})
else:
return AddableDict({path[0]: [nested_element(path[1:], elem)]})
2 changes: 2 additions & 0 deletions libs/core/tests/unit_tests/output_parsers/test_imports.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
"StrOutputParser",
"BaseTransformOutputParser",
"BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser",
"XMLOutputParser",
]


Expand Down
Loading

0 comments on commit 71076cc

Please sign in to comment.