Skip to content

Commit

Permalink
openai[minor]: change to secretstr (langchain-ai#16803)
Browse files Browse the repository at this point in the history
  • Loading branch information
efriis authored and raghavdixit99 committed Feb 1, 2024
1 parent e728f5c commit 13409bc
Show file tree
Hide file tree
Showing 10 changed files with 162 additions and 63 deletions.
5 changes: 2 additions & 3 deletions libs/partners/openai/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,9 @@ all: help
# Define a variable for the test file path.
TEST_FILE ?= tests/unit_tests/

test:
poetry run pytest $(TEST_FILE)
integration_tests: TEST_FILE=tests/integration_tests/

tests:
test tests integration_tests:
poetry run pytest $(TEST_FILE)


Expand Down
28 changes: 18 additions & 10 deletions libs/partners/openai/langchain_openai/chat_models/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,12 +3,12 @@

import logging
import os
from typing import Any, Callable, Dict, List, Union
from typing import Any, Callable, Dict, List, Optional, Union

import openai
from langchain_core.outputs import ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_openai.chat_models.base import ChatOpenAI

Expand Down Expand Up @@ -71,9 +71,9 @@ class AzureChatOpenAI(ChatOpenAI):
"""
openai_api_version: str = Field(default="", alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
azure_ad_token: Optional[SecretStr] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
Expand Down Expand Up @@ -111,11 +111,14 @@ def validate_environment(cls, values: Dict) -> Dict:
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
Expand All @@ -131,8 +134,9 @@ def validate_environment(cls, values: Dict) -> Dict:
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)

values["openai_api_type"] = get_from_dict_or_env(
Expand Down Expand Up @@ -168,8 +172,12 @@ def validate_environment(cls, values: Dict) -> Dict:
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"azure_ad_token": values["azure_ad_token"].get_secret_value()
if values["azure_ad_token"]
else None,
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
Expand Down
16 changes: 8 additions & 8 deletions libs/partners/openai/langchain_openai/chat_models/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,10 +52,11 @@
ToolMessageChunk,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.pydantic_v1 import BaseModel, Field, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.runnables import Runnable
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
Expand Down Expand Up @@ -240,10 +241,7 @@ def is_lc_serializable(cls) -> bool:
"""What sampling temperature to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
# When updating this to use a SecretStr
# Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
Expand Down Expand Up @@ -321,8 +319,8 @@ def validate_environment(cls, values: Dict) -> Dict:
if values["n"] > 1 and values["streaming"]:
raise ValueError("n must be 1 when streaming.")

values["openai_api_key"] = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
values["openai_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "openai_api_key", "OPENAI_API_KEY")
)
# Check OPENAI_ORGANIZATION for backwards compatibility.
values["openai_organization"] = (
Expand All @@ -341,7 +339,9 @@ def validate_environment(cls, values: Dict) -> Dict:
)

client_params = {
"api_key": values["openai_api_key"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
Expand Down
26 changes: 17 additions & 9 deletions libs/partners/openai/langchain_openai/embeddings/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from typing import Callable, Dict, Optional, Union

import openai
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_openai.embeddings.base import OpenAIEmbeddings

Expand Down Expand Up @@ -39,9 +39,9 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
If given sets the base client URL to include `/deployments/{azure_deployment}`.
Note: this means you won't be able to use non-deployment endpoints.
"""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
azure_ad_token: Optional[SecretStr] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
Expand All @@ -64,11 +64,14 @@ def validate_environment(cls, values: Dict) -> Dict:
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
Expand All @@ -92,8 +95,9 @@ def validate_environment(cls, values: Dict) -> Dict:
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)
# Azure OpenAI embedding models allow a maximum of 16 texts
# at a time in each batch
Expand Down Expand Up @@ -122,8 +126,12 @@ def validate_environment(cls, values: Dict) -> Dict:
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"azure_ad_token": values["azure_ad_token"].get_secret_value()
if values["azure_ad_token"]
else None,
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
Expand Down
25 changes: 20 additions & 5 deletions libs/partners/openai/langchain_openai/embeddings/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,18 @@
import openai
import tiktoken
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, Extra, Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_core.pydantic_v1 import (
BaseModel,
Extra,
Field,
SecretStr,
root_validator,
)
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -70,7 +80,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
openai_proxy: Optional[str] = None
embedding_ctx_length: int = 8191
"""The maximum number of tokens to embed at once."""
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_organization: Optional[str] = Field(default=None, alias="organization")
"""Automatically inferred from env var `OPENAI_ORG_ID` if not provided."""
Expand Down Expand Up @@ -152,9 +162,12 @@ def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["openai_api_key"] = get_from_dict_or_env(
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
Expand Down Expand Up @@ -196,7 +209,9 @@ def validate_environment(cls, values: Dict) -> Dict:
"please use the `AzureOpenAIEmbeddings` class."
)
client_params = {
"api_key": values["openai_api_key"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
Expand Down
35 changes: 18 additions & 17 deletions libs/partners/openai/langchain_openai/llms/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,18 +2,11 @@

import logging
import os
from typing import (
Any,
Callable,
Dict,
List,
Mapping,
Union,
)
from typing import Any, Callable, Dict, List, Mapping, Optional, Union

import openai
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env

from langchain_openai.llms.base import BaseOpenAI

Expand Down Expand Up @@ -52,9 +45,9 @@ class AzureOpenAI(BaseOpenAI):
"""
openai_api_version: str = Field(default="", alias="api_version")
"""Automatically inferred from env var `OPENAI_API_VERSION` if not provided."""
openai_api_key: Union[str, None] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `AZURE_OPENAI_API_KEY` if not provided."""
azure_ad_token: Union[str, None] = None
azure_ad_token: Optional[SecretStr] = None
"""Your Azure Active Directory token.
Automatically inferred from env var `AZURE_OPENAI_AD_TOKEN` if not provided.
Expand Down Expand Up @@ -92,17 +85,21 @@ def validate_environment(cls, values: Dict) -> Dict:
# Check OPENAI_KEY for backwards compatibility.
# TODO: Remove OPENAI_API_KEY support to avoid possible conflict when using
# other forms of azure credentials.
values["openai_api_key"] = (
openai_api_key = (
values["openai_api_key"]
or os.getenv("AZURE_OPENAI_API_KEY")
or os.getenv("OPENAI_API_KEY")
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)

values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"
)
values["azure_ad_token"] = values["azure_ad_token"] or os.getenv(
"AZURE_OPENAI_AD_TOKEN"
azure_ad_token = values["azure_ad_token"] or os.getenv("AZURE_OPENAI_AD_TOKEN")
values["azure_ad_token"] = (
convert_to_secret_str(azure_ad_token) if azure_ad_token else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
Expand Down Expand Up @@ -150,8 +147,12 @@ def validate_environment(cls, values: Dict) -> Dict:
"api_version": values["openai_api_version"],
"azure_endpoint": values["azure_endpoint"],
"azure_deployment": values["deployment_name"],
"api_key": values["openai_api_key"],
"azure_ad_token": values["azure_ad_token"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"azure_ad_token": values["azure_ad_token"].get_secret_value()
if values["azure_ad_token"]
else None,
"azure_ad_token_provider": values["azure_ad_token_provider"],
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
Expand Down
22 changes: 14 additions & 8 deletions libs/partners/openai/langchain_openai/llms/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,12 @@
)
from langchain_core.language_models.llms import BaseLLM
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, root_validator
from langchain_core.utils import get_from_dict_or_env, get_pydantic_field_names
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain_core.utils.utils import build_extra_kwargs

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -104,10 +108,7 @@ def lc_attributes(self) -> Dict[str, Any]:
"""Generates best_of completions server-side and returns the "best"."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not explicitly specified."""
# When updating this to use a SecretStr
# Check for classes that derive from this class (as some of them
# may assume openai_api_key is a str)
openai_api_key: Optional[str] = Field(default=None, alias="api_key")
openai_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `OPENAI_API_KEY` if not provided."""
openai_api_base: Optional[str] = Field(default=None, alias="base_url")
"""Base URL path for API requests, leave blank if not using a proxy or service
Expand Down Expand Up @@ -175,9 +176,12 @@ def validate_environment(cls, values: Dict) -> Dict:
if values["streaming"] and values["best_of"] > 1:
raise ValueError("Cannot stream results when best_of > 1.")

values["openai_api_key"] = get_from_dict_or_env(
openai_api_key = get_from_dict_or_env(
values, "openai_api_key", "OPENAI_API_KEY"
)
values["openai_api_key"] = (
convert_to_secret_str(openai_api_key) if openai_api_key else None
)
values["openai_api_base"] = values["openai_api_base"] or os.getenv(
"OPENAI_API_BASE"
)
Expand All @@ -194,7 +198,9 @@ def validate_environment(cls, values: Dict) -> Dict:
)

client_params = {
"api_key": values["openai_api_key"],
"api_key": values["openai_api_key"].get_secret_value()
if values["openai_api_key"]
else None,
"organization": values["openai_organization"],
"base_url": values["openai_api_base"],
"timeout": values["request_timeout"],
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def _get_embeddings(**kwargs: Any) -> AzureOpenAIEmbeddings:
return AzureOpenAIEmbeddings(
azure_deployment=DEPLOYMENT_NAME,
api_version=OPENAI_API_VERSION,
openai_api_base=OPENAI_API_BASE,
azure_endpoint=OPENAI_API_BASE,
openai_api_key=OPENAI_API_KEY,
**kwargs,
)
Expand Down Expand Up @@ -109,7 +109,7 @@ def test_azure_openai_embedding_with_empty_string() -> None:
openai.AzureOpenAI(
api_version=OPENAI_API_VERSION,
api_key=OPENAI_API_KEY,
base_url=embedding.openai_api_base,
azure_endpoint=OPENAI_API_BASE,
azure_deployment=DEPLOYMENT_NAME,
) # type: ignore
.embeddings.create(input="", model="text-embedding-ada-002")
Expand Down
Loading

0 comments on commit 13409bc

Please sign in to comment.