Skip to content

Commit

Permalink
Merge pull request #7 from LlmKira/dev
Browse files Browse the repository at this point in the history
fix:random prompt
  • Loading branch information
sudoskys authored Feb 7, 2024
2 parents d9d85f5 + 4706785 commit 4e6df07
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 21 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ loop.run_until_complete(main())
```python
from novelai_python.utils.random_prompt import RandomPromptGenerator

s = RandomPromptGenerator(nsfw_enabled=False).generate()
s = RandomPromptGenerator(nsfw_enabled=False).random_prompt()
print(s)
```

Expand Down
15 changes: 10 additions & 5 deletions playground/random_prompt.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,16 @@
# @Author : sudoskys
# @File : random_prompt.py
# @Software: PyCharm
import random

from novelai_python.utils.random_prompt import RandomPromptGenerator


print(random.random())
s = RandomPromptGenerator(nsfw_enabled=True).generate()
print(s)
gen = RandomPromptGenerator(nsfw_enabled=True)
print(gen.get_weighted_choice([[1, 35], [2, 20], [3, 7]], []))
print("====")
print(gen.get_weighted_choice([['mss', 30], ['fdd', 50], ['oa', 10]], []))
print("====")
print(gen.get_weighted_choice([['m', 30], ['f', 50], ['o', 10]], ['m']))
print("====")
for i in range(200):
s = RandomPromptGenerator(nsfw_enabled=True).random_prompt()
print(s)
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[project]
name = "novelai-python"
version = "0.2.0"
version = "0.2.1"
description = "Novelai Python Binding With Pydantic"
authors = [
{ name = "sudoskys", email = "[email protected]" },
Expand Down
42 changes: 28 additions & 14 deletions src/novelai_python/utils/random_prompt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,16 +30,25 @@ def get_weighted_choice(tags, existing_tags: List[str]):
:param existing_tags: a list of existing tags
:return: a tag
"""
valid_tags = [tag for tag in tags if
len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
valid_tags = [tag
for tag in tags
if len(tag) < 3 or not tag[2] or any(sub_tag in existing_tags for sub_tag in tag[2])]
total_weight = sum(tagr[1] for tagr in valid_tags if len(tagr) > 1)
if total_weight == 0:
return random.choice(tags)
if isinstance(tags, list):
rd = random.choice(tags)
elif isinstance(tags, str):
rd = tags
else:
raise ValueError('get_weighted_choice: should not reach here')
return rd
random_number = random.randint(1, total_weight)
cumulative_weight = 0
for tag in valid_tags:
cumulative_weight += tag[1]
if random_number <= cumulative_weight:
if isinstance(tag, str):
raise Exception("tag is string")
return tag[0]
raise ValueError('get_weighted_choice: should not reach here')

Expand All @@ -61,7 +70,7 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
has_unique_feature = any(feature in features for feature in unique_features)
if random.random() < 0.1 and enable_skin_color:
features.append(self.get_weighted_choice(sinkColor, features))
if random.random() < 0.3:
if random.random() < 0.8:
features.append(self.get_weighted_choice(eyeColors, features))
if random.random() < 0.1:
features.append(self.get_weighted_choice(eyeCharacteristics, features))
Expand All @@ -71,12 +80,12 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
features.append(self.get_weighted_choice(hairLength, features))
if random.random() < 0.2:
features.append(self.get_weighted_choice(backHairStyle, features))
if random.random() < 0.2:
if random.random() < 0.1:
features.append(self.get_weighted_choice(hairColors, features))
if random.random() < 0.1:
features.append(self.get_weighted_choice(hairColorExtra, features))
features.append(self.get_weighted_choice(hairColors, features))
if random.random() < 0.1:
if random.random() < 0.12:
features.append(self.get_weighted_choice(hairFeatures, features))
if gender.startswith('f') and random.random() < 0.8:
features.append(self.get_weighted_choice(breastsSize, features))
Expand Down Expand Up @@ -136,9 +145,9 @@ def character_features(self, gender, camera_angle, nsfw_enabled, num_characters,
# 单角色 + nsfw 为 1
possible_actions = action
if nsfw_enabled:
if random.random() < 0.3:
if random.random() < 0.5:
features.append(self.get_weighted_choice(nsfw["action"], features))
if random.random() < 0.25:
if random.random() < 0.5:
features.append(self.get_weighted_choice(nsfw["pussyForeplay"], features))
possible_actions += nsfw["action"] + nsfw["analForeplay"] + nsfw["pussyForeplay"]
if random.random() < 0.5:
Expand Down Expand Up @@ -190,16 +199,21 @@ def random_prompt(self, *,
enable_moods: bool = True,
enable_character: bool = True,
enable_identity: bool = False,
must_appear=None,
):
if must_appear is None:
must_appear = []
tags = []
# 必须出现的标签
tags.extend(must_appear)
if self.nsfw_enabled:
tags.append('nsfw')
if random.random() < 0.1:
tags.append('explicit')
tags.append('lewd')
irs = self.get_weighted_choice([[1, 70], [2, 20], [3, 7], [0, 5]], tags)
if self.nsfw_enabled:
irs = self.get_weighted_choice([[1, 35], [2, 20], [3, 7]], tags)
irs = self.get_weighted_choice([[1, 40], [2, 20], [3, 7]], tags)
if irs == 0:
tags.append('no humans')
if random.random() < 0.3:
Expand All @@ -216,11 +230,11 @@ def random_prompt(self, *,
return ', '.join(tags)
if random.random() < 0.3:
tags.append(self.get_weighted_choice(artStyle, tags))
if random.random() < 0.5:
if random.random() < 0.6:
tags.append("{" + self.get_weighted_choice(rankArtist, tags) + "}")
if random.random() < 0.5:
tags.append("[" + self.get_weighted_choice(rankArtist, tags) + "]")
if random.random() < 0.5:
if random.random() < 0.4:
tags.append("{{" + self.get_weighted_choice(rankArtist, tags) + "}}")
c_count = 0
d_count = 0
Expand Down Expand Up @@ -260,17 +274,17 @@ def random_prompt(self, *,
tags.append(nsfw["ya"])
if d_count > 0:
if c_count > 0:
if random.random() < 0.6:
if random.random() < 0.3:
tags.append(self.get_weighted_choice(nsfw["penis"], tags))
else:
if random.random() < 0.3:
tags.append(self.get_weighted_choice(nsfw["penis"], tags))
if d_count > 0 and g_count > 0:
if random.random() < 0.5:
if random.random() < 0.7:
tags.append(self.get_weighted_choice(nsfw["analSex"], tags))
if g_count > 0:
features = []
if random.random() < 0.3:
if random.random() < 0.2:
features = self.character_features(nsfw['fu'], None, True, irs)
if random.random() < 0.6:
tags.append(self.get_weighted_choice(nsfw["sex"], tags))
Expand Down

0 comments on commit 4e6df07

Please sign in to comment.