Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

deps: compat with pydantic v1 #102

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions .github/workflows/build_publish.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,22 @@ jobs:
runs-on: ubuntu-latest
strategy:
matrix:
python-version: ["3.9", "3.10", "3.11", "3.12"]
python-version: [
"3.9",
"3.10",
"3.11",
"3.12"
]
pydantic-version:
[
"^1.10,<2",
"^2,<3"
]
steps:
# Checkout the repository
- name: Checkout
uses: actions/checkout@v4

# Set python version to 3.11
- name: set python version
uses: actions/setup-python@v4
with:
Expand All @@ -35,7 +44,8 @@ jobs:
run: |
pip install poetry \
&& poetry config virtualenvs.create false \
&& poetry install
&& poetry install \
&& poetry add "pydantic==${{ matrix.pydantic-version }}"

# Ruff
- name: Ruff check
Expand Down
155 changes: 43 additions & 112 deletions poetry.lock

Large diffs are not rendered by default.

6 changes: 3 additions & 3 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "mistralai"
version = "0.4.0"
version = "0.4.1"
description = ""
authors = ["Bam4d <[email protected]>"]
readme = "README.md"
Expand All @@ -18,14 +18,14 @@ disallow_untyped_defs = true
show_error_codes = true
no_implicit_optional = true
warn_return_any = true
warn_unused_ignores = true
warn_unused_ignores = false
exclude = ["docs", "tests", "examples", "tools", "build"]


[tool.poetry.dependencies]
python = "^3.9,<4.0"
orjson = "^3.9.10,<3.11"
pydantic = "^2.5.2,<3"
pydantic = "^1.10,<3"
httpx = "^0.25,<1"


Expand Down
2 changes: 1 addition & 1 deletion src/mistralai/client_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def _parse_tool_choice(self, tool_choice: Union[str, ToolChoice]) -> str:
return tool_choice.value
return tool_choice

def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Dict[str, Any]:
def _parse_response_format(self, response_format: Union[Dict[str, Any], ResponseFormat]) -> Any:
if isinstance(response_format, ResponseFormat):
return response_format.model_dump(exclude_none=True)
return response_format
Expand Down
4 changes: 2 additions & 2 deletions src/mistralai/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def create(
"model": model,
"training_files": training_files,
"validation_files": validation_files,
"hyperparameters": hyperparameters.dict(),
"hyperparameters": hyperparameters.model_dump(),
"suffix": suffix,
"integrations": integrations,
},
Expand Down Expand Up @@ -120,7 +120,7 @@ async def create(
"model": model,
"training_files": training_files,
"validation_files": validation_files,
"hyperparameters": hyperparameters.dict(),
"hyperparameters": hyperparameters.model_dump(),
"suffix": suffix,
"integrations": integrations,
},
Expand Down
25 changes: 25 additions & 0 deletions src/mistralai/models/base_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
from typing import Any, TypeVar

import pydantic
from pydantic import BaseModel

Model = TypeVar("Model", bound="BaseModel")
IS_V1 = pydantic.VERSION.startswith("1.")


class BackwardCompatibleBaseModel(BaseModel):
def model_dump(self, *args: Any, **kwargs: Any) -> Any:
if IS_V1:
return self.dict(*args, **kwargs)
return super().model_dump(*args, **kwargs) # type: ignore

@classmethod
def model_validate_json(cls: type[Model], *args: Any, **kwargs: Any) -> Model:
if IS_V1:
return cls.parse_raw(*args, **kwargs)
return super().model_validate_json(*args, **kwargs) # type: ignore

def model_dump_json(self, *args: Any, **kwargs: Any) -> str:
if IS_V1:
return self.json(*args, **kwargs)
return super().model_dump_json(*args, **kwargs) # type: ignore
23 changes: 11 additions & 12 deletions src/mistralai/models/chat_completion.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from enum import Enum
from typing import List, Optional

from pydantic import BaseModel

from mistralai.models.base_model import BackwardCompatibleBaseModel
from mistralai.models.common import UsageInfo


class Function(BaseModel):
class Function(BackwardCompatibleBaseModel):
name: str
description: str
parameters: dict
Expand All @@ -16,12 +15,12 @@ class ToolType(str, Enum):
function = "function"


