Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Cerebras LLM #96

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 53 additions & 0 deletions src/openagi/llms/cerebras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from typing import Any
from langchain_core.messages import HumanMessage
from openagi.exception import OpenAGIException
from openagi.llms.base import LLMBaseModel, LLMConfigModel
from openagi.utils.yamlParse import read_from_env

try:
from langchain_cerebras import ChatCerebras
except ImportError:
raise OpenAGIException("Install langchain-cerebras with cmd `pip install langchain-cerebras`")

class CerebrasConfigModel(LLMConfigModel):

sk5268 marked this conversation as resolved.
Show resolved Hide resolved
"""cloud.cerebras.ai"""
"""Configuration model for Cerebras."""
"""Available models: llama-3.3-70b || llama-3.1-70b || llama-3.1-8b """

cerebras_api_key: str
model_name: str = "llama3.1-8b"
temperature: float = 0.7

class CerebrasModel(LLMBaseModel):
"""Cerebras LLM implementation of the LLMBaseModel."""

config: Any

def load(self):
"""Initializes the Cerebras LLM instance with configurations."""
self.llm = ChatCerebras(
api_key=self.config.cerebras_api_key,
model_name=self.config.model_name,
temperature=self.config.temperature
)
return self.llm

def run(self, input_data: str):
"""Runs the Cerebras model with the provided input text."""
if not self.llm:
self.load()
if not self.llm:
raise ValueError("`llm` attribute not set.")
message = HumanMessage(content=input_data)
response = self.llm([message])
return response.content

@staticmethod
def load_from_env_config() -> CerebrasConfigModel:
"""Loads the Cerebras configurations from environment variables."""
return CerebrasConfigModel(
cerebras_api_key=read_from_env("CEREBRAS_API_KEY", raise_exception=True),
model_name=read_from_env("Cerebras_MODEL", raise_exception=False),
temperature=read_from_env("Cerebras_TEMP", raise_exception=False)
)
Loading