Skip to content

Commit

Permalink
Merge pull request #97 from LlmKira/dev
Browse files Browse the repository at this point in the history
🎨 refactor(generate_voice): update speaker attributes and API usage
sudoskys authored Jan 4, 2025
2 parents 52ab424 + e5981ba commit 904bf85
Showing 17 changed files with 408 additions and 366 deletions.
2 changes: 1 addition & 1 deletion playground/generate_voice.py
Original file line number Diff line number Diff line change
@@ -23,7 +23,7 @@ async def generate_voice(text: str):
try:
voice_gen = VoiceGenerate.build(
text=text,
voice_engine=VoiceSpeakerV1.Crina, # VoiceSpeakerV2.Ligeia,
speaker=VoiceSpeakerV2.Ligeia, # VoiceSpeakerV2.Ligeia,
)
result = await voice_gen.request(
session=credential
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.7.2"
version = "0.7.3"
description = "NovelAI Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "coldlando@hotmail.com" },
8 changes: 8 additions & 0 deletions src/novelai_python/_exceptions.py
Original file line number Diff line number Diff line change
@@ -74,6 +74,14 @@ class ConcurrentGenerationError(APIError):
pass


class DataSerializationError(APIError):
"""
DataSerializationError is raised when the API data is not serializable.
"""

pass


class AuthError(APIError):
"""
AuthError is raised when the API returns an error.
139 changes: 75 additions & 64 deletions src/novelai_python/sdk/ai/augment_image/__init__.py
Original file line number Diff line number Diff line change
@@ -11,7 +11,7 @@
from io import BytesIO
from typing import Optional, Union, IO, Any
from urllib.parse import urlparse
from zipfile import ZipFile
from zipfile import ZipFile, BadZipFile

import curl_cffi
import httpx
@@ -25,13 +25,16 @@
from novelai_python.sdk.ai._enum import Model
from ._enum import ReqType, Moods
from ...schema import ApiBaseModel
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError, DataSerializationError
from ...._response.ai.generate_image import ImageGenerateResp, RequestParams
from ....credential import CredentialBase
from ....utils import try_jsonfy


class AugmentImageInfer(ApiBaseModel):
"""
https://docs.novelai.net/image/directortools.html
"""
_endpoint: str = PrivateAttr("https://image.novelai.net")

@property
@@ -210,96 +213,104 @@ async def request(self,
:param session: session
:return:
"""
# Data Build
# Prepare request data
request_data = self.model_dump(mode="json", exclude_none=True)
async with session if isinstance(session, AsyncSession) else await session.get_session() as sess:
# Header
sess.headers.update(await self.necessary_headers(request_data))
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)

# Log the request data (sanitize sensitive content)
try:
_log_data = deepcopy(request_data)
_log_data.update({
"image": "base64 data"
})
logger.debug(f"Request Data: {_log_data}")
del _log_data
if self.image:
_log_data["image"] = "base64 data hidden"
logger.debug(f"Request Data: {json.dumps(_log_data, indent=2)}")
except Exception as e:
logger.warning(f"Error when print log data: {e}")
logger.warning(f"Failed to log request data: {e}")

