Skip to content

Commit

Permalink
Fix default getting overridden to None
Browse files Browse the repository at this point in the history
  • Loading branch information
aravind10x committed Dec 28, 2024
1 parent e036c67 commit d1f044f
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 13 deletions.
16 changes: 9 additions & 7 deletions src/ragbuilder/core/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ def __init__(
n_trials: Optional[int] = None,
log_config: Optional[LogConfig] = None
):
ConfigStore.set_default_llm(default_llm)
ConfigStore.set_default_embeddings(default_embeddings)
ConfigStore.set_default_n_trials(n_trials)
ConfigStore.set_default_llm(default_llm) if default_llm else None
ConfigStore.set_default_embeddings(default_embeddings) if default_embeddings else None
ConfigStore.set_default_n_trials(n_trials) if n_trials else None
self._log_config = log_config or LogConfig()
self.data_ingest_config = data_ingest_config
self.retrieval_config = retrieval_config
Expand All @@ -61,7 +61,9 @@ def __init__(
self._optimization_results = OptimizationResults()
self._test_dataset_manager = TestDatasetManager(
self._log_config,
db_path=self.data_ingest_config.database_path if self.data_ingest_config else DEFAULT_DB_PATH
db_path=(self.data_ingest_config.database_path
if self.data_ingest_config and self.data_ingest_config.database_path
else DEFAULT_DB_PATH)
)

@classmethod
Expand All @@ -74,9 +76,9 @@ def from_source_with_defaults(cls,
log_config: Optional[LogConfig] = None
) -> 'RAGBuilder':
"""Create RAGBuilder instance with default configuration"""
ConfigStore.set_default_llm(default_llm)
ConfigStore.set_default_embeddings(default_embeddings)
ConfigStore.set_default_n_trials(n_trials)
ConfigStore.set_default_llm(default_llm) if default_llm else None
ConfigStore.set_default_embeddings(default_embeddings) if default_embeddings else None
ConfigStore.set_default_n_trials(n_trials) if n_trials else None

data_ingest_config = DataIngestOptionsConfig.with_defaults(
input_source=input_source,
Expand Down
8 changes: 2 additions & 6 deletions src/ragbuilder/core/config_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __new__(cls):
@classmethod
def set_default_llm(cls, llm_config: Optional[Union[Dict[str, Any], LLMConfig, BaseChatModel, BaseLLM]]) -> None:
"""Store default LLM configuration or instance"""
if llm_config is None:
cls._default_llm = None
elif isinstance(llm_config, dict):
if isinstance(llm_config, dict):
cls._default_llm = LLMConfig(
type=LLMType.OPENAI,
model_kwargs=llm_config
Expand All @@ -55,9 +53,7 @@ def get_default_llm(cls) -> LLMConfig:
@classmethod
def set_default_embeddings(cls, embedding_config: Optional[Union[Dict[str, Any], EmbeddingConfig, Embeddings]]) -> None:
"""Store default Embedding configuration or instance"""
if embedding_config is None:
cls._default_embeddings = None
elif isinstance(embedding_config, dict):
if isinstance(embedding_config, dict):
cls._default_embeddings = EmbeddingConfig(
type=EmbeddingType.OPENAI,
model_kwargs=embedding_config
Expand Down

0 comments on commit d1f044f

Please sign in to comment.