Skip to content

Commit

Permalink
[fix] add_space
Browse files Browse the repository at this point in the history
  • Loading branch information
huyiwen committed Jun 15, 2024
1 parent 0ff9488 commit af0162d
Show file tree
Hide file tree
Showing 3 changed files with 15 additions and 6 deletions.
17 changes: 12 additions & 5 deletions utilization/chat_templates.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,18 +5,18 @@

def add_space(
msg: str,
auto_leading_space: bool,
remove_space_between: bool,
context: str,
auto_leading_space: bool = True,
remove_space_between: bool = True,
starts: Optional[List[str]] = None,
ends: Optional[List[str]] = None
) -> str:
if starts is None or ends is None or remove_space_between is False:
context_ends_special = False
msg_starts_special = False
else:
context_ends_special = any(context.endswith(e) for e in ends)
msg_starts_special = any(msg.startswith(s) for s in starts)
context_ends_special = any(context.endswith(e) for e in ends if len(e) > 0)
msg_starts_special = any(msg.startswith(s) for s in starts if len(s) > 0)
if (auto_leading_space and msg and context)\
and not (context[-1].isspace() or msg[0].isspace())\
and not (context_ends_special and msg_starts_special):
Expand All @@ -30,7 +30,14 @@ def smart_space(parts: List[str], auto_leading_space: bool, remove_space_between
rendered = ""
for part in parts:
if part:
rendered += add_space(part, auto_leading_space, remove_space_between, rendered, starts, ends)
rendered += add_space(
part,
rendered,
auto_leading_space=auto_leading_space,
remove_space_between=remove_space_between,
starts=starts,
ends=ends
)
return rendered


Expand Down
2 changes: 1 addition & 1 deletion utilization/model/model_utils/conversation.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,7 +159,7 @@ def _get_segs(self, conversations: List["Conversation"], max_turns: int = 1) ->
for seg in (system, examples, source, target):
if len(seg) > 0:
if len(result) > 0:
seg = add_space(seg, True, result[-1])
seg = add_space(seg, result[-1])
elif self.final_lstrip:
seg = seg.lstrip()
result += (seg,)
Expand Down
2 changes: 2 additions & 0 deletions utilization/utils/generation_args.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,8 @@ def resolve_generation_args(
# overrides
if key in extra_generation_args:
extra = extra_generation_args.pop(key)
if value is None and not details.nullable:
continue
if callable(extra):
overrided = extra(value, details)
for new_key, new_value in overrided.items():
Expand Down

0 comments on commit af0162d

Please sign in to comment.