Skip to content

Commit

Permalink
add docstrings
Browse files Browse the repository at this point in the history
  • Loading branch information
masci committed Oct 23, 2024
1 parent 33b7486 commit 7a766da
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 5 deletions.
2 changes: 1 addition & 1 deletion llama_deploy/apiserver/routers/deployments.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ async def get_task_result(
async def get_tasks(
deployment_name: str,
) -> JSONResponse:
"""Get the active sessions in a deployment and service."""
"""Get all the tasks from all the sessions in a given deployment."""
deployment = manager.get_deployment(deployment_name)
if deployment is None:
raise HTTPException(status_code=404, detail="Deployment not found")
Expand Down
21 changes: 19 additions & 2 deletions llama_deploy/client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,17 +3,34 @@


class Client(_BaseClient):
"""Fixme.
"""The Llama Deploy Python client.
Fixme.
The client is gives access to both the asyncio and non-asyncio APIs. To access the sync
API just use methods of `client.sync`.
Example usage:
```py
from llama_deploy.client import Client
# Use the same client instance
c = Client()
async def an_async_function():
status = await client.apiserver.status()
def normal_function():
status = client.sync.apiserver.status()
```
"""

@property
def sync(self) -> "Client":
"""Returns the sync version of the client API."""
return _SyncClient(**self.settings.model_dump())

@property
def apiserver(self) -> ApiServer:
"""Returns the ApiServer model."""
return ApiServer.instance(client=self, id="apiserver")


Expand Down
6 changes: 6 additions & 0 deletions llama_deploy/client/client_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,12 @@


class ClientSettings(BaseSettings):
"""The global settings for a Client instance.
Settings can be manually defined before creating a Client instance, or defined with environment variables having
names prefixed with the string `LLAMA_DEPLOY_`, e.g. `LLAMA_DEPLOY_DISABLE_SSL`.
"""

model_config = SettingsConfigDict(env_prefix="LLAMA_DEPLOY_")

api_server_url: str = "http://localhost:4501"
Expand Down
36 changes: 34 additions & 2 deletions llama_deploy/client/models/apiserver.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,25 @@


class Session(Model):
"""A model representing a session."""

pass


class SessionCollection(Collection):
"""A model representing a collection of session for a given deployment."""

deployment_id: str

async def delete(self, session_id: str) -> None:
"""Deletes the session with the provided `session_id`.
Args:
session_id: The id of the session that will be removed
Raises:
HTTPException: If the session couldn't be found with the id provided.
"""
settings = self.client.settings
delete_url = f"{settings.api_server_url}/deployments/{self.deployment_id}/sessions/delete"

Expand All @@ -33,6 +45,8 @@ async def delete(self, session_id: str) -> None:


class Task(Model):
"""A model representing a task belonging to a given session in the given deployment."""

deployment_id: str
session_id: str

Expand All @@ -51,6 +65,7 @@ async def results(self) -> TaskResult:
return TaskResult.model_validate_json(r.json())

async def events(self) -> AsyncGenerator[dict[str, Any], None]: # pragma: no cover
"""Returns a generator object to consume the events streamed from a service."""
settings = self.client.settings
events_url = f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/{self.id}/events"

Expand All @@ -72,9 +87,16 @@ async def events(self) -> AsyncGenerator[dict[str, Any], None]: # pragma: no co


class TaskCollection(Collection):
"""A model representing a collection of tasks for a given deployment."""

deployment_id: str

async def run(self, task: TaskDefinition) -> Any:
"""Runs a task and returns the results once it's done.
Args:
task: The definition of the task we want to run.
"""
settings = self.client.settings
run_url = (
f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/run"
Expand All @@ -91,6 +113,7 @@ async def run(self, task: TaskDefinition) -> Any:
return r.json()

async def create(self, task: TaskDefinition) -> Task:
"""Runs a task returns it immediately, without waiting for the results."""
settings = self.client.settings
create_url = (
f"{settings.api_server_url}/deployments/{self.deployment_id}/tasks/create"
Expand All @@ -115,7 +138,10 @@ async def create(self, task: TaskDefinition) -> Task:


class Deployment(Model):
"""A model representing a deployment."""

async def tasks(self) -> TaskCollection:
"""Returns a collection of tasks from all the sessions in the given deployment."""
settings = self.client.settings
tasks_url = f"{settings.api_server_url}/deployments/{self.id}/tasks"
r = await self.client.request(
Expand All @@ -142,6 +168,7 @@ async def tasks(self) -> TaskCollection:
)

async def sessions(self) -> SessionCollection:
"""Returns a collection of all the sessions in the given deployment."""
settings = self.client.settings
sessions_url = f"{settings.api_server_url}/deployments/{self.id}/sessions"
r = await self.client.request(
Expand All @@ -167,8 +194,10 @@ async def sessions(self) -> SessionCollection:


class DeploymentCollection(Collection):
"""A model representing a collection of deployments currently active."""

async def create(self, config: TextIO) -> Deployment:
"""Creates a deployment"""
"""Creates a new deployment from a deployment file."""
settings = self.client.settings
create_url = f"{settings.api_server_url}/deployments/create"

Expand All @@ -188,7 +217,7 @@ async def create(self, config: TextIO) -> Deployment:
)

async def get(self, deployment_id: str) -> Deployment:
"""Get a deployment by id"""
"""Gets a deployment by id."""
settings = self.client.settings
get_url = f"{settings.api_server_url}/deployments/{deployment_id}"
# Current version of apiserver doesn't returns anything useful in this endpoint, let's just ignore it
Expand All @@ -201,6 +230,8 @@ async def get(self, deployment_id: str) -> Deployment:


class ApiServer(Model):
"""A model representing the API Server instance."""

async def status(self) -> Status:
"""Returns the status of the API Server."""
settings = self.client.settings
Expand Down Expand Up @@ -240,6 +271,7 @@ async def status(self) -> Status:
)

async def deployments(self) -> DeploymentCollection:
"""Returns a collection of deployments currently active in the API Server."""
settings = self.client.settings
status_url = f"{settings.api_server_url}/deployments/"

Expand Down
16 changes: 16 additions & 0 deletions llama_deploy/client/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,16 +9,24 @@


class _Base(BaseModel):
"""The base model provides fields and functionalities common to derived models and collections."""

client: _BaseClient = Field(exclude=True)
_instance_is_sync: bool = PrivateAttr(default=False)

model_config = ConfigDict(arbitrary_types_allowed=True)

def __new__(cls, *args, **kwargs): # type: ignore[no-untyped-def]
"""We prevent the usage of the constructor and force users to call `instance()` instead."""
raise TypeError("Please use instance() instead of direct instantiation")

@classmethod
def instance(cls, make_sync: bool = False, **kwargs: Any) -> Self:
"""Returns an instance of the given model.
Using the class constructor is not possible because we want to alter the class method to
accommodate sync/async usage before creating an instance, and __init__ would be too late.
"""
if make_sync:
cls = _make_sync(cls)

Expand All @@ -36,20 +44,28 @@ class Model(_Base):


class Collection(_Base, Generic[T]):
"""A generic container of items of the same model type."""

items: dict[str, T]

def get(self, id: str) -> T:
"""Returns an item from the collection."""
return self.items[id]

def list(self) -> list[T]:
"""Returns a list of all the items in the collection."""
return [self.get(id) for id in self.items.keys()]


def _make_sync(_class: type[T]) -> type[T]:
"""Wraps the methods of the given model class so that they can be called without `await`."""

class Wrapper(_class): # type: ignore
pass

for name, method in _class.__dict__.items():
# Only wrap async public methods
if asyncio.iscoroutinefunction(method) and not name.startswith("_"):
setattr(Wrapper, name, async_to_sync(method))
# Static type checkers can't assess Wrapper is indeed a type[T], let's promise it is.
return cast(type[T], Wrapper)

0 comments on commit 7a766da

Please sign in to comment.