class FunctionCall(BaseModel):
class FunctionCall(BackwardCompatibleBaseModel):
name: str
arguments: str


class ToolCall(BaseModel):
class ToolCall(BackwardCompatibleBaseModel):
id: str = "null"
type: ToolType = ToolType.function
function: FunctionCall
Expand All @@ -38,19 +37,19 @@ class ToolChoice(str, Enum):
none: str = "none"


class ResponseFormat(BaseModel):
class ResponseFormat(BackwardCompatibleBaseModel):
type: ResponseFormats = ResponseFormats.text


class ChatMessage(BaseModel):
class ChatMessage(BackwardCompatibleBaseModel):
role: str
content: str
name: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
tool_call_id: Optional[str] = None


class DeltaMessage(BaseModel):
class DeltaMessage(BackwardCompatibleBaseModel):
role: Optional[str] = None
content: Optional[str] = None
tool_calls: Optional[List[ToolCall]] = None
Expand All @@ -63,13 +62,13 @@ class FinishReason(str, Enum):
tool_calls = "tool_calls"


class ChatCompletionResponseStreamChoice(BaseModel):
class ChatCompletionResponseStreamChoice(BackwardCompatibleBaseModel):
index: int
delta: DeltaMessage
finish_reason: Optional[FinishReason]


class ChatCompletionStreamResponse(BaseModel):
class ChatCompletionStreamResponse(BackwardCompatibleBaseModel):
id: str
model: str
choices: List[ChatCompletionResponseStreamChoice]
Expand All @@ -78,13 +77,13 @@ class ChatCompletionStreamResponse(BaseModel):
usage: Optional[UsageInfo] = None


class ChatCompletionResponseChoice(BaseModel):
class ChatCompletionResponseChoice(BackwardCompatibleBaseModel):
index: int
message: ChatMessage
finish_reason: Optional[FinishReason]


class ChatCompletionResponse(BaseModel):
class ChatCompletionResponse(BackwardCompatibleBaseModel):
id: str
object: str
created: int
Expand Down
4 changes: 2 additions & 2 deletions src/mistralai/models/common.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Optional

from pydantic import BaseModel
from mistralai.models.base_model import BackwardCompatibleBaseModel


class UsageInfo(BaseModel):
class UsageInfo(BackwardCompatibleBaseModel):
prompt_tokens: int
total_tokens: int
completion_tokens: Optional[int]
7 changes: 3 additions & 4 deletions src/mistralai/models/embeddings.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,16 @@
from typing import List

from pydantic import BaseModel

from mistralai.models.base_model import BackwardCompatibleBaseModel
from mistralai.models.common import UsageInfo


class EmbeddingObject(BaseModel):
class EmbeddingObject(BackwardCompatibleBaseModel):
object: str
embedding: List[float]
index: int


class EmbeddingResponse(BaseModel):
class EmbeddingResponse(BackwardCompatibleBaseModel):
id: str
object: str
data: List[EmbeddingObject]
Expand Down
8 changes: 4 additions & 4 deletions src/mistralai/models/files.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import Literal, Optional

from pydantic import BaseModel
from mistralai.models.base_model import BackwardCompatibleBaseModel


class FileObject(BaseModel):
class FileObject(BackwardCompatibleBaseModel):
id: str
object: str
bytes: int
Expand All @@ -12,12 +12,12 @@ class FileObject(BaseModel):
purpose: Optional[Literal["fine-tune"]] = "fine-tune"


class FileDeleted(BaseModel):
class FileDeleted(BackwardCompatibleBaseModel):
id: str
object: str
deleted: bool


class Files(BaseModel):
class Files(BackwardCompatibleBaseModel):
data: list[FileObject]
object: Literal["list"]
28 changes: 15 additions & 13 deletions src/mistralai/models/jobs.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
from datetime import datetime
from typing import Annotated, List, Literal, Optional, Union
from typing import List, Literal, Optional, Union

from pydantic import BaseModel, Field
from pydantic import Field

from mistralai.models.base_model import BackwardCompatibleBaseModel

class TrainingParameters(BaseModel):

