Skip to content

Commit

Permalink
Merge pull request #76 from LlmKira/dev
Browse files Browse the repository at this point in the history
fix(0.5.1): add tokenizer.get_vocab() | Diff generation endpoint / text.novelai.net or api.novelai.net
  • Loading branch information
sudoskys authored Sep 26, 2024
2 parents 9cbed20 + 601456c commit 7c71f0c
Show file tree
Hide file tree
Showing 8 changed files with 58 additions and 37 deletions.
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.5.0"
version = "0.5.1"
description = "NovelAI Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
4 changes: 2 additions & 2 deletions src/novelai_python/sdk/ai/_cost.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
import math
import random
from typing import List, Optional, Union
from typing import List, Optional

from pydantic import BaseModel

from novelai_python.sdk.ai._const import map, initialN, initial_n, step, newN
from novelai_python.sdk.ai._enum import Sampler, Model, ModelGroups, get_model_group, ModelTypeAlias
from novelai_python.sdk.ai._enum import Sampler, ModelGroups, get_model_group, ModelTypeAlias


class Args(BaseModel):
Expand Down
25 changes: 18 additions & 7 deletions src/novelai_python/sdk/ai/generate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,16 @@ async def necessary_headers(self, request_data) -> dict:

@model_validator(mode="after")
def normalize_model(self):
if self.model in [
TextLLMModel.NEO_2B, TextLLMModel.J_6B, TextLLMModel.J_6B_V3, TextLLMModel.J_6B_V4,
TextLLMModel.GENJI_JP_6B, TextLLMModel.GENJI_JP_6B_V2, TextLLMModel.GENJI_PYTHON_6B,
TextLLMModel.EUTERPE_V0, TextLLMModel.EUTERPE_V2, TextLLMModel.KRAKE_V1, TextLLMModel.KRAKE_V2,
TextLLMModel.CASSANDRA, TextLLMModel.COMMENT_BOT, TextLLMModel.INFILL, TextLLMModel.CLIO
]:
self.endpoint = "https://api.novelai.net"
tokenizer = NaiTokenizer(get_tokenizer_model(self.model))
model_group = get_llm_group(self.model)
total_tokens = len(tokenizer.encode(self.input))
total_tokens = tokenizer.total_tokens()
if isinstance(self.input, str):
prompt = tokenizer.encode(self.input)
dtype = "uint32" if self.model in [TextLLMModel.ERATO] else "uint16"
Expand Down Expand Up @@ -238,10 +245,10 @@ def normalize_model(self):

