Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: change enums usage to work on all supported python versions #329

Open
wants to merge 13 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 12 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/unit_tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -122,4 +122,5 @@ jobs:
make install-base
make install-test
pip install unstructured
python -c "from unstructured.nlp.tokenize import download_nltk_packages; download_nltk_packages()"
mpolomdeepsense marked this conversation as resolved.
Show resolved Hide resolved
make unit-test
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
## 0.3.13-dev3
## 0.3.13-dev4

### Fixes

* **Fix Snowflake Uploader error**
* **Fix SQL Uploader Stager timestamp error**
* **Migrate Discord Sourced Connector to v2**
* **Fix Neo4j Uploader string enum error**

### Enhancements

Expand Down
18 changes: 13 additions & 5 deletions test/integration/connectors/test_neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -199,13 +199,15 @@ async def validate_uploaded_graph(upload_file: Path):
try:
nodes_count = len((await driver.execute_query("MATCH (n) RETURN n"))[0])
chunk_nodes_count = len(
(await driver.execute_query(f"MATCH (n: {Label.CHUNK}) RETURN n"))[0]
(await driver.execute_query(f"MATCH (n: {Label.CHUNK.value}) RETURN n"))[0]
)
document_nodes_count = len(
(await driver.execute_query(f"MATCH (n: {Label.DOCUMENT}) RETURN n"))[0]
(await driver.execute_query(f"MATCH (n: {Label.DOCUMENT.value}) RETURN n"))[0]
)
element_nodes_count = len(
(await driver.execute_query(f"MATCH (n: {Label.UNSTRUCTURED_ELEMENT}) RETURN n"))[0]
(await driver.execute_query(f"MATCH (n: {Label.UNSTRUCTURED_ELEMENT.value}) RETURN n"))[
0
]
)
with check:
assert nodes_count == expected_nodes_count
Expand All @@ -217,12 +219,18 @@ async def validate_uploaded_graph(upload_file: Path):
assert element_nodes_count == expected_element_count

records, _, _ = await driver.execute_query(
f"MATCH ()-[r:{Relationship.PART_OF_DOCUMENT}]->(:{Label.DOCUMENT}) RETURN r"
f"""
MATCH ()-[r:{Relationship.PART_OF_DOCUMENT.value}]->(:{Label.DOCUMENT.value})
RETURN r
"""
)
part_of_document_count = len(records)

records, _, _ = await driver.execute_query(
f"MATCH (:{Label.CHUNK})-[r:{Relationship.NEXT_CHUNK}]->(:{Label.CHUNK}) RETURN r"
f"""
MATCH (:{Label.CHUNK.value})-[r:{Relationship.NEXT_CHUNK.value}]->(:{Label.CHUNK.value})
RETURN r
"""
)
next_chunk_count = len(records)

Expand Down
2 changes: 1 addition & 1 deletion unstructured_ingest/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.3.13-dev3" # pragma: no cover
__version__ = "0.3.13-dev4" # pragma: no cover
24 changes: 12 additions & 12 deletions unstructured_ingest/v2/processes/connectors/neo4j.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def run( # type: ignore
output_filepath.parent.mkdir(parents=True, exist_ok=True)

with open(output_filepath, "w") as file:
json.dump(_GraphData.from_nx(nx_graph).model_dump(), file, indent=4)
file.write(_GraphData.from_nx(nx_graph).model_dump_json())

return output_filepath

Expand Down Expand Up @@ -196,7 +196,7 @@ def from_nx(cls, nx_graph: "MultiDiGraph") -> _GraphData:


class _Node(BaseModel):
model_config = ConfigDict(use_enum_values=True)
model_config = ConfigDict()

id_: str = Field(default_factory=lambda: str(uuid.uuid4()))
labels: list[Label] = Field(default_factory=list)
Expand All @@ -207,20 +207,20 @@ def __hash__(self):


class _Edge(BaseModel):
model_config = ConfigDict(use_enum_values=True)
model_config = ConfigDict()

source_id: str
destination_id: str
relationship: Relationship


class Label(str, Enum):
class Label(Enum):
UNSTRUCTURED_ELEMENT = "UnstructuredElement"
CHUNK = "Chunk"
DOCUMENT = "Document"


class Relationship(str, Enum):
class Relationship(Enum):
PART_OF_DOCUMENT = "PART_OF_DOCUMENT"
PART_OF_CHUNK = "PART_OF_CHUNK"
NEXT_CHUNK = "NEXT_CHUNK"
Expand Down Expand Up @@ -263,23 +263,23 @@ async def run_async(self, path: Path, file_data: FileData, **kwargs) -> None: #
async def _create_uniqueness_constraints(self, client: AsyncDriver) -> None:
for label in Label:
logger.info(
f"Adding id uniqueness constraint for nodes labeled '{label}'"
f"Adding id uniqueness constraint for nodes labeled '{label.value}'"
" if it does not already exist."
)
constraint_name = f"{label.lower()}_id"
constraint_name = f"{label.value.lower()}_id"
await client.execute_query(
f"""
CREATE CONSTRAINT {constraint_name} IF NOT EXISTS
FOR (n: {label}) REQUIRE n.id IS UNIQUE
FOR (n: {label.value}) REQUIRE n.id IS UNIQUE
"""
)

async def _delete_old_data_if_exists(self, file_data: FileData, client: AsyncDriver) -> None:
logger.info(f"Deleting old data for the record '{file_data.identifier}' (if present).")
_, summary, _ = await client.execute_query(
f"""
MATCH (n: {Label.DOCUMENT} {{id: $identifier}})
MATCH (n)--(m: {Label.CHUNK}|{Label.UNSTRUCTURED_ELEMENT})
MATCH (n: {Label.DOCUMENT.value} {{id: $identifier}})
MATCH (n)--(m: {Label.CHUNK.value}|{Label.UNSTRUCTURED_ELEMENT.value})
DETACH DELETE m""",
identifier=file_data.identifier,
)
Expand Down Expand Up @@ -349,7 +349,7 @@ async def _execute_queries(

@staticmethod
def _create_nodes_query(nodes: list[_Node], labels: tuple[Label, ...]) -> tuple[str, dict]:
labels_string = ", ".join(labels)
labels_string = ", ".join([label.value for label in labels])
logger.info(f"Preparing MERGE query for {len(nodes)} nodes labeled '{labels_string}'.")
query_string = f"""
UNWIND $nodes AS node
Expand All @@ -366,7 +366,7 @@ def _create_edges_query(edges: list[_Edge], relationship: Relationship) -> tuple
UNWIND $edges AS edge
MATCH (u {{id: edge.source}})
MATCH (v {{id: edge.destination}})
MERGE (u)-[:{relationship}]->(v)
MERGE (u)-[:{relationship.value}]->(v)
"""
parameters = {
"edges": [
Expand Down
Loading