class TrainingParameters(BackwardCompatibleBaseModel):
training_steps: int = Field(1800, le=10000, ge=1)
learning_rate: float = Field(1.0e-4, le=1, ge=1.0e-8)


class WandbIntegration(BaseModel):
class WandbIntegration(BackwardCompatibleBaseModel):
type: Literal["wandb"] = "wandb"
project: str
name: Union[str, None] = None
Expand All @@ -20,11 +22,11 @@ class WandbIntegrationIn(WandbIntegration):
api_key: str


Integration = Annotated[Union[WandbIntegration], Field(discriminator="type")]
IntegrationIn = Annotated[Union[WandbIntegrationIn], Field(discriminator="type")]
Integration = Union[WandbIntegration]
IntegrationIn = Union[WandbIntegrationIn]


class JobMetadata(BaseModel):
class JobMetadata(BackwardCompatibleBaseModel):
object: Literal["job.metadata"] = "job.metadata"
training_steps: int
train_tokens_per_step: int
Expand All @@ -34,7 +36,7 @@ class JobMetadata(BaseModel):
expected_duration_seconds: Optional[int]


class Job(BaseModel):
class Job(BackwardCompatibleBaseModel):
id: str
hyperparameters: TrainingParameters
fine_tuned_model: Union[str, None]
Expand All @@ -57,25 +59,25 @@ class Job(BaseModel):
integrations: List[Integration] = []


class Event(BaseModel):
class Event(BackwardCompatibleBaseModel):
name: str
data: Union[dict, None] = None
created_at: int


class Metric(BaseModel):
class Metric(BackwardCompatibleBaseModel):
train_loss: Union[float, None] = None
valid_loss: Union[float, None] = None
valid_mean_token_accuracy: Union[float, None] = None


class Checkpoint(BaseModel):
class Checkpoint(BackwardCompatibleBaseModel):
metrics: Metric
step_number: int
created_at: int


class JobQueryFilter(BaseModel):
class JobQueryFilter(BackwardCompatibleBaseModel):
page: int = 0
page_size: int = 100
model: Optional[str] = None
Expand All @@ -93,6 +95,6 @@ class DetailedJob(Job):
estimated_start_time: Optional[int] = None


class Jobs(BaseModel):
class Jobs(BackwardCompatibleBaseModel):
data: list[Job] = []
object: Literal["list"]
8 changes: 4 additions & 4 deletions src/mistralai/models/models.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
from typing import List, Optional

from pydantic import BaseModel
from mistralai.models.base_model import BackwardCompatibleBaseModel


class ModelPermission(BaseModel):
class ModelPermission(BackwardCompatibleBaseModel):
id: str
object: str
created: int
Expand All @@ -18,7 +18,7 @@ class ModelPermission(BaseModel):
is_blocking: Optional[bool] = False


class ModelCard(BaseModel):
class ModelCard(BackwardCompatibleBaseModel):
id: str
object: str
created: int
Expand All @@ -28,6 +28,6 @@ class ModelCard(BaseModel):
permission: List[ModelPermission] = []


class ModelList(BaseModel):
class ModelList(BackwardCompatibleBaseModel):
object: str
data: List[ModelCard]
6 changes: 3 additions & 3 deletions tests/test_files.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ def test_create_file(self, client):
expected_response_file = FileObject.model_validate_json(mock_file_response_payload())
client._client.request.return_value = mock_response(
200,
expected_response_file.json(),
expected_response_file.model_dump_json(),
)

response_file = client.files.create(b"file_content")
Expand All @@ -36,7 +36,7 @@ def test_retrieve(self, client):
expected_response_file = FileObject.model_validate_json(mock_file_response_payload())
client._client.request.return_value = mock_response(
200,
expected_response_file.json(),
expected_response_file.model_dump_json(),
)

response_file = client.files.retrieve("file_id")
Expand Down Expand Up @@ -86,7 +86,7 @@ def test_list_files(self, client):

def test_delete_file(self, client):
expected_response_file = FileDeleted.model_validate_json(mock_file_deleted_response_payload())
client._client.request.return_value = mock_response(200, expected_response_file.json())
client._client.request.return_value = mock_response(200, expected_response_file.model_dump_json())

response_file = client.files.delete("file_id")

Expand Down
Loading
Loading