diff --git a/beanie/__init__.py b/beanie/__init__.py index 05da5e82..d0cff34c 100644 --- a/beanie/__init__.py +++ b/beanie/__init__.py @@ -21,12 +21,12 @@ DeleteRules, ) from beanie.odm.settings.timeseries import TimeSeriesConfig, Granularity -from beanie.odm.utils.general import init_beanie +from beanie.odm.utils.init import init_beanie from beanie.odm.documents import Document from beanie.odm.views import View from beanie.odm.union_doc import UnionDoc -__version__ = "1.13.1" +__version__ = "1.14.0" __all__ = [ # ODM "Document", diff --git a/beanie/migrations/runner.py b/beanie/migrations/runner.py index a475f135..6cd2c2c7 100644 --- a/beanie/migrations/runner.py +++ b/beanie/migrations/runner.py @@ -4,7 +4,7 @@ from typing import Type, Optional from beanie.odm.documents import Document -from beanie.odm.utils.general import init_beanie +from beanie.odm.utils.init import init_beanie from beanie.migrations.controllers.iterative import BaseMigrationController from beanie.migrations.database import DBHandler from beanie.migrations.models import ( diff --git a/beanie/odm/documents.py b/beanie/odm/documents.py index 6401cbb7..406ceead 100644 --- a/beanie/odm/documents.py +++ b/beanie/odm/documents.py @@ -1,5 +1,4 @@ import asyncio -import inspect from typing import ClassVar, AbstractSet from typing import ( Dict, @@ -16,7 +15,6 @@ from uuid import UUID, uuid4 from bson import ObjectId, DBRef -from motor.motor_asyncio import AsyncIOMotorDatabase from pydantic import ( ValidationError, PrivateAttr, @@ -42,7 +40,6 @@ from beanie.odm.actions import ( EventTypes, wrap_with_actions, - ActionRegistry, ActionDirections, ) from beanie.odm.bulk import BulkWriter, Operation @@ -60,6 +57,8 @@ from beanie.odm.interfaces.detector import ModelType from beanie.odm.interfaces.find import FindInterface from beanie.odm.interfaces.getters import OtherGettersInterface +from beanie.odm.interfaces.inheritance import InheritanceInterface +from beanie.odm.interfaces.setters import SettersInterface from beanie.odm.models import ( InspectionResult, InspectionStatuses, @@ -72,11 +71,8 @@ Set as SetOperator, ) from beanie.odm.queries.update import UpdateMany - -# from beanie.odm.settings.general import DocumentSettings from beanie.odm.settings.document import DocumentSettings from beanie.odm.utils.dump import get_dict -from beanie.odm.utils.relations import detect_link from beanie.odm.utils.self_validation import validate_self_before from beanie.odm.utils.state import ( saved_state_needed, @@ -93,6 +89,8 @@ class Document( BaseModel, + SettersInterface, + InheritanceInterface, FindInterface, AggregateInterface, OtherGettersInterface, @@ -165,6 +163,7 @@ async def get( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> Optional["DocType"]: """ @@ -178,11 +177,13 @@ async def get( """ if not isinstance(document_id, cls.__fields__["id"].type_): document_id = parse_obj_as(cls.__fields__["id"].type_, document_id) + return await cls.find_one( {"_id": document_id}, session=session, ignore_cache=ignore_cache, fetch_links=fetch_links, + with_children=with_children, **pymongo_kwargs, ) @@ -214,12 +215,14 @@ async def insert( await value.insert(link_rule=WriteRules.WRITE) if field_info.link_type in [ LinkTypes.LIST, - LinkTypes.OPTIONAL_LIST + LinkTypes.OPTIONAL_LIST, ]: if isinstance(value, List): for obj in value: if isinstance(obj, Document): - await obj.insert(link_rule=WriteRules.WRITE) + await obj.insert( + link_rule=WriteRules.WRITE + ) result = await self.get_motor_collection().insert_one( get_dict(self, to_db=True), session=session @@ -349,7 +352,7 @@ async def replace( ) if field_info.link_type in [ LinkTypes.LIST, - LinkTypes.OPTIONAL_LIST + LinkTypes.OPTIONAL_LIST, ]: if isinstance(value, List): for obj in value: @@ -406,7 +409,7 @@ async def save( ) if field_info.link_type in [ LinkTypes.LIST, - LinkTypes.OPTIONAL_LIST + LinkTypes.OPTIONAL_LIST, ]: if isinstance(value, List): for obj in value: @@ -672,7 +675,7 @@ async def delete( ) if field_info.link_type in [ LinkTypes.LIST, - LinkTypes.OPTIONAL_LIST + LinkTypes.OPTIONAL_LIST, ]: if isinstance(value, List): for obj in value: @@ -798,90 +801,6 @@ def rollback(self) -> None: else: setattr(self, key, value) - # Initialization - - @classmethod - def init_cache(cls) -> None: - """ - Init model's cache - :return: None - """ - if cls.get_settings().use_cache: - cls._cache = LRUCache( - capacity=cls.get_settings().cache_capacity, - expiration_time=cls.get_settings().cache_expiration_time, - ) - - @classmethod - def init_fields(cls) -> None: - """ - Init class fields - :return: None - """ - if cls._link_fields is None: - cls._link_fields = {} - for k, v in cls.__fields__.items(): - path = v.alias or v.name - setattr(cls, k, ExpressionField(path)) - - link_info = detect_link(v) - if link_info is not None: - cls._link_fields[v.name] = link_info - - cls._hidden_fields = cls.get_hidden_fields() - - @classmethod - async def init_settings( - cls, database: AsyncIOMotorDatabase, allow_index_dropping: bool - ) -> None: - """ - Init document settings (collection and models) - - :param database: AsyncIOMotorDatabase - motor database - :param allow_index_dropping: bool - :return: None - """ - # TODO looks ugly a little. Too many parameters transfers. - cls._document_settings = await DocumentSettings.init( - database=database, - document_model=cls, - allow_index_dropping=allow_index_dropping, - ) - - @classmethod - def init_actions(cls): - """ - Init event-based actions - """ - ActionRegistry.clean_actions(cls) - for attr in dir(cls): - f = getattr(cls, attr) - if inspect.isfunction(f): - if hasattr(f, "has_action"): - ActionRegistry.add_action( - document_class=cls, - event_types=f.event_types, # type: ignore - action_direction=f.action_direction, # type: ignore - funct=f, - ) - - @classmethod - async def init_model( - cls, database: AsyncIOMotorDatabase, allow_index_dropping: bool - ) -> None: - """ - Init wrapper - :param database: AsyncIOMotorDatabase - :param allow_index_dropping: bool - :return: None - """ - await cls.init_settings( - database=database, allow_index_dropping=allow_index_dropping - ) - cls.init_fields() - cls.init_cache() - cls.init_actions() - # Other @classmethod @@ -967,7 +886,7 @@ def dict( @wrap_with_actions(event_type=EventTypes.VALIDATE_ON_SAVE) async def validate_self(self, *args, **kwargs): - # TODO it can be sync, but needs some actions controller improvements + # TODO: it can be sync, but needs some actions controller improvements if self.get_settings().validate_on_save: self.parse_obj(self) diff --git a/beanie/odm/fields.py b/beanie/odm/fields.py index 6ae2c61c..162e0659 100644 --- a/beanie/odm/fields.py +++ b/beanie/odm/fields.py @@ -19,6 +19,7 @@ NE, In, ) +from beanie.odm.utils.parsing import parse_obj def Indexed(typ, index_type=ASCENDING, **kwargs): @@ -155,7 +156,7 @@ def __init__(self, ref: DBRef, model_class: Type[T]): self.model_class = model_class async def fetch(self) -> Union[T, "Link"]: - result = await self.model_class.get(self.ref.id) # type: ignore + result = await self.model_class.get(self.ref.id, with_children=True) # type: ignore return result or self @classmethod @@ -175,7 +176,7 @@ async def fetch_list(cls, links: List["Link"]): "All the links must have the same model class" ) ids.append(link.ref.id) - return await model_class.find(In("_id", ids)).to_list() # type: ignore + return await model_class.find(In("_id", ids), with_children=True).to_list() # type: ignore @classmethod async def fetch_many(cls, links: List["Link"]): @@ -196,7 +197,7 @@ def validate(cls, v: Union[DBRef, T], field: ModelField): if isinstance(v, Link): return v if isinstance(v, dict) or isinstance(v, BaseModel): - return model_class.validate(v) + return parse_obj(model_class, v) new_id = parse_obj_as(model_class.__fields__["id"].type_, v) ref = DBRef(collection=model_class.get_collection_name(), id=new_id) return cls(ref=ref, model_class=model_class) diff --git a/beanie/odm/interfaces/find.py b/beanie/odm/interfaces/find.py index 066e04d2..234254e0 100644 --- a/beanie/odm/interfaces/find.py +++ b/beanie/odm/interfaces/find.py @@ -9,14 +9,16 @@ overload, ClassVar, TypeVar, + Dict, ) - +from collections.abc import Iterable from pydantic import ( BaseModel, ) from pymongo.client_session import ClientSession from beanie.odm.enums import SortDirection +from beanie.odm.interfaces.detector import ModelType from beanie.odm.queries.find import FindOne, FindMany from beanie.odm.settings.base import ItemSettings @@ -30,6 +32,14 @@ class FindInterface: _find_one_query_class: ClassVar[Type] = FindOne _find_many_query_class: ClassVar[Type] = FindMany + _inheritance_inited: bool + _class_id: ClassVar[str] + _children: ClassVar[Dict[str, Type]] + + @classmethod + def get_model_type(cls) -> ModelType: + pass + @classmethod def get_settings(cls) -> ItemSettings: pass @@ -43,6 +53,7 @@ def find_one( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindOne["DocType"]: ... @@ -56,6 +67,7 @@ def find_one( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindOne["DocumentProjectionType"]: ... @@ -68,6 +80,7 @@ def find_one( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> Union[FindOne["DocType"], FindOne["DocumentProjectionType"]]: """ @@ -82,7 +95,7 @@ def find_one( :param **pymongo_kwargs: pymongo native parameters for find operation (if Document class contains links, this parameter must fit the respective parameter of the aggregate MongoDB function) :return: [FindOne](https://roman-right.github.io/beanie/api/queries/#findone) - find query instance """ - args = cls._add_class_id_filter(args) + args = cls._add_class_id_filter(args, with_children) return cls._find_one_query_class(document_model=cls).find_one( *args, projection_model=projection_model, @@ -104,6 +117,7 @@ def find_many( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocType"]: ... @@ -120,6 +134,7 @@ def find_many( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @@ -135,6 +150,7 @@ def find_many( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: """ @@ -151,7 +167,7 @@ def find_many( :param **pymongo_kwargs: pymongo native parameters for find operation (if Document class contains links, this parameter must fit the respective parameter of the aggregate MongoDB function) :return: [FindMany](https://roman-right.github.io/beanie/api/queries/#findmany) - query instance """ - args = cls._add_class_id_filter(args) + args = cls._add_class_id_filter(args, with_children) return cls._find_many_query_class(document_model=cls).find_many( *args, sort=sort, @@ -176,6 +192,7 @@ def find( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocType"]: ... @@ -192,6 +209,7 @@ def find( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @@ -207,6 +225,7 @@ def find( session: Optional[ClientSession] = None, ignore_cache: bool = False, fetch_links: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: """ @@ -221,6 +240,7 @@ def find( session=session, ignore_cache=ignore_cache, fetch_links=fetch_links, + with_children=with_children, **pymongo_kwargs, ) @@ -234,6 +254,7 @@ def find_all( projection_model: None = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocType"]: ... @@ -248,6 +269,7 @@ def find_all( projection_model: Optional[Type["DocumentProjectionType"]] = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @@ -261,6 +283,7 @@ def find_all( projection_model: Optional[Type["DocumentProjectionType"]] = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: """ @@ -282,6 +305,7 @@ def find_all( projection_model=projection_model, session=session, ignore_cache=ignore_cache, + with_children=with_children, **pymongo_kwargs, ) @@ -295,6 +319,7 @@ def all( sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocType"]: ... @@ -309,6 +334,7 @@ def all( sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> FindMany["DocumentProjectionType"]: ... @@ -322,6 +348,7 @@ def all( sort: Union[None, str, List[Tuple[str, SortDirection]]] = None, session: Optional[ClientSession] = None, ignore_cache: bool = False, + with_children: bool = False, **pymongo_kwargs, ) -> Union[FindMany["DocType"], FindMany["DocumentProjectionType"]]: """ @@ -334,6 +361,7 @@ def all( projection_model=projection_model, session=session, ignore_cache=ignore_cache, + with_children=with_children, **pymongo_kwargs, ) @@ -348,7 +376,33 @@ async def count(cls) -> int: return await cls.find_all().count() @classmethod - def _add_class_id_filter(cls, args: Tuple): + def _add_class_id_filter(cls, args: Tuple, with_children: bool = False): + # skip if _class_id is already added + if any( + ( + True + for a in args + if isinstance(a, Iterable) and "_class_id" in a + ) + ): + return args + + if ( + cls.get_model_type() == ModelType.Document + and cls._inheritance_inited + ): + if not with_children: + args += ({"_class_id": cls._class_id},) + else: + args += ( + { + "_class_id": { + "$in": [cls._class_id] + + [cname for cname in cls._children.keys()] + } + }, + ) + if cls.get_settings().union_doc: args += ({"_class_id": cls.__name__},) return args diff --git a/beanie/odm/interfaces/getters.py b/beanie/odm/interfaces/getters.py index 372cfd80..929ba1d1 100644 --- a/beanie/odm/interfaces/getters.py +++ b/beanie/odm/interfaces/getters.py @@ -14,10 +14,7 @@ def get_motor_collection(cls) -> AsyncIOMotorCollection: @classmethod def get_collection_name(cls): - input_class = getattr(cls, "Settings", None) - if input_class is None or not hasattr(input_class, "name"): - return cls.__name__ - return input_class.name + return cls.get_settings().name @classmethod def get_bson_encoders(cls): diff --git a/beanie/odm/interfaces/inheritance.py b/beanie/odm/interfaces/inheritance.py new file mode 100644 index 00000000..7ade31eb --- /dev/null +++ b/beanie/odm/interfaces/inheritance.py @@ -0,0 +1,19 @@ +from typing import ( + Type, + Optional, + Dict, + ClassVar, +) + + +class InheritanceInterface: + _children: ClassVar[Dict[str, Type]] + _parent: ClassVar[Optional[Type]] + _inheritance_inited: ClassVar[bool] + _class_id: ClassVar[str] + + @classmethod + def add_child(cls, name: str, clas: Type): + cls._children[name] = clas + if cls._parent is not None: + cls._parent.add_child(name, clas) diff --git a/beanie/odm/interfaces/setters.py b/beanie/odm/interfaces/setters.py new file mode 100644 index 00000000..8f5785ce --- /dev/null +++ b/beanie/odm/interfaces/setters.py @@ -0,0 +1,28 @@ +from typing import ClassVar, Optional + +from beanie.odm.settings.document import DocumentSettings + + +class SettersInterface: + _document_settings: ClassVar[Optional[DocumentSettings]] + + @classmethod + def set_collection(cls, collection): + """ + Collection setter + """ + cls._document_settings.motor_collection = collection + + @classmethod + def set_database(cls, database): + """ + Database setter + """ + cls._document_settings.motor_db = database + + @classmethod + def set_collection_name(cls, name: str): + """ + Collection name setter + """ + cls._document_settings.name = name # type: ignore diff --git a/beanie/odm/queries/find.py b/beanie/odm/queries/find.py index 37b4eb76..3a2990e5 100644 --- a/beanie/odm/queries/find.py +++ b/beanie/odm/queries/find.py @@ -479,7 +479,7 @@ def update_many( Provide search criteria to the [UpdateMany](https://roman-right.github.io/beanie/api/queries/#updatemany) query - :param args: *Mappingp[str,Any] - the modifications to apply. + :param args: *Mapping[str,Any] - the modifications to apply. :param session: Optional[ClientSession] :return: [UpdateMany](https://roman-right.github.io/beanie/api/queries/#updatemany) query """ diff --git a/beanie/odm/settings/base.py b/beanie/odm/settings/base.py index 4ccfb974..913716a9 100644 --- a/beanie/odm/settings/base.py +++ b/beanie/odm/settings/base.py @@ -19,5 +19,7 @@ class ItemSettings(BaseModel): union_doc: Optional[Type] = None + is_root: bool = False + class Config: arbitrary_types_allowed = True diff --git a/beanie/odm/settings/document.py b/beanie/odm/settings/document.py index 237086e0..3c70166e 100644 --- a/beanie/odm/settings/document.py +++ b/beanie/odm/settings/document.py @@ -1,11 +1,8 @@ -import warnings -from typing import Optional, Type, List +from typing import Optional, List -from motor.motor_asyncio import AsyncIOMotorDatabase from pydantic import Field from pymongo import IndexModel -from beanie.exceptions import MongoDBVersionError from beanie.odm.settings.base import ItemSettings from beanie.odm.settings.timeseries import TimeSeriesConfig @@ -28,118 +25,10 @@ class DocumentSettings(ItemSettings): state_management_replace_objects: bool = False validate_on_save: bool = False use_revision: bool = False + single_root_inheritance: bool = False indexes: List[IndexModelField] = Field(default_factory=list) timeseries: Optional[TimeSeriesConfig] = None - @classmethod - async def init( - cls, - database: AsyncIOMotorDatabase, - document_model: Type, - allow_index_dropping: bool, - ) -> "DocumentSettings": - - settings_class = getattr(document_model, "Settings", None) - settings_vars = ( - {} if settings_class is None else dict(vars(settings_class)) - ) - - # deprecated Collection class support - - collection_class = getattr(document_model, "Collection", None) - - if collection_class is not None: - warnings.warn( - "Collection inner class is deprecated, use Settings instead", - DeprecationWarning, - ) - - collection_vars = ( - {} if collection_class is None else dict(vars(collection_class)) - ) - - settings_vars.update(collection_vars) - - # ------------------------------------ # - - document_settings = DocumentSettings.parse_obj(settings_vars) - - document_settings.motor_db = database - - # register in the Union Doc - - if document_settings.union_doc is not None: - document_settings.name = document_settings.union_doc.register_doc( - document_model - ) - - # set a name - - if not document_settings.name: - document_settings.name = document_model.__name__ - - # check mongodb version - build_info = await database.command({"buildInfo": 1}) - mongo_version = build_info["version"] - major_version = int(mongo_version.split(".")[0]) - - if document_settings.timeseries is not None and major_version < 5: - raise MongoDBVersionError( - "Timeseries are supported by MongoDB version 5 and higher" - ) - - # create motor collection - if ( - document_settings.timeseries is not None - and document_settings.name - not in await database.list_collection_names() - ): - - collection = await database.create_collection( - **document_settings.timeseries.build_query( - document_settings.name - ) - ) - else: - collection = database[document_settings.name] - - document_settings.motor_collection = collection - - # indexes - old_indexes = (await collection.index_information()).keys() - new_indexes = ["_id_"] - - # Indexed field wrapped with Indexed() - found_indexes = [ - IndexModel( - [ - ( - fvalue.alias, - fvalue.type_._indexed[0], - ) - ], - **fvalue.type_._indexed[1] - ) - for _, fvalue in document_model.__fields__.items() - if hasattr(fvalue.type_, "_indexed") and fvalue.type_._indexed - ] - - # get indexes from the Collection class - if document_settings.indexes: - found_indexes += document_settings.indexes - - # create indices - if found_indexes: - new_indexes += await collection.create_indexes(found_indexes) - - # delete indexes - # Only drop indexes if the user specifically allows for it - if allow_index_dropping: - for index in set(old_indexes) - set(new_indexes): - await collection.drop_index(index) - - return document_settings - class Config: arbitrary_types_allowed = True diff --git a/beanie/odm/settings/union_doc.py b/beanie/odm/settings/union_doc.py index 16b4406e..af801d87 100644 --- a/beanie/odm/settings/union_doc.py +++ b/beanie/odm/settings/union_doc.py @@ -1,23 +1,5 @@ -from typing import Type - -from motor.motor_asyncio import AsyncIOMotorDatabase - from beanie.odm.settings.base import ItemSettings class UnionDocSettings(ItemSettings): - @classmethod - def init( - cls, doc_class: Type, database: AsyncIOMotorDatabase - ) -> "UnionDocSettings": - settings_class = getattr(doc_class, "Settings", None) - - multi_doc_settings = cls.parse_obj(vars(settings_class)) - - if multi_doc_settings.name is None: - multi_doc_settings.name = doc_class.__name__ - - multi_doc_settings.motor_db = database - multi_doc_settings.motor_collection = database[multi_doc_settings.name] - - return multi_doc_settings + ... diff --git a/beanie/odm/settings/view.py b/beanie/odm/settings/view.py index ac3fdf25..79f5a381 100644 --- a/beanie/odm/settings/view.py +++ b/beanie/odm/settings/view.py @@ -1,33 +1,8 @@ -from inspect import isclass from typing import List, Dict, Any, Union, Type -from motor.motor_asyncio import AsyncIOMotorDatabase - -from beanie.exceptions import ViewHasNoSettings from beanie.odm.settings.base import ItemSettings class ViewSettings(ItemSettings): source: Union[str, Type] pipeline: List[Dict[str, Any]] - - @classmethod - async def init( - cls, view_class: Type, database: AsyncIOMotorDatabase - ) -> "ViewSettings": - settings_class = getattr(view_class, "Settings", None) - if settings_class is None: - raise ViewHasNoSettings("View must have Settings inner class") - - view_settings = cls.parse_obj(vars(settings_class)) - - if view_settings.name is None: - view_settings.name = view_class.__name__ - - if isclass(view_settings.source): - view_settings.source = view_settings.source.get_collection_name() - - view_settings.motor_db = database - view_settings.motor_collection = database[view_settings.name] - - return view_settings diff --git a/beanie/odm/union_doc.py b/beanie/odm/union_doc.py index a3a34f33..bb66e03d 100644 --- a/beanie/odm/union_doc.py +++ b/beanie/odm/union_doc.py @@ -1,7 +1,5 @@ from typing import ClassVar, Type, Dict, Optional -from motor.motor_asyncio import AsyncIOMotorDatabase - from beanie.exceptions import UnionDocNotInited from beanie.odm.interfaces.aggregate import AggregateInterface from beanie.odm.interfaces.detector import DetectionInterface, ModelType @@ -24,11 +22,6 @@ class UnionDoc( def get_settings(cls) -> UnionDocSettings: return cls._settings - @classmethod - def init(cls, database: AsyncIOMotorDatabase): - cls._settings = UnionDocSettings.init(database=database, doc_class=cls) - cls._is_inited = True - @classmethod def register_doc(cls, doc_model: Type): if cls._document_models is None: diff --git a/beanie/odm/utils/encoder.py b/beanie/odm/utils/encoder.py index 8195f21c..c7292499 100644 --- a/beanie/odm/utils/encoder.py +++ b/beanie/odm/utils/encoder.py @@ -90,6 +90,9 @@ def encode_document(self, obj): obj_dict: Dict[str, Any] = {} if obj.get_settings().union_doc is not None: obj_dict["_class_id"] = obj.__class__.__name__ + if obj._inheritance_inited: + obj_dict["_class_id"] = obj._class_id + for k, o in obj._iter(to_dict=False, by_alias=self.by_alias): if k not in self.exclude: if link_fields and k in link_fields: diff --git a/beanie/odm/utils/general.py b/beanie/odm/utils/general.py index 05a752f4..e69de29b 100644 --- a/beanie/odm/utils/general.py +++ b/beanie/odm/utils/general.py @@ -1,89 +0,0 @@ -import asyncio -import importlib -from typing import List, Type, Union, TYPE_CHECKING - -from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorClient -from yarl import URL - -from beanie.odm.interfaces.detector import ModelType - -if TYPE_CHECKING: - from beanie.odm.documents import DocType - from beanie.odm.views import View - - -def get_model(dot_path: str) -> Type["DocType"]: - """ - Get the model by the path in format bar.foo.Model - - :param dot_path: str - dot seprated path to the model - :return: Type[DocType] - class of the model - """ - module_name, class_name = None, None - try: - module_name, class_name = dot_path.rsplit(".", 1) - return getattr(importlib.import_module(module_name), class_name) - - except ValueError: - raise ValueError( - f"'{dot_path}' doesn't have '.' path, eg. path.to.your.model.class" - ) - - except AttributeError: - raise AttributeError( - f"module '{module_name}' has no class called '{class_name}'" - ) - - -async def init_beanie( - database: AsyncIOMotorDatabase = None, - connection_string: str = None, - document_models: List[Union[Type["DocType"], Type["View"], str]] = None, - allow_index_dropping: bool = False, - recreate_views: bool = False, -): - """ - Beanie initialization - - :param database: AsyncIOMotorDatabase - motor database instance - :param connection_string: str - MongoDB connection string - :param document_models: List[Union[Type[DocType], str]] - model classes - or strings with dot separated paths - :param allow_index_dropping: bool - if index dropping is allowed. - Default False - :return: None - """ - if (connection_string is None and database is None) or ( - connection_string is not None and database is not None - ): - raise ValueError( - "connection_string parameter or database parameter must be set" - ) - - if document_models is None: - raise ValueError("document_models parameter must be set") - if connection_string is not None: - database = AsyncIOMotorClient(connection_string)[ - URL(connection_string).path[1:] - ] - - collection_inits = [] - for model in document_models: - if isinstance(model, str): - model = get_model(model) - - if model.get_model_type() == ModelType.UnionDoc: - model.init(database) - - if model.get_model_type() == ModelType.Document: - collection_inits.append( - model.init_model( - database, allow_index_dropping=allow_index_dropping - ) - ) - if model.get_model_type() == ModelType.View: - collection_inits.append( - model.init_view(database, recreate_view=recreate_views) - ) - - await asyncio.gather(*collection_inits) diff --git a/beanie/odm/utils/init.py b/beanie/odm/utils/init.py new file mode 100644 index 00000000..e52e93c2 --- /dev/null +++ b/beanie/odm/utils/init.py @@ -0,0 +1,458 @@ +import importlib +import inspect +from typing import Optional, List, Type, Union + +from motor.motor_asyncio import AsyncIOMotorDatabase, AsyncIOMotorClient +from pydantic import BaseModel +from pymongo import IndexModel +from yarl import URL + +from beanie.exceptions import MongoDBVersionError +from beanie.odm.actions import ActionRegistry +from beanie.odm.cache import LRUCache +from beanie.odm.documents import DocType +from beanie.odm.documents import Document +from beanie.odm.fields import ExpressionField +from beanie.odm.interfaces.detector import ModelType +from beanie.odm.settings.document import DocumentSettings +from beanie.odm.settings.union_doc import UnionDocSettings +from beanie.odm.settings.view import ViewSettings +from beanie.odm.union_doc import UnionDoc +from beanie.odm.utils.relations import detect_link +from beanie.odm.views import View + + +class Output(BaseModel): + class_name: str + collection_name: str + + +class Initializer: + def __init__( + self, + database: AsyncIOMotorDatabase = None, + connection_string: str = None, + document_models: List[ + Union[Type["DocType"], Type["View"], str] + ] = None, + allow_index_dropping: bool = False, + recreate_views: bool = False, + ): + """ + Beanie initializer + + :param database: AsyncIOMotorDatabase - motor database instance + :param connection_string: str - MongoDB connection string + :param document_models: List[Union[Type[DocType], str]] - model classes + or strings with dot separated paths + :param allow_index_dropping: bool - if index dropping is allowed. + Default False + :return: None + """ + self.inited_classes: List[Type] = [] + self.allow_index_dropping = allow_index_dropping + self.recreate_views = recreate_views + + if (connection_string is None and database is None) or ( + connection_string is not None and database is not None + ): + raise ValueError( + "connection_string parameter or database parameter must be set" + ) + + if document_models is None: + raise ValueError("document_models parameter must be set") + if connection_string is not None: + database = AsyncIOMotorClient(connection_string)[ + URL(connection_string).path[1:] + ] + + self.database: AsyncIOMotorDatabase = database + + sort_order = { + ModelType.UnionDoc: 0, + ModelType.Document: 1, + ModelType.View: 2, + } + + self.document_models: List[Union[Type[DocType], Type[View]]] = [ + self.get_model(model) if isinstance(model, str) else model + for model in document_models + ] + + self.document_models.sort( + key=lambda val: sort_order[val.get_model_type()] + ) + + def __await__(self): + for model in self.document_models: + yield from self.init_class(model).__await__() + + # General + + @staticmethod + def get_model(dot_path: str) -> Type["DocType"]: + """ + Get the model by the path in format bar.foo.Model + + :param dot_path: str - dot seprated path to the model + :return: Type[DocType] - class of the model + """ + module_name, class_name = None, None + try: + module_name, class_name = dot_path.rsplit(".", 1) + return getattr(importlib.import_module(module_name), class_name) + + except ValueError: + raise ValueError( + f"'{dot_path}' doesn't have '.' path, eg. path.to.your.model.class" + ) + + except AttributeError: + raise AttributeError( + f"module '{module_name}' has no class called '{class_name}'" + ) + + def init_settings( + self, cls: Union[Type[Document], Type[View], Type[UnionDoc]] + ): + """ + Init Settings + + :param cls: Union[Type[Document], Type[View], Type[UnionDoc]] - Class + to init settings + :return: None + """ + settings_class = getattr(cls, "Settings", None) + settings_vars = ( + {} if settings_class is None else dict(vars(settings_class)) + ) + if issubclass(cls, Document): + cls._document_settings = DocumentSettings.parse_obj(settings_vars) + if issubclass(cls, View): + cls._settings = ViewSettings.parse_obj(settings_vars) + if issubclass(cls, UnionDoc): + cls._settings = UnionDocSettings.parse_obj(settings_vars) + + # Document + + @staticmethod + def set_default_class_vars(cls: Type[Document]): + """ + Set default class variables. + + :param cls: Union[Type[Document], Type[View], Type[UnionDoc]] - Class + to init settings + :return: + """ + cls._children = dict() + cls._parent = None + cls._inheritance_inited = False + cls._class_id = "" + + @staticmethod + def init_cache(cls) -> None: + """ + Init model's cache + :return: None + """ + if cls.get_settings().use_cache: + cls._cache = LRUCache( + capacity=cls.get_settings().cache_capacity, + expiration_time=cls.get_settings().cache_expiration_time, + ) + + @staticmethod + def init_document_fields(cls) -> None: + """ + Init class fields + :return: None + """ + if cls._link_fields is None: + cls._link_fields = {} + for k, v in cls.__fields__.items(): + path = v.alias or v.name + setattr(cls, k, ExpressionField(path)) + + link_info = detect_link(v) + if link_info is not None: + cls._link_fields[v.name] = link_info + + cls._hidden_fields = cls.get_hidden_fields() + + @staticmethod + def init_actions(cls): + """ + Init event-based actions + """ + ActionRegistry.clean_actions(cls) + for attr in dir(cls): + f = getattr(cls, attr) + if inspect.isfunction(f): + if hasattr(f, "has_action"): + ActionRegistry.add_action( + document_class=cls, + event_types=f.event_types, # type: ignore + action_direction=f.action_direction, # type: ignore + funct=f, + ) + + async def init_document_collection(self, cls): + """ + Init collection for the Document-based class + :param cls: + :return: + """ + cls.set_database(self.database) + + document_settings = cls.get_settings() + + # register in the Union Doc + + if document_settings.union_doc is not None: + document_settings.name = document_settings.union_doc.register_doc( + cls + ) + + # set a name + + if not document_settings.name: + document_settings.name = cls.__name__ + + # check mongodb version + build_info = await self.database.command({"buildInfo": 1}) + mongo_version = build_info["version"] + major_version = int(mongo_version.split(".")[0]) + + if document_settings.timeseries is not None and major_version < 5: + raise MongoDBVersionError( + "Timeseries are supported by MongoDB version 5 and higher" + ) + + # create motor collection + if ( + document_settings.timeseries is not None + and document_settings.name + not in await self.database.list_collection_names() + ): + + collection = await self.database.create_collection( + **document_settings.timeseries.build_query( + document_settings.name + ) + ) + else: + collection = self.database[document_settings.name] + + cls.set_collection(collection) + + @staticmethod + async def init_indexes(cls, allow_index_dropping: bool = False): + """ + Async indexes initializer + """ + collection = cls.get_motor_collection() + document_settings = cls.get_settings() + + old_indexes = (await collection.index_information()).keys() + new_indexes = ["_id_"] + + # Indexed field wrapped with Indexed() + found_indexes = [ + IndexModel( + [ + ( + fvalue.alias, + fvalue.type_._indexed[0], + ) + ], + **fvalue.type_._indexed[1], + ) + for _, fvalue in cls.__fields__.items() + if hasattr(fvalue.type_, "_indexed") and fvalue.type_._indexed + ] + + # get indexes from the Collection class + if document_settings.indexes: + found_indexes += document_settings.indexes + + # create indices + if found_indexes: + new_indexes += await collection.create_indexes(found_indexes) + + # delete indexes + # Only drop indexes if the user specifically allows for it + if allow_index_dropping: + for index in set(old_indexes) - set(new_indexes): + await collection.drop_index(index) + + async def init_document(self, cls: Type[Document]) -> Optional[Output]: + """ + Init Document-based class + + :param cls: + :return: + """ + if cls is Document: + return None + + if cls not in self.inited_classes: + self.set_default_class_vars(cls) + self.init_settings(cls) + + bases = [b for b in cls.__bases__ if issubclass(b, Document)] + if len(bases) > 1: + return None + parent = bases[0] + output = await self.init_document(parent) + if cls.get_settings().is_root and ( + parent is Document or not parent.get_settings().is_root + ): + if cls.get_collection_name() is None: + cls.set_collection_name(cls.__name__) + output = Output( + class_name=cls.__name__, + collection_name=cls.get_collection_name(), + ) + elif output is not None: + output.class_name = f"{output.class_name}.{cls.__name__}" + cls._class_id = output.class_name + cls.set_collection_name(output.collection_name) + parent.add_child(cls._class_id, cls) + cls._parent = parent + cls._inheritance_inited = True + + await self.init_document_collection(cls) + await self.init_indexes(cls, self.allow_index_dropping) + self.init_document_fields(cls) + self.init_cache(cls) + self.init_actions(cls) + + self.inited_classes.append(cls) + + return output + + else: + return Output( + class_name=cls._class_id, + collection_name=cls.get_collection_name(), + ) + + # Views + + @staticmethod + def init_view_fields(cls) -> None: + """ + Init class fields + :return: None + """ + for k, v in cls.__fields__.items(): + path = v.alias or v.name + setattr(cls, k, ExpressionField(path)) + + def init_view_collection(self, cls): + """ + Init collection for View + + :param cls: + :return: + """ + view_settings = cls.get_settings() + + if view_settings.name is None: + view_settings.name = cls.__name__ + + if inspect.isclass(view_settings.source): + view_settings.source = view_settings.source.get_collection_name() + + view_settings.motor_db = self.database + view_settings.motor_collection = self.database[view_settings.name] + + async def init_view(self, cls: Type[View]): + """ + Init View-based class + + :param cls: + :return: + """ + self.init_settings(cls) + self.init_view_collection(cls) + self.init_view_fields(cls) + + collection_names = await self.database.list_collection_names() + if self.recreate_views or cls._settings.name not in collection_names: + if cls._settings.name in collection_names: + await cls.get_motor_collection().drop() + + await self.database.command( + { + "create": cls.get_settings().name, + "viewOn": cls.get_settings().source, + "pipeline": cls.get_settings().pipeline, + } + ) + + # Union Doc + + async def init_union_doc(self, cls: Type[UnionDoc]): + """ + Init Union Doc based class + + :param cls: + :return: + """ + self.init_settings(cls) + if cls._settings.name is None: + cls._settings.name = cls.__name__ + + cls._settings.motor_db = self.database + cls._settings.motor_collection = self.database[cls._settings.name] + cls._is_inited = True + + # Final + + async def init_class( + self, cls: Union[Type[Document], Type[View], Type[UnionDoc]] + ): + """ + Init Document, View or UnionDoc based class. + + :param cls: + :return: + """ + if issubclass(cls, Document): + await self.init_document(cls) + + if issubclass(cls, View): + await self.init_view(cls) + + if issubclass(cls, UnionDoc): + await self.init_union_doc(cls) + + +async def init_beanie( + database: AsyncIOMotorDatabase = None, + connection_string: str = None, + document_models: List[Union[Type["DocType"], Type["View"], str]] = None, + allow_index_dropping: bool = False, + recreate_views: bool = False, +): + """ + Beanie initialization + + :param database: AsyncIOMotorDatabase - motor database instance + :param connection_string: str - MongoDB connection string + :param document_models: List[Union[Type[DocType], str]] - model classes + or strings with dot separated paths + :param allow_index_dropping: bool - if index dropping is allowed. + Default False + :return: None + """ + + await Initializer( + database=database, + connection_string=connection_string, + document_models=document_models, + allow_index_dropping=allow_index_dropping, + recreate_views=recreate_views, + ) diff --git a/beanie/odm/utils/parsing.py b/beanie/odm/utils/parsing.py index 10819d56..bf26ff35 100644 --- a/beanie/odm/utils/parsing.py +++ b/beanie/odm/utils/parsing.py @@ -1,5 +1,4 @@ from typing import Any, Type, Union, TYPE_CHECKING - from pydantic import BaseModel from beanie.exceptions import ( @@ -30,6 +29,20 @@ def parse_obj( if class_name not in model._document_models: raise DocWasNotRegisteredInUnionClass return parse_obj(model=model._document_models[class_name], data=data) + if ( + hasattr(model, "get_model_type") + and model.get_model_type() == ModelType.Document + and model._inheritance_inited + ): + if isinstance(data, dict): + class_name = data.get("_class_id") + elif hasattr(data, "_class_id"): + class_name = data._class_id + else: + class_name = None + + if model._children and class_name in model._children: + return parse_obj(model=model._children[class_name], data=data) # if hasattr(model, "_parse_obj_saving_state"): # return model._parse_obj_saving_state(data) # type: ignore diff --git a/beanie/odm/utils/projection.py b/beanie/odm/utils/projection.py index 813e5a09..fae16f22 100644 --- a/beanie/odm/utils/projection.py +++ b/beanie/odm/utils/projection.py @@ -8,15 +8,16 @@ def get_projection( - model: Type[ProjectionModelType], + model: Type[ProjectionModelType], ) -> Optional[Dict[str, int]]: - if ( - hasattr(model, "get_model_type") - and model.get_model_type() == ModelType.UnionDoc - ): + if hasattr(model, "get_model_type") and ( + model.get_model_type() == ModelType.UnionDoc or ( + model.get_model_type() == ModelType.Document and model._inheritance_inited)): return None + if hasattr(model, "Settings"): # MyPy checks settings = getattr(model, "Settings") + if hasattr(settings, "projection"): return getattr(settings, "projection") @@ -27,4 +28,5 @@ def get_projection( for name, field in model.__fields__.items(): document_projection[field.alias] = 1 + return document_projection diff --git a/beanie/odm/views.py b/beanie/odm/views.py index b58ecd6a..a4d30e5f 100644 --- a/beanie/odm/views.py +++ b/beanie/odm/views.py @@ -1,10 +1,8 @@ from typing import ClassVar -from motor.motor_asyncio import AsyncIOMotorDatabase from pydantic import BaseModel from beanie.exceptions import ViewWasNotInitialized -from beanie.odm.fields import ExpressionField from beanie.odm.interfaces.aggregate import AggregateInterface from beanie.odm.interfaces.detector import DetectionInterface, ModelType from beanie.odm.interfaces.find import FindInterface @@ -29,40 +27,6 @@ class View( _settings: ClassVar[ViewSettings] - @classmethod - async def init_view(cls, database, recreate_view: bool): - await cls.init_settings(database) - cls.init_fields() - - collection_names = await database.list_collection_names() - if recreate_view or cls._settings.name not in collection_names: - if cls._settings.name in collection_names: - await cls.get_motor_collection().drop() - - await database.command( - { - "create": cls.get_settings().name, - "viewOn": cls.get_settings().source, - "pipeline": cls.get_settings().pipeline, - } - ) - - @classmethod - async def init_settings(cls, database: AsyncIOMotorDatabase) -> None: - cls._settings = await ViewSettings.init( - database=database, view_class=cls - ) - - @classmethod - def init_fields(cls) -> None: - """ - Init class fields - :return: None - """ - for k, v in cls.__fields__.items(): - path = v.alias or v.name - setattr(cls, k, ExpressionField(path)) - @classmethod def get_settings(cls) -> ViewSettings: """ diff --git a/docs/async_tutorial/inheritance.md b/docs/async_tutorial/inheritance.md new file mode 100644 index 00000000..6afd8333 --- /dev/null +++ b/docs/async_tutorial/inheritance.md @@ -0,0 +1,130 @@ +## Inheritance + +Beanie `Documents` support inheritance as any other Python classes. But there are additional features available, if you mark the root model with parameter `is_root = True` in the inner Settings class. + +This behavior is similar to `UnionDoc`, but you don't need additional entity. +Parent `Document` act like a "controller", that handle proper storing and fetching different type `Document`. +Also, parent `Document` can have some shared attributes which are propagated to all children. +All classes in inheritance chain can be a used as `Link` in foreign `Documents`. + +Depend on the business logic, parent `Document` can be like "abstract" class that is not used to store objects of its type (like in example below), as well as can be a full-fledged entity, like its children. + +## Examples + +Define models + +```py hl_lines="20 20" +from typing import Optional, List +from motor.motor_asyncio import AsyncIOMotorClient +from pydantic import BaseModel +from beanie import Document, Link, init_beanie + + +class Vehicle(Document): + """Inheritance scheme bellow""" + # Vehicle + # / | \ + # / | \ + # Bicycle Bike Car + # \ + # \ + # Bus + # shared attribute for all children + color: str + + class Settings: + is_root = True + + +class Fuelled(BaseModel): + """Just a mixin""" + fuel: Optional[str] + + +class Bicycle(Vehicle): + """Derived from Vehicle, will use its collection""" + frame: int + wheels: int + + +class Bike(Vehicle, Fuelled): + ... + + +class Car(Vehicle, Fuelled): + body: str + + +class Bus(Car, Fuelled): + """Inheritance chain is Vehicle -> Car -> Bus, it is also stored in Vehicle collection""" + seats: int + + +class Owner(Document): + vehicles: Optional[List[Link[Vehicle]]] +``` + +Insert data + +```python +client = AsyncIOMotorClient() +await init_beanie(client.test_db, document_models=[Vehicle, Bicycle, Bike, Car, Bus]) + +bike_1 = await Bike(color='black', fuel='gasoline').insert() + +car_1 = await Car(color='grey', body='sedan', fuel='gasoline').insert() +car_2 = await Car(color='white', body='crossover', fuel='diesel').insert() + +bus_1 = await Bus(color='white', seats=80, body='bus', fuel='diesel').insert() +bus_2 = await Bus(color='yellow', seats=26, body='minibus', fuel='diesel').insert() + +owner = await Owner(name='John', vehicles=[car_1, car_2, bus_1]).insert() +``` + +Query data + +```python +# this query returns vehicles of all types that have white color, becuase `with_children` is True +white_vehicles = await Vehicle.find(Vehicle.color == 'white', with_children=True).to_list() +# [ +# Bicycle(..., color='white', frame=54, wheels=29), +# Car(fuel='diesel', ..., color='white', body='crossover'), +# Bus(fuel='diesel', ..., color='white', body='bus', seats=80) +# ] + +# however it is possible to limit by Vehicle type +cars_only = await Car.find().to_list() +# [ +# Car(fuel='gasoline', ..., color='grey', body='sedan'), +# Car(fuel='diesel', ..., color='white', body='crossover') +# ] + +# if search is based on child, query returns this child type and all sub-children +cars_and_buses = await Car.find(Car.fuel == 'diesel', with_children=True).to_list() +# [ +# Car(fuel='diesel', ..., color='white', body='crossover'), +# Bus(fuel='diesel', ..., color='white', body='bus', seats=80), +# Bus(fuel='diesel', ..., color='yellow', body='minibus', seats=26) +# ] + +# to get a single Document it is not necessary to known its type +# you can query using parent class +await Vehicle.get(bus_2.id, with_children=True) +# returns Bus instance: +# Bus(fuel='diesel', ..., color='yellow', body='minibus', seats=26) + +# re-fetch from DB with resolved links (using aggregation under the hood) +owner = await Owner.get(owner.id, fetch_links=True) +print(owner.vehicles) +# returns +# [ +# Car(fuel='diesel', ..., color='white', body='crossover'), +# Bus(fuel='diesel', ..., color='white', body='bus', seats=80), +# Car(fuel='gasoline', ..., color='grey', body='sedan') +# ] +# the same result will be if owner get without fetching link, and they will be fetched manually later + +# all other operations works the same as simple Documents +await Bike.find().update({"$set": {Bike.color: 'yellow'}}) +await Car.find_one(Car.body == 'sedan') +``` diff --git a/docs/changelog.md b/docs/changelog.md index a5dfb184..37c9dba7 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -2,6 +2,21 @@ Beanie project +## [1.14.0] - 2022-11-04 + +### Feature + +- Multi-model behavior for inherited documents + +### Breaking change + +- The inner class `Collection` is not supported more. Please, use `Settings` instead. + +### Implementation + +- Author - [Vitaliy Ivanov](https://github.com/Vitalium) +- PR + ## [1.13.1] - 2022-10-26 ### Fix @@ -1022,4 +1037,6 @@ how specific type should be presented in the database [1.13.0]: https://pypi.org/project/beanie/1.13.0 -[1.13.1]: https://pypi.org/project/beanie/1.13.1 \ No newline at end of file +[1.13.1]: https://pypi.org/project/beanie/1.13.1 + +[1.14.0]: https://pypi.org/project/beanie/1.14.0 \ No newline at end of file diff --git a/pydoc-markdown.yml b/pydoc-markdown.yml index 1aa31685..3a7fda28 100644 --- a/pydoc-markdown.yml +++ b/pydoc-markdown.yml @@ -75,6 +75,8 @@ renderer: source: docs/async_tutorial/update.md - title: Multi-model pattern source: docs/async_tutorial/multi-model.md + - title: Inheritance + source: docs/async_tutorial/inheritance.md - title: Indexes & collection names source: docs/async_tutorial/collection_setup.md - title: Aggregation diff --git a/pyproject.toml b/pyproject.toml index a2eecedc..36155742 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "beanie" -version = "1.13.1" +version = "1.14.0" description = "Python ODM for MongoDB" authors = ["Roman "] license = "Apache-2.0" diff --git a/tests/migrations/iterative/test_change_subfield.py b/tests/migrations/iterative/test_change_subfield.py index f8844a70..3a30c245 100644 --- a/tests/migrations/iterative/test_change_subfield.py +++ b/tests/migrations/iterative/test_change_subfield.py @@ -22,7 +22,7 @@ class OldNote(Document): title: str tag: OldTag - class Collection: + class Settings: name = "notes" @@ -30,7 +30,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/iterative/test_change_value.py b/tests/migrations/iterative/test_change_value.py index 219fb41a..5685a2c7 100644 --- a/tests/migrations/iterative/test_change_value.py +++ b/tests/migrations/iterative/test_change_value.py @@ -17,7 +17,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/iterative/test_change_value_subfield.py b/tests/migrations/iterative/test_change_value_subfield.py index 97ff7170..d64562a8 100644 --- a/tests/migrations/iterative/test_change_value_subfield.py +++ b/tests/migrations/iterative/test_change_value_subfield.py @@ -17,7 +17,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/iterative/test_pack_unpack.py b/tests/migrations/iterative/test_pack_unpack.py index db4ec9e6..d12d2caf 100644 --- a/tests/migrations/iterative/test_pack_unpack.py +++ b/tests/migrations/iterative/test_pack_unpack.py @@ -23,7 +23,7 @@ class OldNote(Document): tag_name: str tag_color: str - class Collection: + class Settings: name = "notes" @@ -31,7 +31,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/iterative/test_rename_field.py b/tests/migrations/iterative/test_rename_field.py index 5ffef121..c294bcef 100644 --- a/tests/migrations/iterative/test_rename_field.py +++ b/tests/migrations/iterative/test_rename_field.py @@ -17,7 +17,7 @@ class OldNote(Document): name: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -25,7 +25,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -37,7 +37,7 @@ async def notes(loop, db): note = OldNote(name=str(i), tag=Tag(name="test", color="red")) await note.insert() yield - await OldNote.delete_all() + # await OldNote.delete_all() async def test_migration_rename_field(settings, notes, db): diff --git a/tests/migrations/migrations_for_test/break/20210413211219_break.py b/tests/migrations/migrations_for_test/break/20210413211219_break.py index 8800899b..c84d3c50 100644 --- a/tests/migrations/migrations_for_test/break/20210413211219_break.py +++ b/tests/migrations/migrations_for_test/break/20210413211219_break.py @@ -12,7 +12,7 @@ class OldNote(Document): name: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -20,7 +20,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/change_subfield/20210413152406_change_subfield.py b/tests/migrations/migrations_for_test/change_subfield/20210413152406_change_subfield.py index 3c371427..288dbe15 100644 --- a/tests/migrations/migrations_for_test/change_subfield/20210413152406_change_subfield.py +++ b/tests/migrations/migrations_for_test/change_subfield/20210413152406_change_subfield.py @@ -17,7 +17,7 @@ class OldNote(Document): title: str tag: OldTag - class Collection: + class Settings: name = "notes" @@ -25,7 +25,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/change_subfield_value/20210413143405_change_subfield_value.py b/tests/migrations/migrations_for_test/change_subfield_value/20210413143405_change_subfield_value.py index 4f1cdea2..d09fc94c 100644 --- a/tests/migrations/migrations_for_test/change_subfield_value/20210413143405_change_subfield_value.py +++ b/tests/migrations/migrations_for_test/change_subfield_value/20210413143405_change_subfield_value.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/change_value/20210413115234_change_value.py b/tests/migrations/migrations_for_test/change_value/20210413115234_change_value.py index fb458045..73f10eb0 100644 --- a/tests/migrations/migrations_for_test/change_value/20210413115234_change_value.py +++ b/tests/migrations/migrations_for_test/change_value/20210413115234_change_value.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/free_fall/20210413210446_free_fall.py b/tests/migrations/migrations_for_test/free_fall/20210413210446_free_fall.py index 5eb3c8a3..723ecc0d 100644 --- a/tests/migrations/migrations_for_test/free_fall/20210413210446_free_fall.py +++ b/tests/migrations/migrations_for_test/free_fall/20210413210446_free_fall.py @@ -12,7 +12,7 @@ class OldNote(Document): name: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -20,7 +20,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/many_migrations/20210413170640_1.py b/tests/migrations/migrations_for_test/many_migrations/20210413170640_1.py index f3305514..a044fd1b 100644 --- a/tests/migrations/migrations_for_test/many_migrations/20210413170640_1.py +++ b/tests/migrations/migrations_for_test/many_migrations/20210413170640_1.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/many_migrations/20210413170645_2.py b/tests/migrations/migrations_for_test/many_migrations/20210413170645_2.py index 658262b0..e166d184 100644 --- a/tests/migrations/migrations_for_test/many_migrations/20210413170645_2.py +++ b/tests/migrations/migrations_for_test/many_migrations/20210413170645_2.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/many_migrations/20210413170700_3_skip_backward.py b/tests/migrations/migrations_for_test/many_migrations/20210413170700_3_skip_backward.py index f39fb0db..3476426d 100644 --- a/tests/migrations/migrations_for_test/many_migrations/20210413170700_3_skip_backward.py +++ b/tests/migrations/migrations_for_test/many_migrations/20210413170700_3_skip_backward.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/many_migrations/20210413170709_3_skip_forward.py b/tests/migrations/migrations_for_test/many_migrations/20210413170709_3_skip_forward.py index ea09735b..52feb071 100644 --- a/tests/migrations/migrations_for_test/many_migrations/20210413170709_3_skip_forward.py +++ b/tests/migrations/migrations_for_test/many_migrations/20210413170709_3_skip_forward.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/many_migrations/20210413170728_4_5.py b/tests/migrations/migrations_for_test/many_migrations/20210413170728_4_5.py index 9e257287..a53eb72d 100644 --- a/tests/migrations/migrations_for_test/many_migrations/20210413170728_4_5.py +++ b/tests/migrations/migrations_for_test/many_migrations/20210413170728_4_5.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/many_migrations/20210413170734_6_7.py b/tests/migrations/migrations_for_test/many_migrations/20210413170734_6_7.py index 1cf9cdcb..5fff1882 100644 --- a/tests/migrations/migrations_for_test/many_migrations/20210413170734_6_7.py +++ b/tests/migrations/migrations_for_test/many_migrations/20210413170734_6_7.py @@ -12,7 +12,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/pack_unpack/20210413135927_pack_unpack.py b/tests/migrations/migrations_for_test/pack_unpack/20210413135927_pack_unpack.py index 1374b07f..b30248c6 100644 --- a/tests/migrations/migrations_for_test/pack_unpack/20210413135927_pack_unpack.py +++ b/tests/migrations/migrations_for_test/pack_unpack/20210413135927_pack_unpack.py @@ -13,7 +13,7 @@ class OldNote(Document): tag_name: str tag_color: str - class Collection: + class Settings: name = "notes" @@ -21,7 +21,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/migrations_for_test/remove_index/20210414135045_remove_index.py b/tests/migrations/migrations_for_test/remove_index/20210414135045_remove_index.py index 63ea2c9e..2a7b2deb 100644 --- a/tests/migrations/migrations_for_test/remove_index/20210414135045_remove_index.py +++ b/tests/migrations/migrations_for_test/remove_index/20210414135045_remove_index.py @@ -12,7 +12,7 @@ class OldNote(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" indexes = ["title"] @@ -21,7 +21,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" indexes = [ "_id", diff --git a/tests/migrations/migrations_for_test/rename_field/20210407203225_rename_field.py b/tests/migrations/migrations_for_test/rename_field/20210407203225_rename_field.py index c7d286be..7c8e000d 100644 --- a/tests/migrations/migrations_for_test/rename_field/20210407203225_rename_field.py +++ b/tests/migrations/migrations_for_test/rename_field/20210407203225_rename_field.py @@ -12,7 +12,7 @@ class OldNote(Document): name: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -20,7 +20,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/test_break.py b/tests/migrations/test_break.py index 6cf1cee6..ee73f961 100644 --- a/tests/migrations/test_break.py +++ b/tests/migrations/test_break.py @@ -16,7 +16,7 @@ class OldNote(Document): name: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -24,7 +24,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/test_directions.py b/tests/migrations/test_directions.py index 84bd9834..3104a9bf 100644 --- a/tests/migrations/test_directions.py +++ b/tests/migrations/test_directions.py @@ -16,7 +16,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/test_free_fall.py b/tests/migrations/test_free_fall.py index 8f240f17..1712cb3b 100644 --- a/tests/migrations/test_free_fall.py +++ b/tests/migrations/test_free_fall.py @@ -16,7 +16,7 @@ class OldNote(Document): name: str tag: Tag - class Collection: + class Settings: name = "notes" @@ -24,7 +24,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/migrations/test_remove_indexes.py b/tests/migrations/test_remove_indexes.py index 234ac123..0b8132b9 100644 --- a/tests/migrations/test_remove_indexes.py +++ b/tests/migrations/test_remove_indexes.py @@ -16,7 +16,7 @@ class OldNote(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" indexes = ["title"] @@ -25,7 +25,7 @@ class Note(Document): title: str tag: Tag - class Collection: + class Settings: name = "notes" diff --git a/tests/odm/conftest.py b/tests/odm/conftest.py index b21ee9df..65c5d1c9 100644 --- a/tests/odm/conftest.py +++ b/tests/odm/conftest.py @@ -4,7 +4,7 @@ import pytest -from beanie.odm.utils.general import init_beanie +from beanie.odm.utils.init import init_beanie from tests.odm.models import ( DocumentTestModel, DocumentTestModelWithIndexFlagsAliases, @@ -40,7 +40,7 @@ HouseWithRevision, WindowWithRevision, YardWithRevision, - DocumentWithActions2, + DocumentWithActions2, Vehicle, Bicycle, Bike, Car, Bus, Owner, ) from tests.odm.views import TestView from tests.odm.models import ( @@ -174,6 +174,7 @@ async def init(loop, db): WindowWithRevision, YardWithRevision, DocumentWithActions2, + Vehicle, Bicycle, Bike, Car, Bus, Owner ] await init_beanie( database=db, diff --git a/tests/odm/documents/test_inheritance.py b/tests/odm/documents/test_inheritance.py new file mode 100644 index 00000000..cdd7b17c --- /dev/null +++ b/tests/odm/documents/test_inheritance.py @@ -0,0 +1,100 @@ +from beanie import Link +from tests.odm.models import ( + Vehicle, + Bicycle, + Bike, + Car, + Bus, TunedDocument, Owner, +) + + +class TestInheritance: + async def test_inheritance(self, db): + bicycle_1 = await Bicycle(color="white", frame=54, wheels=29).insert() + bicycle_2 = await Bicycle(color="red", frame=52, wheels=28).insert() + + bike_1 = await Bike(color="black", fuel="gasoline").insert() + + car_1 = await Car(color="grey", body="sedan", fuel="gasoline").insert() + car_2 = await Car( + color="white", body="crossover", fuel="diesel" + ).insert() + + bus_1 = await Bus( + color="white", seats=80, body="bus", fuel="diesel" + ).insert() + bus_2 = await Bus( + color="yellow", seats=26, body="minibus", fuel="diesel" + ).insert() + + white_vehicles = await Vehicle.find(Vehicle.color == "white", + with_children=True).to_list() + + cars_only = await Car.find().to_list() + cars_and_buses = await Car.find(Car.fuel == "diesel", + with_children=True).to_list() + + big_bicycles = await Bicycle.find(Bicycle.wheels > 28).to_list() + + await Bike.find().update({"$set": {Bike.color: "yellow"}}) + sedan = await Car.find_one(Car.body == "sedan") + + sedan.color = "yellow" + await sedan.save() + + # get using Vehicle should return Bike instance + updated_bike = await Vehicle.get(bike_1.id, with_children=True) + + assert isinstance(sedan, Car) + + assert isinstance(updated_bike, Bike) + assert updated_bike.color == "yellow" + + assert Vehicle._parent is TunedDocument + assert Bus._parent is Car + + assert len(big_bicycles) == 1 + assert big_bicycles[0].wheels > 28 + + assert len(white_vehicles) == 3 + assert len(cars_only) == 2 + + assert {Car, Bus} == set(i.__class__ for i in cars_and_buses) + assert {Bicycle, Car, Bus} == set(i.__class__ for i in white_vehicles) + + white_vehicles_2 = await Car.find(Vehicle.color == "white").to_list() + assert len(white_vehicles_2) == 1 + + for i in cars_and_buses: + assert i.fuel == "diesel" + + for e in (bicycle_1, bicycle_2, bike_1, car_1, car_2, bus_1, bus_2): + assert isinstance(e, Vehicle) + await e.delete() + + async def test_links(self, db): + car_1 = await Car(color="grey", body="sedan", fuel="gasoline").insert() + car_2 = await Car( + color="white", body="crossover", fuel="diesel" + ).insert() + + bus_1 = await Bus( + color="white", seats=80, body="bus", fuel="diesel" + ).insert() + + owner = await Owner(name="John").insert() + owner.vehicles = [car_1, car_2, bus_1] + await owner.save() + + # re-fetch from DB w/o links + owner = await Owner.get(owner.id) + assert {Link} == set(i.__class__ for i in owner.vehicles) + await owner.fetch_all_links() + assert {Car, Bus} == set(i.__class__ for i in owner.vehicles) + + # re-fetch from DB with resolved links + owner = await Owner.get(owner.id, fetch_links=True) + assert {Car, Bus} == set(i.__class__ for i in owner.vehicles) + + for e in (owner, car_1, car_2, bus_1): + await e.delete() diff --git a/tests/odm/models.py b/tests/odm/models.py index 1c38f502..161166db 100644 --- a/tests/odm/models.py +++ b/tests/odm/models.py @@ -80,12 +80,9 @@ class DocumentTestModelWithCustomCollectionName(Document): test_list: List[SubDocument] test_str: str - class Collection: + class Settings: name = "custom" - # class Settings: - # name = "custom" - class DocumentTestModelWithSimpleIndex(Document): test_int: Indexed(int) @@ -110,7 +107,7 @@ class DocumentTestModelWithComplexIndex(Document): test_list: List[SubDocument] test_str: str - class Collection: + class Settings: name = "docs_with_index" indexes = [ "test_int", @@ -124,20 +121,6 @@ class Collection: ), ] - # class Settings: - # name = "docs_with_index" - # indexes = [ - # "test_int", - # [ - # ("test_int", pymongo.ASCENDING), - # ("test_str", pymongo.DESCENDING), - # ], - # IndexModel( - # [("test_str", pymongo.DESCENDING)], - # name="test_string_index_DESCENDING", - # ), - # ] - class DocumentTestModelWithDroppedIndex(Document): test_int: int @@ -478,3 +461,57 @@ class HouseWithRevision(Document): class Settings: use_revision = True use_state_management = True + + +class TunedDocument(Document): + # some common settings for all models in the file + class Settings: + is_root = True + use_state_management = True + + +# classes for inheritance test +class Vehicle(TunedDocument): + """Root parent for testing flat inheritance""" + + # Vehicle + # / | \ + # / | \ + # Bicycle Bike Car + # \ + # \ + # Bus + color: str + + @after_event(Insert) + def on_object_create(self): + # this event will be triggered for all children too (self will have corresponding type) + ... + + +class Bicycle(Vehicle): + frame: int + wheels: int + + +class Fuelled(BaseModel): + """Just a mixin""" + + fuel: Optional[str] + + +class Car(Vehicle, Fuelled): + body: str + + +class Bike(Vehicle, Fuelled): + ... + + +class Bus(Car, Fuelled): + seats: int + + +class Owner(Document): + name: str + vehicles: List[Link[Vehicle]] = [] diff --git a/tests/test_beanie.py b/tests/test_beanie.py index 0ece1421..c534bd19 100644 --- a/tests/test_beanie.py +++ b/tests/test_beanie.py @@ -2,4 +2,4 @@ def test_version(): - assert __version__ == "1.13.1" + assert __version__ == "1.14.0"