# Perform request and handle response
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
data=json.dumps(request_data).encode("utf-8")
)
if response.headers.get('Content-Type') not in ['binary/octet-stream', 'application/x-zip-compressed']:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
if (
response.headers.get('Content-Type') not in ['binary/octet-stream',
'application/x-zip-compressed']
or response.status_code >= 400
):
error_message = await self.handle_error_response(response, request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
if status_code in [409]:
# conflict error
raise APIError(message, request=request_data, code=status_code, response=_msg)
if status_code in [429]:
# concurrent error
raise AuthError(message, request=request_data, code=status_code, response=error_message)
elif status_code == 409:
raise APIError(message, request=request_data, code=status_code, response=error_message)
elif status_code == 429:
raise ConcurrentGenerationError(
message=message,
request=request_data,
code=status_code,
response=_msg
)
raise APIError(message, request=request_data, code=status_code, response=_msg)
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise APIError(
message="No file in zip",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
response=error_message,
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return ImageGenerateResp(
meta=RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content
)
else:
raise APIError(message, request=request_data, code=status_code, response=error_message)

# Unpack the ZIP response
try:
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise DataSerializationError(
message="The ZIP response contains no files.",
request=request_data,
response=try_jsonfy(response.content),
code=response.status_code,
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return ImageGenerateResp(
meta=RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content,
)
except BadZipFile as e:
# Invalid ZIP file - indicate serialization error
logger.exception("The response content is not a valid ZIP file.")
raise DataSerializationError(
message="Invalid ZIP file received from the API.",
request=request_data,
response={},
code=response.status_code,
) from e
except Exception as e:
logger.exception("Unexpected error while unpacking ZIP response.")
raise DataSerializationError(
message="An unexpected error occurred while processing ZIP data.",
request=request_data,
response={},
code=response.status_code,
) from e
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
raise SessionHttpError(
"An AsyncSession RequestsError occurred, maybe SSL error. Try again later!") from exc
raise SessionHttpError("A RequestsError occurred (e.g., SSL error). Try again later.")
except httpx.HTTPError as exc:
logger.exception(exc)
raise SessionHttpError("An HTTPError occurred, maybe SSL error. Try again later!") from exc
raise SessionHttpError("An HTTP error occurred. Try again later.")
except APIError as e:
raise e
except Exception as e:
logger.opt(exception=e).exception("An Unexpected error occurred")
raise e
logger.opt(exception=e).exception("Unexpected error occurred during the request.")
raise Exception("An unexpected error occurred.") from e
2 changes: 1 addition & 1 deletion src/novelai_python/sdk/ai/generate/__init__.py
Original file line number Diff line number Diff line change
@@ -369,7 +369,7 @@ async def request(self,
logger.debug(f"LLM request data: {json.dumps(request_data)}")
# Request
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
json=request_data,
176 changes: 102 additions & 74 deletions src/novelai_python/sdk/ai/generate_image/__init__.py
Original file line number Diff line number Diff line change
@@ -12,7 +12,7 @@
from io import BytesIO
from typing import Optional, Union, Tuple, List
from urllib.parse import urlparse
from zipfile import ZipFile
from zipfile import ZipFile, BadZipFile

import curl_cffi
import cv2
@@ -32,7 +32,7 @@
get_model_group, ModelGroups, get_supported_params, get_modifiers, ImageBytesTypeAlias
from .schema import Character, V4Prompt, V4NegativePrompt, PositionMap
from ...schema import ApiBaseModel
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError
from ...._exceptions import APIError, AuthError, ConcurrentGenerationError, SessionHttpError, DataSerializationError
from ...._response.ai.generate_image import ImageGenerateResp, RequestParams
from ....credential import CredentialBase
from ....utils import try_jsonfy
@@ -403,10 +403,19 @@ def model_post_init(self, *args) -> None:
# Add negative prompt based on ucPreset
if self.parameters.ucPreset is not None:
uc_preset = self.parameters.ucPreset

# If ucPreset is Enum, get the value
if isinstance(self.parameters.ucPreset, Enum):
uc_preset = self.parameters.ucPreset.value
default_negative_prompt = get_default_uc_preset(self.model, uc_preset)

# Lowres means we don't found any negative prompt.
# If the default negative prompt is lowres, and the user has set a negative prompt,
# then the default negative prompt is not added.
if self.parameters.negative_prompt and default_negative_prompt == "lowres":
default_negative_prompt = ""

# Combine the negative prompt preset and the user's negative prompt
self.parameters.negative_prompt = ", ".join(
filter(None, [default_negative_prompt, self.parameters.negative_prompt])
)
@@ -986,113 +995,132 @@ async def necessary_headers(self, request_data) -> dict:
retry=retry_if_exception(lambda e: hasattr(e, "code") and str(e.code) == "500"),
reraise=True
)
async def request(self,
session: Union[AsyncSession, "CredentialBase"],
*,
override_headers: Optional[dict] = None,
) -> ImageGenerateResp:
async def request(
self,
session: Union[AsyncSession, "CredentialBase"],
*,
override_headers: Optional[dict] = None,
) -> ImageGenerateResp:
"""
**Generate images using NovelAI's diffusion models.**
According to our Terms of Service, all generation requests must be initiated by a human action. Automating text or image generation to create excessive load on our systems is not allowed.
:param override_headers: the headers to override
:param session: session
:return:
According to our Terms of Service, all generation requests must be initiated by a human action. Automating text
or image generation to create excessive load on our systems is not allowed.
:param override_headers: Headers to override the default headers.
:param session: Async session object or credential-based session.
:return: ImageGenerateResp containing the response data and metadata.
:raises AuthError: If the request is unauthorized.
:raises APIError: If the API returns an error.
:raises ConcurrentGenerationError: If the request is rate-limited.
:raises SessionHttpError: If an HTTP error occurs.
:raises DataSerializationError: If an error occurs while processing the response data.
"""
# Data Build
# Prepare request data
request_data = self.model_dump(mode="json", exclude_none=True)
async with session if isinstance(session, AsyncSession) else await session.get_session() as sess:
# Header
sess.headers.update(await self.necessary_headers(request_data))
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)

# Log the request data (sanitize sensitive content)
try:
_log_data = deepcopy(request_data)
if self.parameters.image:
_log_data["parameters"]["image"] = "base64 data"
_log_data["parameters"]["image"] = "base64 data hidden"
if self.parameters.mask:
_log_data["parameters"]["mask"] = "base64 data"
_log_data["parameters"]["mask"] = "base64 data hidden"
if self.parameters.reference_image_multiple:
_log_data["parameters"]["reference_image_multiple"] = ["base64 data"] * len(
self.parameters.reference_image_multiple)
_log_data["parameters"]["reference_image_multiple"] = ["base64 data hidden"] * len(
self.parameters.reference_image_multiple
)
logger.debug(f"Request Data: {json.dumps(_log_data, indent=2)}")
del _log_data
except Exception as e:
logger.warning(f"Error when print log data: {e}")
logger.warning(f"Failed to log request data: {e}")

# Perform request and handle response
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
data=json.dumps(request_data).encode("utf-8")
data=json.dumps(request_data).encode("utf-8"),
)
if response.headers.get('Content-Type') not in ['binary/octet-stream', 'application/x-zip-compressed']:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
# Validate response content type and status code
if (
response.headers.get("Content-Type")
not in ["binary/octet-stream", "application/x-zip-compressed"]
or response.status_code >= 400
):
error_message = await self.handle_error_response(response, request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
if status_code in [409]:
# conflict error
raise APIError(message, request=request_data, code=status_code, response=_msg)
if status_code in [429]:
# concurrent error
raise AuthError(message, request=request_data, code=status_code, response=error_message)
elif status_code == 409:
raise APIError(message, request=request_data, code=status_code, response=error_message)
elif status_code == 429:
raise ConcurrentGenerationError(
message=message,
request=request_data,
code=status_code,
response=_msg
response=error_message,
)
raise APIError(message, request=request_data, code=status_code, response=_msg)
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise APIError(
message="No file in zip",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return ImageGenerateResp(
meta=RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content
)
else:
raise APIError(message, request=request_data, code=status_code, response=error_message)

# Unpack the ZIP response
try:
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise DataSerializationError(
message="The ZIP response contains no files.",
request=request_data,
response=try_jsonfy(response.content),
code=response.status_code,
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return ImageGenerateResp(
meta=RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content,
)
except BadZipFile as e:
# Invalid ZIP file - indicate serialization error
logger.exception("The response content is not a valid ZIP file.")
raise DataSerializationError(
message="Invalid ZIP file received from the API.",
request=request_data,
response={},
code=response.status_code,
) from e
except Exception as e:
logger.exception("Unexpected error while unpacking ZIP response.")
raise DataSerializationError(
message="An unexpected error occurred while processing ZIP data.",
request=request_data,
response={},
code=response.status_code,
) from e
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
raise SessionHttpError("An AsyncSession RequestsError occurred, maybe SSL error. Try again later!")
raise SessionHttpError("A RequestsError occurred (e.g., SSL error). Try again later.")
except httpx.HTTPError as exc:
logger.exception(exc)
raise SessionHttpError("An HTTPError occurred, maybe SSL error. Try again later!")
raise SessionHttpError("An HTTP error occurred. Try again later.")
except APIError as e:
raise e
except Exception as e:
logger.opt(exception=e).exception("An Unexpected error occurred")
raise e
logger.opt(exception=e).exception("Unexpected error occurred during the request.")
raise Exception("An unexpected error occurred.") from e
33 changes: 10 additions & 23 deletions src/novelai_python/sdk/ai/generate_image/suggest_tags.py
Original file line number Diff line number Diff line change
@@ -77,41 +77,28 @@ async def request(self,
if override_headers:
session.headers.clear()
session.headers.update(override_headers)
logger.debug("SuggestTags")
try:
assert hasattr(session, "get"), "session must have get method."
response = await session.get(
url=self.base_url + "?" + "&".join([f"{k}={v}" for k, v in request_data.items()])
)
if "application/json" not in response.headers.get('Content-Type') or response.status_code != 200:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
if (
"application/json" not in response.headers.get('Content-Type')
or response.status_code != 200
):
error_message = await self.handle_error_response(response=response, request_data=request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
raise AuthError(message, request=request_data, code=status_code, response=error_message)
if status_code in [500]:
# An unknown error occured.
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=error_message)
raise APIError(message, request=request_data, code=status_code, response=error_message)
return SuggestTagsResp.model_validate(response.json())
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
2 changes: 1 addition & 1 deletion src/novelai_python/sdk/ai/generate_stream.py
Original file line number Diff line number Diff line change
@@ -95,7 +95,7 @@ async def request(self,
logger.debug(f"StreamLLM request data: {request_data}")
# Request
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
json=request_data,
79 changes: 31 additions & 48 deletions src/novelai_python/sdk/ai/generate_voice/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
# @Author : sudoskys
from enum import Enum
from typing import Optional, Union, Literal
from urllib.parse import urlparse

@@ -11,7 +12,7 @@
from pydantic import ConfigDict, PrivateAttr, Field, model_validator
from tenacity import wait_random, retry, stop_after_attempt, retry_if_exception

from ._enum import VoiceSpeakerV1, VoiceSpeakerV2
from ._enum import VoiceSpeakerV1, VoiceSpeakerV2, Speaker
from ...schema import ApiBaseModel
from ...._exceptions import APIError, SessionHttpError
from ...._response.ai.generate_voice import VoiceResponse
@@ -80,43 +81,36 @@ async def necessary_headers(self, request_data) -> dict:
@classmethod
def build(cls,
text: str,
voice_engine: Union[VoiceSpeakerV1, VoiceSpeakerV2, str],
speaker: Union[VoiceSpeakerV2, VoiceSpeakerV1, Speaker, str],
*,
opus: bool = False
opus: bool = True
) -> "VoiceGenerate":
"""
生成图片
:param opus: unknown
:param text: str
:param voice_engine: VoiceSpeakerV1 or VoiceSpeakerV2 or str
:param speaker: Speaker import from novelai_python.sdk.ai.generate_voice._enum
:return: VoiceGenerate instance
:raises: ValueError
"""
if isinstance(voice_engine, str):
if isinstance(speaker, Enum):
speaker = speaker.value

if isinstance(speaker, str):
return cls(
text=text,
voice=-1,
seed=voice_engine,
opus=opus,
version="v2"
)
if isinstance(voice_engine, VoiceSpeakerV2):
return cls(
text=text,
seed=voice_engine.value.seed,
voice=voice_engine.value.sid,
seed=speaker,
opus=opus,
version="v2"
)
if isinstance(voice_engine, VoiceSpeakerV1):
return cls(
text=text,
seed=voice_engine.value.seed,
voice=voice_engine.value.sid,
opus=opus,
version="v1"
)
raise ValueError("Invalid voice engine")
return cls(
text=text,
seed=speaker.seed,
voice=speaker.voice,
opus=opus,
version=speaker.version
)

@retry(
wait=wait_random(min=1, max=3),
@@ -147,35 +141,24 @@ async def request(self,
logger.debug(f"Voice request data: {request_data}")
# Request
try:
assert hasattr(sess, "get"), "session must have get method."
response = await sess.get(
self.base_url,
params=request_data
self.ensure_session_has_post_method(sess)
response = await sess.post(
url=self.base_url,
json=request_data
)
header_type = response.headers.get('Content-Type')
if header_type not in ['audio/mpeg', 'audio/ogg', 'audio/opus']:
logger.warning(
f"Error with content type: {header_type} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content: {header_type} with code: {response.status_code}",
request=request_data,
code=response.status_code,
response="UnJsoned content"
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
if (
header_type not in ['audio/mpeg', 'audio/ogg', 'audio/opus', 'audio/wav', 'audio/webm']
or response.status_code >= 400
):

error_message = await self.handle_error_response(response=response, request_data=request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400]:
# Validation tts version error
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=error_message)
raise APIError(message, request=request_data, code=status_code, response=error_message)
return VoiceResponse(
meta=request_data,
audio=response.content,
29 changes: 14 additions & 15 deletions src/novelai_python/sdk/ai/generate_voice/_enum.py
Original file line number Diff line number Diff line change
@@ -1,37 +1,36 @@
from typing import Optional
from enum import Enum

from pydantic import BaseModel
from enum import Enum


class Speaker(BaseModel):
"""
Speaker for /ai/generated_voice
"""
sid: int = -1
seed: Optional[str] = None
voice: int = -1
seed: str = "kurumuz12"
name: str
category: str

@property
def version(self):
return "v2" if self.sid is None else "v1"
return "v2" if self.voice == -1 else "v1"


class VoiceSpeakerV1(Enum):
"""
Speaker for /ai/generated_voice
"""
Cyllene = Speaker(sid=17, name="Cyllene", category="female")
Leucosia = Speaker(sid=95, name="Leucosia", category="female")
Crina = Speaker(sid=44, name="Crina", category="female")
Hespe = Speaker(sid=80, name="Hespe", category="female")
Ida = Speaker(sid=106, name="Ida", category="female")
Alseid = Speaker(sid=6, name="Alseid", category="male")
Daphnis = Speaker(sid=10, name="Daphnis", category="male")
Echo = Speaker(sid=16, name="Echo", category="male")
Thel = Speaker(sid=41, name="Thel", category="male")
Nomios = Speaker(sid=77, name="Nomios", category="male")
Cyllene = Speaker(voice=17, name="Cyllene", category="female")
Leucosia = Speaker(voice=95, name="Leucosia", category="female")
Crina = Speaker(voice=44, name="Crina", category="female")
Hespe = Speaker(voice=80, name="Hespe", category="female")
Ida = Speaker(voice=106, name="Ida", category="female")
Alseid = Speaker(voice=6, name="Alseid", category="male")
Daphnis = Speaker(voice=10, name="Daphnis", category="male")
Echo = Speaker(voice=16, name="Echo", category="male")
Thel = Speaker(voice=41, name="Thel", category="male")
Nomios = Speaker(voice=77, name="Nomios", category="male")
# SeedInput = Speaker(sid=-1, name="Seed Input", category="custom")


128 changes: 66 additions & 62 deletions src/novelai_python/sdk/ai/upscale.py
Original file line number Diff line number Diff line change
@@ -4,10 +4,11 @@
# @File : upscale.py
import base64
import json
from copy import deepcopy
from io import BytesIO
from typing import Optional, Union
from urllib.parse import urlparse
from zipfile import ZipFile
from zipfile import ZipFile, BadZipFile

import curl_cffi
import httpx
@@ -17,7 +18,7 @@
from tenacity import wait_random, retry, stop_after_attempt, retry_if_exception

from ..schema import ApiBaseModel
from ..._exceptions import APIError, AuthError, SessionHttpError
from ..._exceptions import APIError, AuthError, SessionHttpError, DataSerializationError
from ..._response.ai.upscale import UpscaleResp
from ...credential import CredentialBase
from ...utils import try_jsonfy
@@ -104,86 +105,89 @@ async def request(self,
:param session: session
:return:
"""
# Data Build
# Prepare request data
request_data = self.model_dump(mode="json", exclude_none=True)
async with session if isinstance(session, AsyncSession) else await session.get_session() as sess:
# Header
sess.headers.update(await self.necessary_headers(request_data))
if override_headers:
sess.headers.clear()
sess.headers.update(override_headers)

# Log the request data (sanitize sensitive content)
try:
_log_data = request_data.copy()
_log_data.update({"image": "base64 data"}) if isinstance(_log_data.get("image"), str) else None
logger.info(f"Upscale request data: {_log_data}")
_log_data = deepcopy(request_data)
if self.image:
_log_data["image"] = "base64 data hidden"
logger.debug(f"Request Data: {json.dumps(_log_data, indent=2)}")
except Exception as e:
logger.warning(f"Error when print log data: {e}")
logger.warning(f"Failed to log request data: {e}")

# Perform request and handle response
try:
assert hasattr(sess, "post"), "session must have post method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
data=json.dumps(request_data).encode("utf-8")
)
if response.headers.get('Content-Type') not in ['binary/octet-stream', 'application/x-zip-compressed']:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
if (
response.headers.get('Content-Type') not in ['binary/octet-stream',
'application/x-zip-compressed']
or response.status_code >= 400
):
error_message = await self.handle_error_response(response=response, request_data=request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
raise AuthError(message, request=request_data, code=status_code, response=error_message)
if status_code in [409]:
# conflict error
raise APIError(message, request=request_data, code=status_code, response=_msg)
"""
if status_code in [429]:
# concurrent error
raise ConcurrentGenerationError(
message=message,
request=request_data,
code=status_code,
response=_msg
)
"""
raise APIError(message, request=request_data, code=status_code, response=_msg)
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise APIError(
message="No file in zip",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return UpscaleResp(
meta=UpscaleResp.RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content[0]
)
raise APIError(message, request=request_data, code=status_code, response=error_message)
raise APIError(message, request=request_data, code=status_code, response=error_message)

# Unpack the ZIP response
try:
zip_file = ZipFile(BytesIO(response.content))
unzip_content = []
with zip_file as zf:
file_list = zf.namelist()
if not file_list:
raise DataSerializationError(
message="The ZIP response contains no files.",
request=request_data,
response=try_jsonfy(response.content),
code=response.status_code,
)
for filename in file_list:
data = zip_file.read(filename)
unzip_content.append((filename, data))
return UpscaleResp(
meta=UpscaleResp.RequestParams(
endpoint=self.base_url,
raw_request=request_data,
),
files=unzip_content[0]
)
except BadZipFile as e:
# Invalid ZIP file - indicate serialization error
logger.exception("The response content is not a valid ZIP file.")
raise DataSerializationError(
message="Invalid ZIP file received from the API.",
request=request_data,
response={},
code=response.status_code,
) from e
except Exception as e:
logger.exception("Unexpected error while unpacking ZIP response.")
raise DataSerializationError(
message="An unexpected error occurred while processing ZIP data.",
request=request_data,
response={},
code=response.status_code,
) from e
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
raise SessionHttpError("An AsyncSession RequestsError occurred, maybe SSL error. Try again later!")
55 changes: 55 additions & 0 deletions src/novelai_python/sdk/schema.py
Original file line number Diff line number Diff line change
@@ -7,14 +7,22 @@
from typing import Optional, Union

from curl_cffi.requests import AsyncSession
from loguru import logger
from pydantic import BaseModel, PrivateAttr

from ..credential import CredentialBase
from ..utils import try_jsonfy


class ApiBaseModel(BaseModel, ABC):
_endpoint: Optional[str] = PrivateAttr()

@property
@abstractmethod
def base_url(self):
logger.error("ApiBaseModel.base_url must be overridden")
return f"{self.endpoint.strip('/')}/need-to-override"

@property
def endpoint(self):
return self._endpoint
@@ -27,6 +35,53 @@ def endpoint(self, value):
async def necessary_headers(self, request_data) -> dict:
raise NotImplementedError()

@staticmethod
def ensure_session_has_post_method(session):
if not hasattr(session, "post"):
raise AttributeError("SESSION_MUST_HAVE_POST_METHOD")

@staticmethod
def ensure_session_has_get_method(session):
if not hasattr(session, "get"):
raise AttributeError("SESSION_MUST_HAVE_GET_METHOD")

async def handle_error_response(
self,
response,
request_data: dict,
content_hint: str = "Response content too long",
max_content_length: int = 50
) -> dict:
"""
Common method to handle error response
:param response: HTTP response
:param request_data: request data
:param content_hint: hint for content too long
:param max_content_length: max content length
:return: dict of error message
"""
logger.debug(
f"\n[novelai-python] Unexpected response:\n"
f" - URL : {self.base_url}\n"
f" - Content-Type: {response.headers.get('Content-Type', 'N/A')}\n"
f" - Status : {response.status_code}\n"
)
try:
# 尝试解析 JSON 响应
error_message = response.json()
except Exception as e:
# 如果解析 JSON 失败,则记录日志,并尝试显示短内容
logger.warning(f"Failed to parse error response: {e}")
error_message = {
"statusCode": response.status_code,
"message": try_jsonfy(response.content)
if len(response.content) < max_content_length
else content_hint,
}
# 日志记录解析出的错误消息
logger.trace(f"Parsed error message: {error_message}")
return error_message

@abstractmethod
async def request(self,
session: Union[AsyncSession, CredentialBase],
31 changes: 8 additions & 23 deletions src/novelai_python/sdk/user/information.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@
from ..._exceptions import APIError, AuthError, SessionHttpError
from ..._response.user.information import InformationResp
from ...credential import CredentialBase
from ...utils import try_jsonfy


class Information(ApiBaseModel):
@@ -65,39 +64,25 @@ async def request(self,
sess.headers.update(override_headers)
logger.debug("Information")
try:
assert hasattr(sess, "get"), "session must have get method."
self.ensure_session_has_get_method(sess)
response = await sess.get(
self.base_url
)
if "application/json" not in response.headers.get('Content-Type') or response.status_code != 200:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
error_message = await self.handle_error_response(response=response, request_data=request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
raise AuthError(message, request=request_data, code=status_code, response=error_message)
if status_code in [500]:
# An unknown error occured.
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=error_message)
raise APIError(message, request=request_data, code=status_code, response=error_message)

return InformationResp.model_validate(response.json())
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
31 changes: 8 additions & 23 deletions src/novelai_python/sdk/user/login.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,7 @@
from ..._exceptions import APIError, SessionHttpError
from ..._response.user.login import LoginResp
from ...credential import CredentialBase
from ...utils import try_jsonfy, encode_access_key
from ...utils import encode_access_key


class Login(ApiBaseModel):
@@ -76,38 +76,23 @@ async def request(self,
sess.headers.update(override_headers)
logger.debug("Fetching login-credential")
try:
assert hasattr(sess, "post"), "session must have get method."
self.ensure_session_has_post_method(sess)
response = await sess.post(
self.base_url,
data=json.dumps(request_data).encode("utf-8")
)
if "application/json" not in response.headers.get('Content-Type') or response.status_code != 201:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
error_message = await self.handle_error_response(response=response, request_data=request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401]:
# 400 : A validation error occured.
# 401 : Access Key is incorrect.
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=error_message)
if status_code in [500]:
# An unknown error occured.
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=error_message)
raise APIError(message, request=request_data, code=status_code, response=error_message)
return LoginResp.model_validate(response.json())
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
30 changes: 7 additions & 23 deletions src/novelai_python/sdk/user/subscription.py
Original file line number Diff line number Diff line change
@@ -16,7 +16,6 @@
from ..._exceptions import APIError, AuthError, SessionHttpError
from ..._response.user.subscription import SubscriptionResp
from ...credential import CredentialBase
from ...utils import try_jsonfy


class Subscription(ApiBaseModel):
@@ -73,39 +72,24 @@ async def request(self,
sess.headers.update(override_headers)
logger.debug("Subscription")
try:
assert hasattr(sess, "get"), "session must have get method."
self.ensure_session_has_get_method(sess)
response = await sess.get(
url=self.base_url,
)
if "application/json" not in response.headers.get('Content-Type') or response.status_code != 200:
logger.warning(
f"Error with content type: {response.headers.get('Content-Type')} and code: {response.status_code}"
)
try:
_msg = response.json()
except Exception as e:
logger.warning(e)
if not isinstance(response.content, str) and len(response.content) < 50:
raise APIError(
message=f"Unexpected content type: {response.headers.get('Content-Type')}",
request=request_data,
code=response.status_code,
response=try_jsonfy(response.content)
)
else:
_msg = {"statusCode": response.status_code, "message": response.content}
status_code = _msg.get("statusCode", response.status_code)
message = _msg.get("message", "Unknown error")
error_message = await self.handle_error_response(response=response, request_data=request_data)
status_code = error_message.get("statusCode", response.status_code)
message = error_message.get("message", "Unknown error")
if status_code in [400, 401, 402]:
# 400 : validation error
# 401 : unauthorized
# 402 : payment required
# 409 : conflict
raise AuthError(message, request=request_data, code=status_code, response=_msg)
raise AuthError(message, request=request_data, code=status_code, response=error_message)
if status_code in [500]:
# An unknown error occured.
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=_msg)
raise APIError(message, request=request_data, code=status_code, response=error_message)
raise APIError(message, request=request_data, code=status_code, response=error_message)
return SubscriptionResp.model_validate(response.json())
except curl_cffi.requests.errors.RequestsError as exc:
logger.exception(exc)
10 changes: 9 additions & 1 deletion src/novelai_python/tokenizer/__init__.py
Original file line number Diff line number Diff line change
@@ -25,7 +25,15 @@ def download_file(url, destination_path, session):
:return: None
:raises ValueError: If the downloaded file size doesn't match the Content-Length header.
"""
response = session.get(url, headers={"Content-Type": "application/json"})
response = session.get(
url,
timeout=30,
headers={
"User-Agent": "novelai-python/0.1.0",
"Accept-Encoding": "gzip, deflate, br",
"Accept": "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
}
)
response.raise_for_status()
with open(destination_path, "wb") as f:
f.write(response.content)
17 changes: 11 additions & 6 deletions tests/test_generate_voice.py
Original file line number Diff line number Diff line change
@@ -13,14 +13,19 @@ async def test_request():
mock_response = mock.MagicMock()
mock_response.content = b'audio_content'
mock_response.headers = {'Content-Type': 'audio/mpeg'}
mock_response.status_code = 200
mock_response.json = AsyncMock(return_value={"statusCode": 200, "message": "Success"})

# 使用 AsyncMock 模拟异步方法
session = mock.MagicMock(spec=AsyncSession)
session.get = AsyncMock(return_value=mock_response)
session.post = AsyncMock(return_value=mock_response)
session.headers = {}

# Mock '__aenter__' 和 '__aexit__',以兼容异步上下文管理器
session.__aenter__ = AsyncMock(return_value=session)
session.__aexit__ = AsyncMock(return_value=None)

# 创建 VoiceGenerate 对象
voice_generate = VoiceGenerate(
text="Hello, world!",
voice=-1,
@@ -30,12 +35,12 @@ async def test_request():
)

# Act
result = await voice_generate.request(session=session)
result = await voice_generate.request(session=session, override_headers=None)

# Assert
session.get.assert_called_once_with(
voice_generate.base_url,
params=voice_generate.model_dump(mode="json", exclude_none=True)
session.post.assert_called_once_with(
url=voice_generate.base_url,
json=voice_generate.model_dump(mode="json", exclude_none=True)
)
assert isinstance(result, VoiceResponse)
assert result.audio == b'audio_content'
assert result.audio == b'audio_content'

0 comments on commit 904bf85

Please sign in to comment.