diff --git a/pyproject.toml b/pyproject.toml index 0f4d94260..f24ad417d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -28,7 +28,7 @@ dependencies = [ "aiohttp-jinja2==1.5", "aioipfs @ git+https://github.com/aleph-im/aioipfs.git@d671c79b2871bb4d6c8877ba1e7f3ffbe7d20b71", "alembic==1.12.1", - "aleph-message==0.4.9", + "aleph-message @ git+https://github.com/1yam/aleph-message@1yam-512-VM", "aleph-nuls2==0.1", "aleph-p2p-client @ git+https://github.com/aleph-im/p2p-service-client-python@2c04af39c566217f629fd89505ffc3270fba8676", "aleph-pytezos==3.13.4", @@ -157,9 +157,8 @@ dependencies = [ "isort==5.13.2", "check-sdist==0.1.3", "sqlalchemy[mypy]==1.4.41", - "yamlfix==1.16.1", - # because of aleph messages otherwise yamlfix install a too new version - "pydantic>=1.10.5,<2.0.0", + "yamlfix>=1.17", + "pydantic>=2,<3.0.0", "pyproject-fmt==2.2.1", "types-aiofiles", "types-protobuf", diff --git a/src/aleph/chains/chain_data_service.py b/src/aleph/chains/chain_data_service.py index 7eadf1afe..2e83c2b1e 100644 --- a/src/aleph/chains/chain_data_service.py +++ b/src/aleph/chains/chain_data_service.py @@ -1,4 +1,5 @@ import asyncio +import json from typing import Any, Dict, List, Mapping, Optional, Self, Set, Type, Union, cast import aio_pika.abc @@ -69,7 +70,7 @@ async def prepare_sync_event_payload( messages=[OnChainMessage.from_orm(message) for message in messages] ), ) - archive_content: bytes = archive.json().encode("utf-8") + archive_content: bytes = archive.model_dump_json().encode("utf-8") ipfs_cid = await self.storage_service.add_file( session=session, file_content=archive_content, engine=ItemType.ipfs @@ -166,7 +167,9 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An ) try: - payload = cast(GenericMessageEvent, payload_model.parse_obj(tx.content)) + payload = cast( + GenericMessageEvent, payload_model.model_validate(tx.content) + ) except ValidationError: raise InvalidContent(f"Incompatible tx content for {tx.chain}/{tx.hash}") @@ -189,7 +192,7 @@ def _get_tx_messages_smart_contract_protocol(tx: ChainTxDb) -> List[Dict[str, An item_hash=ItemHash(payload.content), metadata=None, ) - item_content = content.json(exclude_none=True) + item_content = json.dumps(content.model_dump(exclude_none=True)) else: item_content = payload.content diff --git a/src/aleph/chains/ethereum.py b/src/aleph/chains/ethereum.py index 1036166a8..a3cc2e9eb 100644 --- a/src/aleph/chains/ethereum.py +++ b/src/aleph/chains/ethereum.py @@ -364,7 +364,7 @@ async def packer(self, config: Config): account, int(gas_price * 1.1), nonce, - sync_event_payload.json(), + sync_event_payload.model_dump_json(), ) LOGGER.info("Broadcast %r on %s" % (response, CHAIN_NAME)) diff --git a/src/aleph/chains/indexer_reader.py b/src/aleph/chains/indexer_reader.py index 808aab374..c5bff3c53 100644 --- a/src/aleph/chains/indexer_reader.py +++ b/src/aleph/chains/indexer_reader.py @@ -93,7 +93,7 @@ def make_events_query( model = SyncEvent event_type_str = "syncEvents" - fields = "\n".join(model.__fields__.keys()) + fields = "\n".join(model.model_fields.keys()) params: Dict[str, Any] = { "blockchain": f'"{blockchain.value}"', "limit": limit, @@ -146,8 +146,8 @@ async def _query(self, query: str, model: Type[T]) -> T: response = await self.http_session.post("/", json={"query": query}) response.raise_for_status() - response_json = await response.json() - return model.parse_obj(response_json) + response_json = await response.model_dump_json() + return model.model_validate(response_json) async def fetch_account_state( self, @@ -196,7 +196,7 @@ def indexer_event_to_chain_tx( if isinstance(indexer_event, MessageEvent): protocol = ChainSyncProtocol.SMART_CONTRACT protocol_version = 1 - content = indexer_event.dict() + content = indexer_event.model_dump() else: sync_message = aleph_json.loads(indexer_event.message) diff --git a/src/aleph/chains/nuls2.py b/src/aleph/chains/nuls2.py index 6062b5b0a..ff971ff16 100644 --- a/src/aleph/chains/nuls2.py +++ b/src/aleph/chains/nuls2.py @@ -210,7 +210,7 @@ async def packer(self, config: Config): # Required to apply update to the files table in get_chaindata session.commit() - content = sync_event_payload.json() + content = sync_event_payload.model_dump_json() tx = await prepare_transfer_tx( address, [(target_addr, CHEAP_UNIT_FEE)], @@ -248,7 +248,7 @@ async def get_transactions( "pagination": 500, }, ) as resp: - jres = await resp.json() + jres = await resp.model_dump_json() for tx in sorted(jres["transactions"], key=itemgetter("height")): if remark is not None and tx["remark"] != remark: continue diff --git a/src/aleph/chains/tezos.py b/src/aleph/chains/tezos.py index 59f5690a2..5e17b412a 100644 --- a/src/aleph/chains/tezos.py +++ b/src/aleph/chains/tezos.py @@ -139,7 +139,7 @@ def make_graphql_query( async def get_indexer_status(http_session: aiohttp.ClientSession) -> SyncStatus: response = await http_session.post("/", json={"query": make_graphql_status_query()}) response.raise_for_status() - response_json = await response.json() + response_json = await response.model_dump_json() return SyncStatus(response_json["data"]["indexStatus"]["status"]) @@ -160,9 +160,9 @@ async def fetch_messages( response = await http_session.post("/", json={"query": query}) response.raise_for_status() - response_json = await response.json() + response_json = await response.model_dump_json() - return IndexerResponse[IndexerMessageEvent].parse_obj(response_json) + return IndexerResponse[IndexerMessageEvent].model_validate(response_json) def indexer_event_to_chain_tx( @@ -176,7 +176,7 @@ def indexer_event_to_chain_tx( publisher=indexer_event.source, protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=indexer_event.payload.dict(), + content=indexer_event.payload.model_dump(), ) return chain_tx diff --git a/src/aleph/db/models/messages.py b/src/aleph/db/models/messages.py index 0d2f8e894..eb06ae12f 100644 --- a/src/aleph/db/models/messages.py +++ b/src/aleph/db/models/messages.py @@ -14,7 +14,6 @@ StoreContent, ) from pydantic import ValidationError -from pydantic.error_wrappers import ErrorWrapper from sqlalchemy import ( ARRAY, TIMESTAMP, @@ -62,14 +61,14 @@ def validate_message_content( content_dict: Dict[str, Any], ) -> BaseContent: content_type = CONTENT_TYPE_MAP[message_type] - content = content_type.parse_obj(content_dict) + content = content_type.model_validate(content_dict) # Validate that the content time can be converted to datetime. This will # raise a ValueError and be caught # TODO: move this validation in aleph-message try: _ = dt.datetime.fromtimestamp(content_dict["time"]) except ValueError as e: - raise ValidationError([ErrorWrapper(e, loc="time")], model=content_type) from e + raise ValidationError(str(e)) from e return content diff --git a/src/aleph/handlers/content/vm.py b/src/aleph/handlers/content/vm.py index 71f3a3dc4..7aa8a6a90 100644 --- a/src/aleph/handlers/content/vm.py +++ b/src/aleph/handlers/content/vm.py @@ -182,7 +182,7 @@ def vm_message_to_db(message: MessageDb) -> VmBaseDb: if content.on.message: vm.message_triggers = [ - subscription.dict() for subscription in content.on.message + subscription.model_dump() for subscription in content.on.message ] vm.code_volume = CodeVolumeDb( diff --git a/src/aleph/schemas/api/accounts.py b/src/aleph/schemas/api/accounts.py index ca904d3af..86c7df649 100644 --- a/src/aleph/schemas/api/accounts.py +++ b/src/aleph/schemas/api/accounts.py @@ -2,7 +2,7 @@ from decimal import Decimal from typing import List -from pydantic import BaseModel, Field +from pydantic import BaseModel, ConfigDict, Field from aleph.types.files import FileType from aleph.types.sort_order import SortOrder @@ -40,8 +40,7 @@ class GetAccountFilesResponseItem(BaseModel): class GetAccountFilesResponse(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) address: str total_size: int diff --git a/src/aleph/schemas/api/messages.py b/src/aleph/schemas/api/messages.py index 8c7f96c99..eb7aaf457 100644 --- a/src/aleph/schemas/api/messages.py +++ b/src/aleph/schemas/api/messages.py @@ -25,10 +25,8 @@ ProgramContent, StoreContent, ) -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field -import aleph.toolkit.json as aleph_json from aleph.db.models import MessageDb from aleph.types.message_status import ErrorCode, MessageStatus @@ -39,26 +37,26 @@ class MessageConfirmation(BaseModel): """Format of the result when a message has been confirmed on a blockchain""" - class Config: - orm_mode = True - json_encoders = {dt.datetime: lambda d: d.timestamp()} + model_config = ConfigDict( + from_attributes=True, json_encoders={dt.datetime: lambda d: d.timestamp()} + ) chain: Chain height: int hash: str -class BaseMessage(GenericModel, Generic[MType, ContentType]): - class Config: - orm_mode = True - json_loads = aleph_json.loads - json_encoders = {dt.datetime: lambda d: d.timestamp()} +class BaseMessage(BaseModel, Generic[MType, ContentType]): + model_config = ConfigDict( + from_attributes=True, + json_encoders={dt.datetime: lambda d: d.timestamp()}, + ) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime @@ -133,7 +131,7 @@ def format_message(message: MessageDb) -> AlephMessage: def format_message_dict(message: Dict[str, Any]) -> AlephMessage: message_type = message.get("type") message_cls = MESSAGE_CLS_DICT[message_type] - return message_cls.parse_obj(message) + return message_cls.model_validate(message) class BaseMessageStatus(BaseModel): @@ -145,45 +143,41 @@ class BaseMessageStatus(BaseModel): # We already have a model for the validation of pending messages, but this one # is only used for formatting and does not try to be smart. class PendingMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime channel: Optional[str] = None - content: Optional[Dict[str, Any]] + content: Optional[Dict[str, Any]] = None reception_time: dt.datetime class PendingMessageStatus(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) status: MessageStatus = MessageStatus.PENDING messages: List[PendingMessage] class ProcessedMessageStatus(BaseMessageStatus): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) status: MessageStatus = MessageStatus.PROCESSED message: AlephMessage class ForgottenMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType item_type: ItemType item_hash: str @@ -201,7 +195,7 @@ class RejectedMessageStatus(BaseMessageStatus): status: MessageStatus = MessageStatus.REJECTED message: Mapping[str, Any] error_code: ErrorCode - details: Any + details: Any = None MessageWithStatus = Union[ @@ -213,9 +207,9 @@ class RejectedMessageStatus(BaseMessageStatus): class MessageListResponse(BaseModel): - class Config: - json_encoders = {dt.datetime: lambda d: d.timestamp()} - json_loads = aleph_json.loads + model_config = ConfigDict( + json_encoders={dt.datetime: lambda d: d.timestamp()}, + ) messages: List[AlephMessage] pagination_page: int diff --git a/src/aleph/schemas/base_messages.py b/src/aleph/schemas/base_messages.py index 71aa1df72..4b1c79f58 100644 --- a/src/aleph/schemas/base_messages.py +++ b/src/aleph/schemas/base_messages.py @@ -4,11 +4,10 @@ import datetime as dt from hashlib import sha256 -from typing import Any, Generic, Mapping, Optional, TypeVar, cast +from typing import Any, Generic, Optional, TypeVar, cast from aleph_message.models import BaseContent, Chain, ItemType, MessageType -from pydantic import root_validator, validator -from pydantic.generics import GenericModel +from pydantic import BaseModel, ValidationInfo, field_validator, model_validator from aleph.toolkit.timestamp import timestamp_to_datetime from aleph.utils import item_type_from_hash @@ -17,7 +16,7 @@ ContentType = TypeVar("ContentType", bound=BaseContent) -class AlephBaseMessage(GenericModel, Generic[MType, ContentType]): +class AlephBaseMessage(BaseModel, Generic[MType, ContentType]): """ The base structure of an Aleph message. All the fields of this class appear in all the representations @@ -26,53 +25,52 @@ class AlephBaseMessage(GenericModel, Generic[MType, ContentType]): sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: str time: dt.datetime channel: Optional[str] = None content: Optional[ContentType] = None - @root_validator() - def check_item_type(cls, values): + @model_validator(mode="after") + def check_item_type(self): """ Checks that the item hash of the message matches the one inferred from the hash. Only applicable to storage/ipfs item types. """ - item_type_value = values.get("item_type") + item_type_value = self.item_type if item_type_value is None: raise ValueError("Could not determine item type") item_type = ItemType(item_type_value) - if item_type == ItemType.inline: - return values - - item_hash = values.get("item_hash") - if item_hash is None: - raise ValueError("Could not determine item hash") + if item_type != ItemType.inline: + item_hash = self.item_hash + if item_hash is None: + raise ValueError("Could not determine item hash") - expected_item_type = item_type_from_hash(item_hash) - if item_type != expected_item_type: - raise ValueError( - f"Expected {expected_item_type} based on hash but item type is {item_type}." - ) - return values + expected_item_type = item_type_from_hash(item_hash) + if item_type != expected_item_type: + raise ValueError( + f"Expected {expected_item_type} based on hash but item type is {item_type}." + ) + return self - @validator("item_hash") - def check_item_hash(cls, v: Any, values: Mapping[str, Any]): + @field_validator("item_hash", mode="after") + @classmethod + def check_item_hash(cls, v: Any, info: ValidationInfo): """ For inline item types, check that the item hash is equal to the hash of the item content. """ - item_type = values.get("item_type") + item_type = info.data.get("item_type") if item_type is None: raise ValueError("Could not determine item type") if item_type == ItemType.inline: - item_content = cast(Optional[str], values.get("item_content")) + item_content = cast(Optional[str], info.data.get("item_content")) if item_content is None: raise ValueError("Could not find inline item content") @@ -90,8 +88,9 @@ def check_item_hash(cls, v: Any, values: Mapping[str, Any]): raise ValueError(f"Unknown item type: '{item_type}'") return v - @validator("time", pre=True) - def check_time(cls, v, values): + @field_validator("time", mode="before") + @classmethod + def check_time(cls, v: Any, info: ValidationInfo): """ Parses the time field as a UTC datetime. Contrary to the default datetime validator, this implementation raises an exception if the time field is diff --git a/src/aleph/schemas/chains/indexer_response.py b/src/aleph/schemas/chains/indexer_response.py index f78b2ba82..6081769a1 100644 --- a/src/aleph/schemas/chains/indexer_response.py +++ b/src/aleph/schemas/chains/indexer_response.py @@ -6,7 +6,7 @@ from enum import Enum from typing import List, Protocol, Tuple -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, Field, field_validator class GenericMessageEvent(Protocol): @@ -43,11 +43,11 @@ class AccountEntityState(BaseModel): pending: List[Tuple[dt.datetime, dt.datetime]] processed: List[Tuple[dt.datetime, dt.datetime]] - @validator("pending", "processed", pre=True, each_item=True) - def split_datetime_ranges(cls, v): - if isinstance(v, str): - return v.split("/") - return v + @field_validator("pending", "processed", mode="before") + def split_datetime_ranges(cls, values): + return map( + lambda value: value.split("/") if isinstance(value, str) else value, values + ) class IndexerAccountStateResponseData(BaseModel): diff --git a/src/aleph/schemas/chains/sync_events.py b/src/aleph/schemas/chains/sync_events.py index ab4304847..3c30bbe17 100644 --- a/src/aleph/schemas/chains/sync_events.py +++ b/src/aleph/schemas/chains/sync_events.py @@ -2,27 +2,26 @@ from typing import Annotated, List, Literal, Optional, Union from aleph_message.models import Chain, ItemHash, ItemType, MessageType -from pydantic import BaseModel, Field, validator +from pydantic import BaseModel, ConfigDict, Field, field_validator from aleph.types.chain_sync import ChainSyncProtocol from aleph.types.channel import Channel class OnChainMessage(BaseModel): - class Config: - orm_mode = True + model_config = ConfigDict(from_attributes=True) sender: str chain: Chain - signature: Optional[str] + signature: Optional[str] = None type: MessageType - item_content: Optional[str] + item_content: Optional[str] = None item_type: ItemType item_hash: ItemHash time: float channel: Optional[Channel] = None - @validator("time", pre=True) + @field_validator("time", mode="before") def check_time(cls, v, values): if isinstance(v, dt.datetime): return v.timestamp() diff --git a/src/aleph/schemas/chains/tezos_indexer_response.py b/src/aleph/schemas/chains/tezos_indexer_response.py index d9204d427..4283dfa4b 100644 --- a/src/aleph/schemas/chains/tezos_indexer_response.py +++ b/src/aleph/schemas/chains/tezos_indexer_response.py @@ -2,8 +2,7 @@ from enum import Enum from typing import Generic, List, TypeVar -from pydantic import BaseModel, Field -from pydantic.generics import GenericModel +from pydantic import BaseModel, ConfigDict, Field PayloadType = TypeVar("PayloadType") @@ -24,7 +23,7 @@ class IndexerStats(BaseModel): total_events: int = Field(alias="totalEvents") -class IndexerEvent(GenericModel, Generic[PayloadType]): +class IndexerEvent(BaseModel, Generic[PayloadType]): source: str timestamp: dt.datetime block_level: int = Field(alias="blockLevel") @@ -34,8 +33,7 @@ class IndexerEvent(GenericModel, Generic[PayloadType]): class MessageEventPayload(BaseModel): - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) timestamp: float addr: str @@ -67,11 +65,11 @@ def timestamp_seconds(self) -> float: IndexerEventType = TypeVar("IndexerEventType", bound=IndexerEvent) -class IndexerResponseData(GenericModel, Generic[IndexerEventType]): +class IndexerResponseData(BaseModel, Generic[IndexerEventType]): index_status: IndexerStatus = Field(alias="indexStatus") stats: IndexerStats events: List[IndexerEventType] -class IndexerResponse(GenericModel, Generic[IndexerEventType]): +class IndexerResponse(BaseModel, Generic[IndexerEventType]): data: IndexerResponseData[IndexerEventType] diff --git a/src/aleph/schemas/pending_messages.py b/src/aleph/schemas/pending_messages.py index e37a6313f..28b741562 100644 --- a/src/aleph/schemas/pending_messages.py +++ b/src/aleph/schemas/pending_messages.py @@ -29,7 +29,7 @@ ProgramContent, StoreContent, ) -from pydantic import ValidationError, root_validator +from pydantic import ValidationError, model_validator import aleph.toolkit.json as aleph_json from aleph.exceptions import UnknownHashError @@ -45,8 +45,9 @@ class BasePendingMessage(AlephBaseMessage, Generic[MType, ContentType]): A raw Aleph message, as sent by users to the Aleph network. """ - @root_validator(pre=True) - def load_content(cls, values): + @model_validator(mode="before") + @classmethod + def load_content(cls, values: Any): """ Preload inline content. We let the CCN populate this field later on for ipfs and storage item types. diff --git a/src/aleph/services/ipfs/service.py b/src/aleph/services/ipfs/service.py index 8dd222d24..7b84867c5 100644 --- a/src/aleph/services/ipfs/service.py +++ b/src/aleph/services/ipfs/service.py @@ -165,7 +165,7 @@ async def add_file(self, file_content: bytes): resp = await session.post(url, data=data) resp.raise_for_status() - return await resp.json() + return await resp.model_dump_json() async def sub(self, topic: str): ipfs_client = self.ipfs_client diff --git a/src/aleph/services/p2p/http.py b/src/aleph/services/p2p/http.py index 53b92d537..5654d444f 100644 --- a/src/aleph/services/p2p/http.py +++ b/src/aleph/services/p2p/http.py @@ -28,7 +28,7 @@ async def api_get_request(base_uri, method, timeout=1): if resp.status != 200: result = None else: - result = await resp.json() + result = await resp.model_dump_json() except ( TimeoutError, asyncio.TimeoutError, diff --git a/src/aleph/types/message_processing_result.py b/src/aleph/types/message_processing_result.py index 705d6146d..4a0980d3a 100644 --- a/src/aleph/types/message_processing_result.py +++ b/src/aleph/types/message_processing_result.py @@ -32,7 +32,7 @@ def item_hash(self) -> str: def to_dict(self) -> Dict[str, Any]: return { "status": self.status.value, - "message": format_message(self.message).dict(), + "message": format_message(self.message).model_dump(), } diff --git a/src/aleph/web/controllers/accounts.py b/src/aleph/web/controllers/accounts.py index 141b08285..392c91681 100644 --- a/src/aleph/web/controllers/accounts.py +++ b/src/aleph/web/controllers/accounts.py @@ -75,7 +75,7 @@ async def get_account_balance(request: web.Request): return web.json_response( text=GetAccountBalanceResponse( address=address, balance=balance, locked_amount=total_cost - ).json() + ).model_dump_json() ) @@ -83,7 +83,7 @@ async def get_account_files(request: web.Request) -> web.Response: address = _get_address_from_request(request) try: - query_params = GetAccountFilesQueryParams.parse_obj(request.query) + query_params = GetAccountFilesQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -112,4 +112,4 @@ async def get_account_files(request: web.Request) -> web.Response: pagination_total=nb_files, pagination_per_page=query_params.pagination, ) - return web.json_response(text=response.json()) + return web.json_response(text=response.model_dump_json()) diff --git a/src/aleph/web/controllers/aggregates.py b/src/aleph/web/controllers/aggregates.py index 09b25b0a6..66b910850 100644 --- a/src/aleph/web/controllers/aggregates.py +++ b/src/aleph/web/controllers/aggregates.py @@ -3,7 +3,7 @@ from typing import Dict, List, Optional from aiohttp import web -from pydantic import BaseModel, ValidationError, validator +from pydantic import BaseModel, ValidationError, field_validator from sqlalchemy import select from aleph.db.accessors.aggregates import get_aggregates_by_owner, refresh_aggregate @@ -22,10 +22,8 @@ class AggregatesQueryParams(BaseModel): with_info: bool = False value_only: bool = False - @validator( - "keys", - pre=True, - ) + @field_validator("keys", mode="before") + @classmethod def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) @@ -40,7 +38,7 @@ async def address_aggregate(request: web.Request) -> web.Response: address: str = request.match_info["address"] try: - query_params = AggregatesQueryParams.parse_obj(request.query) + query_params = AggregatesQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity( text=e.json(), content_type="application/json" diff --git a/src/aleph/web/controllers/main.py b/src/aleph/web/controllers/main.py index 7fcb7aa8e..5b5c5b79e 100644 --- a/src/aleph/web/controllers/main.py +++ b/src/aleph/web/controllers/main.py @@ -98,7 +98,7 @@ async def ccn_metric(request: web.Request) -> web.Response: """Fetch metrics for CCN node id""" session_factory: DbSessionFactory = get_session_factory_from_request(request) - query_params = Metrics.parse_obj(request.query) + query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) @@ -124,7 +124,7 @@ async def crn_metric(request: web.Request) -> web.Response: """Fetch Metric for crn.""" session_factory: DbSessionFactory = get_session_factory_from_request(request) - query_params = Metrics.parse_obj(request.query) + query_params = Metrics.model_validate(request.query) node_id = _get_node_id_from_request(request) diff --git a/src/aleph/web/controllers/messages.py b/src/aleph/web/controllers/messages.py index 68d7f2c15..8794742dc 100644 --- a/src/aleph/web/controllers/messages.py +++ b/src/aleph/web/controllers/messages.py @@ -6,7 +6,14 @@ import aiohttp.web_ws from aiohttp import WSMsgType, web from aleph_message.models import Chain, ItemHash, MessageType -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) import aleph.toolkit.json as aleph_json from aleph.db.accessors.messages import ( @@ -110,19 +117,20 @@ class BaseMessageQueryParams(BaseModel): default=None, description="Accepted values for the 'item_hash' field." ) - @root_validator - def validate_field_dependencies(cls, values): - start_date = values.get("start_date") - end_date = values.get("end_date") + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date if start_date and end_date and (end_date < start_date): raise ValueError("end date cannot be lower than start date.") - start_block = values.get("start_block") - end_block = values.get("end_block") + start_block = self.start_block + end_block = self.end_block if start_block and end_block and (end_block < start_block): raise ValueError("end block cannot be lower than start block.") - return values - @validator( + return self + + @field_validator( "hashes", "addresses", "refs", @@ -133,15 +141,15 @@ def validate_field_dependencies(cls, values): "channels", "message_types", "tags", - pre=True, + mode="before", ) + @classmethod def split_str(cls, v): if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) return v - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) class MessageQueryParams(BaseMessageQueryParams): @@ -237,7 +245,7 @@ async def view_messages_list(request: web.Request) -> web.Response: """Messages list view with filters""" try: - query_params = MessageQueryParams.parse_obj(request.query) + query_params = MessageQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -246,7 +254,7 @@ async def view_messages_list(request: web.Request) -> web.Response: if url_page_param := request.match_info.get("page"): query_params.page = int(url_page_param) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -277,10 +285,10 @@ async def _send_history_to_ws( session=session, pagination=history, include_confirmations=True, - **query_params.dict(exclude_none=True), + **query_params.model_dump(exclude_none=True), ) for message in messages: - await ws.send_str(format_message(message).json()) + await ws.send_str(format_message(message).model_dump_json()) def message_matches_filters( @@ -357,7 +365,7 @@ async def _process_message(mq_message: aio_pika.abc.AbstractMessage): if message_matches_filters(message=message, query_params=query_params): try: - await ws.send_str(message.json()) + await ws.send_str(message.model_dump_json()) except ConnectionResetError: # We can detect the WS closing in this task in addition to the main one. # The main task will also detect the close event. @@ -382,7 +390,7 @@ async def messages_ws(request: web.Request) -> web.WebSocketResponse: mq_channel = await get_mq_ws_channel_from_request(request=request, logger=LOGGER) try: - query_params = WsMessageQueryParams.parse_obj(request.query) + query_params = WsMessageQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -521,7 +529,7 @@ async def view_message(request: web.Request): session=session, status_db=message_status_db ) - return web.json_response(text=message_with_status.json()) + return web.json_response(text=message_with_status.model_dump_json()) async def view_message_content(request: web.Request): diff --git a/src/aleph/web/controllers/metrics.py b/src/aleph/web/controllers/metrics.py index bd812c706..c0c79a7fd 100644 --- a/src/aleph/web/controllers/metrics.py +++ b/src/aleph/web/controllers/metrics.py @@ -117,7 +117,7 @@ async def fetch_reference_total_messages() -> Optional[int]: async with session.get( urljoin(url, "metrics.json"), raise_for_status=True ) as resp: - data = await resp.json() + data = await resp.model_dump_json() return int(data["pyaleph_status_sync_messages_total"]) except aiohttp.ClientResponseError: LOGGER.warning("ETH height could not be obtained") diff --git a/src/aleph/web/controllers/p2p.py b/src/aleph/web/controllers/p2p.py index 88939d724..7c1310429 100644 --- a/src/aleph/web/controllers/p2p.py +++ b/src/aleph/web/controllers/p2p.py @@ -112,7 +112,7 @@ async def pub_json(request: web.Request): pub_status = PublicationStatus.from_failures(failed_publications) return web.json_response( - text=pub_status.json(), + text=pub_status.model_dump_json(), status=500 if pub_status == "error" else 200, ) @@ -125,7 +125,7 @@ class PubMessageRequest(BaseModel): @shielded async def pub_message(request: web.Request): try: - request_data = PubMessageRequest.parse_obj(await request.json()) + request_data = PubMessageRequest.model_validate(await request.json()) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) except ValueError: @@ -142,4 +142,6 @@ async def pub_message(request: web.Request): ) status_code = broadcast_status_to_http_status(broadcast_status) - return web.json_response(text=broadcast_status.json(), status=status_code) + return web.json_response( + text=broadcast_status.model_dump_json(), status=status_code + ) diff --git a/src/aleph/web/controllers/posts.py b/src/aleph/web/controllers/posts.py index 773d7f6df..4443a832c 100644 --- a/src/aleph/web/controllers/posts.py +++ b/src/aleph/web/controllers/posts.py @@ -2,7 +2,14 @@ from aiohttp import web from aleph_message.models import ItemHash -from pydantic import BaseModel, Field, ValidationError, root_validator, validator +from pydantic import ( + BaseModel, + ConfigDict, + Field, + ValidationError, + field_validator, + model_validator, +) from sqlalchemy import select from aleph.db.accessors.posts import ( @@ -82,30 +89,25 @@ class PostQueryParams(BaseModel): "-1 means most recent messages first, 1 means older messages first.", ) - @root_validator - def validate_field_dependencies(cls, values): - start_date = values.get("start_date") - end_date = values.get("end_date") + @model_validator(mode="after") + def validate_field_dependencies(self): + start_date = self.start_date + end_date = self.end_date if start_date and end_date and (end_date < start_date): raise ValueError("end date cannot be lower than start date.") - return values - - @validator( - "addresses", - "hashes", - "refs", - "post_types", - "channels", - "tags", - pre=True, + + return self + + @field_validator( + "addresses", "hashes", "refs", "post_types", "channels", "tags", mode="before" ) - def split_str(cls, v): + @classmethod + def split_str(cls, v) -> List[str]: if isinstance(v, str): return v.split(LIST_FIELD_SEPARATOR) return v - class Config: - allow_population_by_field_name = True + model_config = ConfigDict(populate_by_name=True) def merged_post_to_dict(merged_post: MergedPost) -> Dict[str, Any]: @@ -173,7 +175,7 @@ def merged_post_v0_to_dict( def get_query_params(request: web.Request) -> PostQueryParams: try: - query_params = PostQueryParams.parse_obj(request.query) + query_params = PostQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -188,7 +190,7 @@ async def view_posts_list_v0(request: web.Request) -> web.Response: query_string = request.query_string query_params = get_query_params(request) - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination @@ -231,7 +233,7 @@ async def view_posts_list_v1(request) -> web.Response: query_string = request.query_string try: - query_params = PostQueryParams.parse_obj(request.query) + query_params = PostQueryParams.model_validate(request.query) except ValidationError as e: raise web.HTTPUnprocessableEntity(text=e.json(indent=4)) @@ -239,7 +241,7 @@ async def view_posts_list_v1(request) -> web.Response: if path_page: query_params.page = path_page - find_filters = query_params.dict(exclude_none=True) + find_filters = query_params.model_dump(exclude_none=True) pagination_page = query_params.page pagination_per_page = query_params.pagination diff --git a/src/aleph/web/controllers/programs.py b/src/aleph/web/controllers/programs.py index c5427ccf2..e898e7aca 100644 --- a/src/aleph/web/controllers/programs.py +++ b/src/aleph/web/controllers/programs.py @@ -1,5 +1,5 @@ from aiohttp import web -from pydantic import BaseModel, ValidationError +from pydantic import BaseModel, ConfigDict, ValidationError from aleph.db.accessors.messages import get_programs_triggered_by_messages from aleph.types.db_session import DbSessionFactory @@ -8,9 +8,7 @@ class GetProgramQueryFields(BaseModel): sort_order: SortOrder = SortOrder.DESCENDING - - class Config: - extra = "forbid" + model_config = ConfigDict(extra="forbid") async def get_programs_on_message(request: web.Request) -> web.Response: diff --git a/src/aleph/web/controllers/storage.py b/src/aleph/web/controllers/storage.py index cb0ff2c69..8ce01d5a5 100644 --- a/src/aleph/web/controllers/storage.py +++ b/src/aleph/web/controllers/storage.py @@ -224,7 +224,7 @@ async def _check_and_add_file( raise web.HTTPUnprocessableEntity(reason="Store message content needed") try: - message_content = StoreContent.parse_raw(message.item_content) + message_content = StoreContent.model_validate_json(message.item_content) if message_content.item_hash != file_hash: raise web.HTTPUnprocessableEntity( reason=f"File hash does not match ({file_hash} != {message_content.item_hash})" @@ -328,7 +328,7 @@ async def storage_add_file(request: web.Request): metadata.file.read() if isinstance(metadata, FileField) else metadata ) try: - storage_metadata = StorageMetadata.parse_raw(metadata_bytes) + storage_metadata = StorageMetadata.model_validate_json(metadata_bytes) except ValidationError as e: raise web.HTTPUnprocessableEntity( reason=f"Could not decode metadata: {e.json()}" diff --git a/src/aleph/web/controllers/utils.py b/src/aleph/web/controllers/utils.py index 6cbcc0ebb..7e0a4ac7b 100644 --- a/src/aleph/web/controllers/utils.py +++ b/src/aleph/web/controllers/utils.py @@ -296,11 +296,11 @@ async def pub_on_p2p_topics( class BroadcastStatus(BaseModel): publication_status: PublicationStatus - message_status: Optional[MessageStatus] + message_status: Optional[MessageStatus] = None def broadcast_status_to_http_status(broadcast_status: BroadcastStatus) -> int: - if broadcast_status.publication_status == "error": + if broadcast_status.publication_status.status == "error": return 500 message_status = broadcast_status.message_status @@ -311,7 +311,7 @@ def broadcast_status_to_http_status(broadcast_status: BroadcastStatus) -> int: def format_pending_message_dict(pending_message: BasePendingMessage) -> Dict[str, Any]: - pending_message_dict = pending_message.dict(exclude_none=True) + pending_message_dict = pending_message.model_dump(exclude_none=True) pending_message_dict["time"] = pending_message_dict["time"].timestamp() return pending_message_dict diff --git a/tests/api/test_balance.py b/tests/api/test_balance.py index c37d1b44c..e3abf315f 100644 --- a/tests/api/test_balance.py +++ b/tests/api/test_balance.py @@ -23,5 +23,5 @@ async def test_get_balance( response = await ccn_api_client.get(MESSAGES_URI) assert response.status == 200, await response.text() data = await response.json() - assert data["balance"] == user_balance.balance - assert data["locked_amount"] == 2002.4666666666667 + assert data["balance"] == str(user_balance.balance) + assert data["locked_amount"] == "2002.46666666666669698315672576427459716796875" diff --git a/tests/api/test_list_messages.py b/tests/api/test_list_messages.py index 6e9b2582a..b191e3a7b 100644 --- a/tests/api/test_list_messages.py +++ b/tests/api/test_list_messages.py @@ -14,7 +14,6 @@ from aleph_message.models.execution.volume import ( ImmutableVolume, ParentVolume, - PersistentVolumeSizeMib, VolumePersistence, ) @@ -228,16 +227,32 @@ async def test_get_messages_filter_by_tags( assert messages[0]["item_hash"] == amend_message_db.item_hash -@pytest.mark.parametrize("type_field", ("msgType", "msgTypes")) @pytest.mark.asyncio -async def test_get_by_message_type(fixture_messages, ccn_api_client, type_field: str): +async def test_get_by_deprecated_message_type(fixture_messages, ccn_api_client): messages_by_type = defaultdict(list) for message in fixture_messages: messages_by_type[message["type"]].append(message) for message_type, expected_messages in messages_by_type.items(): response = await ccn_api_client.get( - MESSAGES_URI, params={type_field: message_type} + MESSAGES_URI, params={"msgType": message_type} + ) + assert response.status == 200, await response.text() + messages = (await response.json())["messages"] + assert set(msg["item_hash"] for msg in messages) == set( + msg["item_hash"] for msg in expected_messages + ) + + +@pytest.mark.asyncio +async def test_get_by_message_type(fixture_messages, ccn_api_client): + messages_by_type = defaultdict(list) + for message in fixture_messages: + messages_by_type[message["type"]].append(message) + + for message_type, expected_messages in messages_by_type.items(): + response = await ccn_api_client.get( + MESSAGES_URI, params={"msgTypes": [message_type]} ) assert response.status == 200, await response.text() messages = (await response.json())["messages"] @@ -517,7 +532,7 @@ def instance_message_fixture() -> MessageDb: ) ), persistence=VolumePersistence("host"), - size_mib=PersistentVolumeSizeMib(1024), + size_mib=1024, ), volumes=[ ImmutableVolume( @@ -527,7 +542,7 @@ def instance_message_fixture() -> MessageDb: use_latest=True, ) ], - ).dict(), + ).model_dump(), size=3000, time=timestamp_to_datetime(1686572207.89381), channel=Channel("TEST"), diff --git a/tests/chains/test_chain_data_service.py b/tests/chains/test_chain_data_service.py index b34a5522c..722d8573f 100644 --- a/tests/chains/test_chain_data_service.py +++ b/tests/chains/test_chain_data_service.py @@ -1,4 +1,5 @@ import datetime as dt +import json import pytest from aleph_message.models import ( @@ -44,7 +45,7 @@ async def mock_add_file( session: DbSession, file_content: bytes, engine: ItemType = ItemType.ipfs ) -> str: content = file_content - archive = OnChainSyncEventPayload.parse_raw(content) + archive = OnChainSyncEventPayload.model_validate_json(content) assert archive.version == 1 assert len(archive.content.messages) == len(messages) @@ -86,7 +87,7 @@ async def test_smart_contract_protocol_ipfs_store( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(), ) chain_data_service = ChainDataService( @@ -112,7 +113,7 @@ async def test_smart_contract_protocol_ipfs_store( assert pending_message.channel is None assert pending_message.item_content - message_content = StoreContent.parse_raw(pending_message.item_content) + message_content = StoreContent.model_validate_json(pending_message.item_content) assert message_content.item_hash == payload.message_content assert message_content.item_type == ItemType.ipfs assert message_content.address == payload.addr @@ -135,7 +136,7 @@ async def test_smart_contract_protocol_regular_message( timestamp=1668611900, addr="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", msgtype="POST", - msgcontent=content.json(), + msgcontent=json.dumps(content.model_dump()), ) tx = ChainTxDb( @@ -146,7 +147,7 @@ async def test_smart_contract_protocol_regular_message( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(mode="json"), ) chain_data_service = ChainDataService( @@ -172,7 +173,7 @@ async def test_smart_contract_protocol_regular_message( assert pending_message.channel is None assert pending_message.item_content - message_content = PostContent.parse_raw(pending_message.item_content) + message_content = PostContent.model_validate_json(pending_message.item_content) assert message_content.address == content.address assert message_content.time == content.time assert message_content.ref == content.ref diff --git a/tests/chains/test_tezos.py b/tests/chains/test_tezos.py index aa0d93e4c..bb9593cb6 100644 --- a/tests/chains/test_tezos.py +++ b/tests/chains/test_tezos.py @@ -139,4 +139,4 @@ def test_indexer_event_to_aleph_message(message_type: str, message_content: str) assert tx.protocol == ChainSyncProtocol.SMART_CONTRACT assert tx.protocol_version == 1 - assert tx.content == indexer_event.payload.dict() + assert tx.content == indexer_event.payload.model_dump() diff --git a/tests/conftest.py b/tests/conftest.py index 2a9370c96..a074cb568 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -310,7 +310,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): Insert volume references in the DB to make the program processable. """ - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) diff --git a/tests/db/test_cost.py b/tests/db/test_cost.py index 1226aeae0..f7f09ca67 100644 --- a/tests/db/test_cost.py +++ b/tests/db/test_cost.py @@ -65,7 +65,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): """ if message.item_content: - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) diff --git a/tests/message_processing/test_process_confidential.py b/tests/message_processing/test_process_confidential.py index 8ac51aa90..3e016d7fa 100644 --- a/tests/message_processing/test_process_confidential.py +++ b/tests/message_processing/test_process_confidential.py @@ -194,7 +194,7 @@ def get_volume_refs(content: ExecutableContent) -> List[ImmutableVolume]: def insert_volume_refs(session: DbSession, message: PendingMessageDb): item_content = message.item_content if message.item_content is not None else "" - content = InstanceContent.parse_raw(item_content) + content = InstanceContent.model_validate_json(item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) diff --git a/tests/message_processing/test_process_instances.py b/tests/message_processing/test_process_instances.py index 55594f29d..fefd59749 100644 --- a/tests/message_processing/test_process_instances.py +++ b/tests/message_processing/test_process_instances.py @@ -22,6 +22,7 @@ ProgramContent, ) from aleph_message.models.execution.volume import ImmutableVolume, ParentVolume +from aleph_message.utils import Gigabytes, gigabyte_to_mebibyte from more_itertools import one from sqlalchemy import text @@ -154,7 +155,7 @@ def user_balance(session_factory: DbSessionFactory) -> AlephBalanceDb: balance = AlephBalanceDb( address="0x9319Ad3B7A8E0eE24f2E639c40D8eD124C5520Ba", chain=Chain.ETH, - balance=Decimal(22_192), + balance=Decimal(50_000), eth_height=0, ) @@ -227,7 +228,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): """ assert message.item_content - content = InstanceContent.parse_raw(message.item_content) + content = InstanceContent.model_validate_json(message.item_content) volumes = get_volume_refs(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) @@ -375,7 +376,9 @@ async def test_process_instance_missing_volumes( assert rejected_message.error_code == ErrorCode.VM_VOLUME_NOT_FOUND if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) volume_refs = set(volume.ref for volume in get_volume_refs(content)) assert isinstance(rejected_message.details, dict) assert set(rejected_message.details["errors"]) == volume_refs @@ -453,7 +456,9 @@ async def test_get_volume_size( session.commit() if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) with session_factory() as session: volume_size = get_volume_size(session=session, content=content) assert volume_size == 21512585216 @@ -469,7 +474,9 @@ async def test_get_additional_storage_price( session.commit() if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) with session_factory() as session: additional_price = get_additional_storage_price( content=content, session=session @@ -487,7 +494,9 @@ async def test_get_compute_cost( session.commit() if fixture_instance_message.item_content: - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json( + fixture_instance_message.item_content + ) with session_factory() as session: price: Decimal = compute_cost(content=content, session=session) assert price == Decimal("2001.8") @@ -509,7 +518,7 @@ async def test_compare_cost_view_with_cost_function( _ = [message async for message in pipeline] assert fixture_instance_message.item_content - content = InstanceContent.parse_raw(fixture_instance_message.item_content) + content = InstanceContent.model_validate_json(fixture_instance_message.item_content) with session_factory() as session: cost_from_function: Decimal = compute_cost(session=session, content=content) cost_from_view = session.execute( @@ -518,3 +527,48 @@ async def test_compare_cost_view_with_cost_function( ).scalar_one() assert Decimal(str(cost_from_view)) == cost_from_function + + +@pytest.mark.asyncio +async def test_persistent_volume_500_GB( + session_factory: DbSessionFactory, + message_processor: PendingMessageProcessor, + fixture_instance_message: PendingMessageDb, + user_balance: AlephBalanceDb, +): + assert fixture_instance_message.item_content + content_dict = json.loads(fixture_instance_message.item_content) + # Update the 'store' volume to 1000 GiB + for volume in content_dict["volumes"]: + if volume.get("persistence") == "store" and volume.get("name") == "statistics": + volume["size_mib"] = gigabyte_to_mebibyte(Gigabytes(500)) + + fixture_instance_message.item_content = json.dumps(content_dict) + + with session_factory() as session: + session.merge(fixture_instance_message) # Update previous change of volume size + insert_volume_refs(session, fixture_instance_message) + session.commit() + + pipeline = message_processor.make_pipeline() + _ = [message async for message in pipeline] + + with session_factory() as session: + instance = get_instance( + session=session, item_hash=fixture_instance_message.item_hash + ) + assert instance is not None + + persistent_volume = next( + ( + vol + for vol in instance.volumes + if isinstance(vol, PersistentVolumeDb) + and vol.persistence == "store" + and vol.name == "statistics" + ), + None, + ) + + assert persistent_volume is not None, "PersistentVolume not found" + assert persistent_volume.size_mib == gigabyte_to_mebibyte(Gigabytes(500)) diff --git a/tests/message_processing/test_process_pending_txs.py b/tests/message_processing/test_process_pending_txs.py index d9f3b6086..5e951dfe4 100644 --- a/tests/message_processing/test_process_pending_txs.py +++ b/tests/message_processing/test_process_pending_txs.py @@ -136,7 +136,7 @@ async def _process_smart_contract_tx( publisher="KT1BfL57oZfptdtMFZ9LNakEPvuPPA2urdSW", protocol=ChainSyncProtocol.SMART_CONTRACT, protocol_version=1, - content=payload.dict(), + content=payload.model_dump(), ) pending_tx = PendingTxDb(tx=tx) @@ -214,7 +214,7 @@ async def test_process_pending_smart_contract_tx_post( type="my-type", address="KT1VBeLD7hzKpj17aRJ3Kc6QQFeikCEXi7W6", time=1000, - ).json(), + ).model_dump_json(), ) await _process_smart_contract_tx( diff --git a/tests/message_processing/test_process_programs.py b/tests/message_processing/test_process_programs.py index 9c00c557d..f6a631a0c 100644 --- a/tests/message_processing/test_process_programs.py +++ b/tests/message_processing/test_process_programs.py @@ -117,7 +117,7 @@ def insert_volume_refs(session: DbSession, message: PendingMessageDb): """ assert message.item_content - content = ProgramContent.parse_raw(message.item_content) + content = ProgramContent.model_validate_json(message.item_content) volumes = get_volumes_with_ref(content) created = pytz.utc.localize(dt.datetime(2023, 1, 1)) @@ -298,7 +298,7 @@ async def test_process_program_missing_volumes( assert rejected_message.error_code == ErrorCode.VM_VOLUME_NOT_FOUND assert program_message.item_content - content = ProgramContent.parse_raw(program_message.item_content) + content = ProgramContent.model_validate_json(program_message.item_content) volume_refs = set(volume.ref for volume in get_volumes_with_ref(content)) assert isinstance(rejected_message.details, dict) assert set(rejected_message.details["errors"]) == volume_refs diff --git a/tests/schemas/test_pending_messages.py b/tests/schemas/test_pending_messages.py index e48ee076e..34477ddda 100644 --- a/tests/schemas/test_pending_messages.py +++ b/tests/schemas/test_pending_messages.py @@ -170,7 +170,7 @@ def test_parse_program_message(): content = json.loads(message_dict["item_content"]) assert message.content.address == content["address"] assert message.content.time == content["time"] - assert message.content.code.dict(exclude_none=True) == content["code"] + assert message.content.code.model_dump(exclude_none=True) == content["code"] assert message.content.type == content["type"]