Skip to content

Commit

Permalink
Feat/inheritance (#395)
Browse files Browse the repository at this point in the history
* Multi-model behavior for inherited documents

Co-authored-by: vitalium <[email protected]>
Co-authored-by: Vitaliy Ivanov <[email protected]>
  • Loading branch information
3 people authored Nov 4, 2022
1 parent ff9be63 commit 1ec3415
Show file tree
Hide file tree
Showing 51 changed files with 963 additions and 466 deletions.
4 changes: 2 additions & 2 deletions beanie/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,12 +21,12 @@
DeleteRules,
)
from beanie.odm.settings.timeseries import TimeSeriesConfig, Granularity
from beanie.odm.utils.general import init_beanie
from beanie.odm.utils.init import init_beanie
from beanie.odm.documents import Document
from beanie.odm.views import View
from beanie.odm.union_doc import UnionDoc

__version__ = "1.13.1"
__version__ = "1.14.0"
__all__ = [
# ODM
"Document",
Expand Down
2 changes: 1 addition & 1 deletion beanie/migrations/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import Type, Optional

from beanie.odm.documents import Document
from beanie.odm.utils.general import init_beanie
from beanie.odm.utils.init import init_beanie
from beanie.migrations.controllers.iterative import BaseMigrationController
from beanie.migrations.database import DBHandler
from beanie.migrations.models import (
Expand Down
111 changes: 15 additions & 96 deletions beanie/odm/documents.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import asyncio
import inspect
from typing import ClassVar, AbstractSet
from typing import (
Dict,
Expand All @@ -16,7 +15,6 @@
from uuid import UUID, uuid4

from bson import ObjectId, DBRef
from motor.motor_asyncio import AsyncIOMotorDatabase
from pydantic import (
ValidationError,
PrivateAttr,
Expand All @@ -42,7 +40,6 @@
from beanie.odm.actions import (
EventTypes,
wrap_with_actions,
ActionRegistry,
ActionDirections,
)
from beanie.odm.bulk import BulkWriter, Operation
Expand All @@ -60,6 +57,8 @@
from beanie.odm.interfaces.detector import ModelType
from beanie.odm.interfaces.find import FindInterface
from beanie.odm.interfaces.getters import OtherGettersInterface
from beanie.odm.interfaces.inheritance import InheritanceInterface
from beanie.odm.interfaces.setters import SettersInterface
from beanie.odm.models import (
InspectionResult,
InspectionStatuses,
Expand All @@ -72,11 +71,8 @@
Set as SetOperator,
)
from beanie.odm.queries.update import UpdateMany

# from beanie.odm.settings.general import DocumentSettings
from beanie.odm.settings.document import DocumentSettings
from beanie.odm.utils.dump import get_dict
from beanie.odm.utils.relations import detect_link
from beanie.odm.utils.self_validation import validate_self_before
from beanie.odm.utils.state import (
saved_state_needed,
Expand All @@ -93,6 +89,8 @@

class Document(
BaseModel,
SettersInterface,
InheritanceInterface,
FindInterface,
AggregateInterface,
OtherGettersInterface,
Expand Down Expand Up @@ -165,6 +163,7 @@ async def get(
session: Optional[ClientSession] = None,
ignore_cache: bool = False,
fetch_links: bool = False,
with_children: bool = False,
**pymongo_kwargs,
) -> Optional["DocType"]:
"""
Expand All @@ -178,11 +177,13 @@ async def get(
"""
if not isinstance(document_id, cls.__fields__["id"].type_):
document_id = parse_obj_as(cls.__fields__["id"].type_, document_id)

return await cls.find_one(
{"_id": document_id},
session=session,
ignore_cache=ignore_cache,
fetch_links=fetch_links,
with_children=with_children,
**pymongo_kwargs,
)

Expand Down Expand Up @@ -214,12 +215,14 @@ async def insert(
await value.insert(link_rule=WriteRules.WRITE)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
for obj in value:
if isinstance(obj, Document):
await obj.insert(link_rule=WriteRules.WRITE)
await obj.insert(
link_rule=WriteRules.WRITE
)

result = await self.get_motor_collection().insert_one(
get_dict(self, to_db=True), session=session
Expand Down Expand Up @@ -349,7 +352,7 @@ async def replace(
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
for obj in value:
Expand Down Expand Up @@ -406,7 +409,7 @@ async def save(
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
for obj in value:
Expand Down Expand Up @@ -672,7 +675,7 @@ async def delete(
)
if field_info.link_type in [
LinkTypes.LIST,
LinkTypes.OPTIONAL_LIST
LinkTypes.OPTIONAL_LIST,
]:
if isinstance(value, List):
for obj in value:
Expand Down Expand Up @@ -798,90 +801,6 @@ def rollback(self) -> None:
else:
setattr(self, key, value)

# Initialization

@classmethod
def init_cache(cls) -> None:
"""
Init model's cache
:return: None
"""
if cls.get_settings().use_cache:
cls._cache = LRUCache(
capacity=cls.get_settings().cache_capacity,
expiration_time=cls.get_settings().cache_expiration_time,
)

@classmethod
def init_fields(cls) -> None:
"""
Init class fields
:return: None
"""
if cls._link_fields is None:
cls._link_fields = {}
for k, v in cls.__fields__.items():
path = v.alias or v.name
setattr(cls, k, ExpressionField(path))

link_info = detect_link(v)
if link_info is not None:
cls._link_fields[v.name] = link_info

cls._hidden_fields = cls.get_hidden_fields()

@classmethod
async def init_settings(
cls, database: AsyncIOMotorDatabase, allow_index_dropping: bool
) -> None:
"""
Init document settings (collection and models)
:param database: AsyncIOMotorDatabase - motor database
:param allow_index_dropping: bool
:return: None
"""
# TODO looks ugly a little. Too many parameters transfers.
cls._document_settings = await DocumentSettings.init(
database=database,
document_model=cls,
allow_index_dropping=allow_index_dropping,
)

@classmethod
def init_actions(cls):
"""
Init event-based actions
"""
ActionRegistry.clean_actions(cls)
for attr in dir(cls):
f = getattr(cls, attr)
if inspect.isfunction(f):
if hasattr(f, "has_action"):
ActionRegistry.add_action(
document_class=cls,
event_types=f.event_types, # type: ignore
action_direction=f.action_direction, # type: ignore
funct=f,
)

@classmethod
async def init_model(
cls, database: AsyncIOMotorDatabase, allow_index_dropping: bool
) -> None:
"""
Init wrapper
:param database: AsyncIOMotorDatabase
:param allow_index_dropping: bool
:return: None
"""
await cls.init_settings(
database=database, allow_index_dropping=allow_index_dropping
)
cls.init_fields()
cls.init_cache()
cls.init_actions()

# Other

@classmethod
Expand Down Expand Up @@ -967,7 +886,7 @@ def dict(

@wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE)
async def validate_self(self, *args, **kwargs):
# TODO it can be sync, but needs some actions controller improvements
# TODO: it can be sync, but needs some actions controller improvements
if self.get_settings().validate_on_save:
self.parse_obj(self)

Expand Down
7 changes: 4 additions & 3 deletions beanie/odm/fields.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
NE,
In,
)
from beanie.odm.utils.parsing import parse_obj


def Indexed(typ, index_type=ASCENDING, **kwargs):
Expand Down Expand Up @@ -155,7 +156,7 @@ def __init__(self, ref: DBRef, model_class: Type[T]):
self.model_class = model_class

async def fetch(self) -> Union[T, "Link"]:
result = await self.model_class.get(self.ref.id) # type: ignore
result = await self.model_class.get(self.ref.id, with_children=True) # type: ignore
return result or self

@classmethod
Expand All @@ -175,7 +176,7 @@ async def fetch_list(cls, links: List["Link"]):
"All the links must have the same model class"
)
ids.append(link.ref.id)
return await model_class.find(In("_id", ids)).to_list() # type: ignore
return await model_class.find(In("_id", ids), with_children=True).to_list() # type: ignore

@classmethod
async def fetch_many(cls, links: List["Link"]):
Expand All @@ -196,7 +197,7 @@ def validate(cls, v: Union[DBRef, T], field: ModelField):
if isinstance(v, Link):
return v
if isinstance(v, dict) or isinstance(v, BaseModel):
return model_class.validate(v)
return parse_obj(model_class, v)
new_id = parse_obj_as(model_class.__fields__["id"].type_, v)
ref = DBRef(collection=model_class.get_collection_name(), id=new_id)
return cls(ref=ref, model_class=model_class)
Expand Down
Loading

0 comments on commit 1ec3415

Please sign in to comment.