valid_sequences = []
for cell in logit_bias_group_exp:
if not any(token < 0 or token >= total_tokens for token in cell.sequence):
if any(token < 0 or token >= total_tokens for token in cell.sequence):
# 超出范围的 Token
logger.debug(
f"Bias {cell} contains tokens that are out of range and will be ignored."
logger.trace(
f"Bias [{cell}] contains tokens that are out of range and will be ignored."
)
else:
# 将有效偏置组添加到列表
Expand All @@ -258,8 +265,8 @@ def normalize_model(self):
valid_bad_words_ids = []
self.parameters.bad_words_ids = bad_words_ids
for ban_word in self.parameters.bad_words_ids:
if not any(token < 0 or token >= total_tokens for token in ban_word):
logger.warning(
if any(token < 0 or token >= total_tokens for token in ban_word):
logger.trace(
f"Bad word {ban_word} contains tokens that are out of range and will be ignored."
)
else:
Expand All @@ -277,6 +284,10 @@ def normalize_model(self):
self.advanced_setting.num_logprobs = self.logprobs_count
if not self.advanced_setting.max_length:
self.advanced_setting.max_length = 40
if self.parameters.repetition_penalty_range == 0:
self.parameters.repetition_penalty_range = None
if self.parameters.repetition_penalty_slope == 0:
self.parameters.repetition_penalty_slope = None
return self

@classmethod
Expand Down Expand Up @@ -358,7 +369,7 @@ async def request(self,
message = response_data.get("error") or response_data.get(
"message") or f"Server error with status_code {response.status_code}"
status_code = response_data.get("statusCode", response.status_code)
if status_code == 200:
if status_code == 200 or status_code == 201:
output = response_data.get("output", None)
if not output:
raise APIError(
Expand Down
45 changes: 20 additions & 25 deletions src/novelai_python/sdk/ai/generate/_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from enum import Enum
from typing import List, Optional, Tuple, Union

from pydantic import BaseModel, Field, model_validator, ConfigDict
from pydantic import BaseModel, Field, ConfigDict

"""
class StorySettings(BaseModel):
Expand Down Expand Up @@ -77,19 +77,14 @@ class LogitBiasGroup(BaseModel):
ensure_sequence_finish: bool
generate_once: bool

@model_validator(mode="before")
def validate_sequence(cls, value):
print(value)
return value


class AdvanceLLMSetting(BaseModel):
"""
LLM Generation Request Parameters
"""
min_length: Optional[int] = 1
max_length: Optional[int] = None
repetition_penalty: Optional[float] = None
repetition_penalty: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
generate_until_sentence: Optional[bool] = True
use_cache: Optional[bool] = False
use_string: Optional[bool] = False
Expand All @@ -107,36 +102,36 @@ class LLMGenerationParams(BaseModel):
LLM Generation Settings
"""
textGenerationSettingsVersion: Optional[int] = None
temperature: Optional[float] = None
temperature: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
max_length: Optional[int] = Field(default=None, ge=1, le=100) # Max 150(vip3),100(vip2)
min_length: Optional[int] = Field(default=None, ge=1, le=100) # Max 150(vip3),100(vip2)
top_k: Optional[float] = Field(default=None, ge=0)
top_p: Optional[float] = Field(default=None, gt=0, le=1)
top_a: Optional[float] = Field(default=None, ge=0)
typical_p: Optional[float] = Field(default=None, ge=0, le=1)
tail_free_sampling: Optional[float] = Field(default=None, ge=0)
repetition_penalty: Optional[float] = Field(default=None, gt=0)
top_k: Optional[int] = Field(default=None, ge=0)
top_p: Optional[Union[float, int]] = Field(default=None, gt=0, le=1, allow_inf_nan=False)
top_a: Optional[Union[float, int]] = Field(default=None, ge=0, allow_inf_nan=False)
typical_p: Optional[Union[float, int]] = Field(default=None, ge=0, le=1, allow_inf_nan=False)
tail_free_sampling: Optional[Union[float, int]] = Field(default=None, ge=0, allow_inf_nan=False)
repetition_penalty: Optional[Union[float, int]] = Field(default=None, gt=0, allow_inf_nan=False)
repetition_penalty_range: Optional[int] = Field(default=None, ge=0)
repetition_penalty_slope: Optional[float] = Field(default=None, ge=0)
repetition_penalty_slope: Optional[Union[float, int]] = Field(default=None, ge=0, allow_inf_nan=False)

eos_token_id: int = None
bad_words_ids: List[List[int]] = None
logit_bias_groups: Optional[List[LogitBiasGroup]] = []

repetition_penalty_frequency: Optional[float] = None
repetition_penalty_presence: Optional[float] = None
repetition_penalty_frequency: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
repetition_penalty_presence: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
repetition_penalty_whitelist: Optional[List[int]] = None
repetition_penalty_default_whitelist: Optional[bool] = None
cfg_scale: Optional[float] = None
cfg_scale: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
cfg_uc: Optional[str] = None
phrase_rep_pen: PenStyle = PenStyle.Off
top_g: Optional[float] = None
mirostat_tau: Optional[float] = None
mirostat_lr: Optional[float] = None
math1_temp: Optional[float] = None
math1_quad: Optional[float] = None
math1_quad_entropy_scale: Optional[float] = None
min_p: Optional[float] = None
top_g: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
mirostat_tau: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
mirostat_lr: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
math1_temp: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
math1_quad: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
math1_quad_entropy_scale: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)
min_p: Optional[Union[float, int]] = Field(default=None, allow_inf_nan=False)

order: Union[List[int], List[KeyOrderEntry]] = [
KeyOrderEntry(id=Key.Cfg, enabled=False),
Expand Down
2 changes: 1 addition & 1 deletion src/novelai_python/sdk/ai/generate_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,7 @@ async def request(self,
request=request_data, code=status_code, response=message
)
else:
if response.status_code != 200:
if response.status_code not in [200, 201]:
raise APIError(
f"Server error with status code {response.status_code}",
request=request_data, code=response.status_code, response=response.content
Expand Down
7 changes: 7 additions & 0 deletions src/novelai_python/tokenizer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,13 @@ def tokenize_text(self, text: str) -> List[int]:
return self.tokenizer.encode(text).tokens
raise NotImplementedError("Tokenizer does not support token encoding")

def total_tokens(self) -> int:
if isinstance(self.tokenizer, Tokenizer):
return len(self.tokenizer.get_vocab())
if isinstance(self.tokenizer, SimpleTokenizer):
return len(self.tokenizer.get_vocab())
raise NotImplementedError("Tokenizer does not support token encoding")

def encode(self, sentence: str) -> List[int]:
if isinstance(self.tokenizer, SimpleTokenizer):
return self.tokenizer.encode(sentence).ids
Expand Down
8 changes: 8 additions & 0 deletions src/novelai_python/tokenizer/clip_simple_tokenizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,11 @@ def decode(self, tokens):
text = ''.join([self.decoder[token] for token in tokens])
text = bytearray([self.byte_decoder[c] for c in text]).decode('utf-8', errors="replace").replace('</w>', ' ')
return text

def get_vocab(self):
return self.encoder


if __name__ == '__main__':
tokenizer = SimpleTokenizer()
print(len(tokenizer.get_vocab()))
2 changes: 1 addition & 1 deletion src/novelai_python/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ def try_jsonfy(obj: Union[str, dict, list, tuple], default_when_error=None):
try:
return json.loads(obj)
except Exception as e:
logger.trace(f"Decode Error {obj}")
logger.trace(f"Decode Error {obj} {e}")
if default_when_error is None:
return f"Decode Error {type(obj)}"
else:
Expand Down

0 comments on commit 7c71f0c

Please sign in to comment.