Skip to content

Commit

Permalink
Fix puctuations and upgrade to version 0.8 (#48)
Browse files Browse the repository at this point in the history
* Fix puctuations and upgrade to version 0.8

* Minor fixes

* Fix style

* Fix style
  • Loading branch information
pkufool authored Aug 4, 2023
1 parent 32a34d6 commit 63b18db
Show file tree
Hide file tree
Showing 6 changed files with 76 additions and 32 deletions.
2 changes: 1 addition & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
cmake_minimum_required(VERSION 3.12 FATAL_ERROR)
project(textsearch)

set(TS_VERSION "0.7")
set(TS_VERSION "0.8")

set(CMAKE_ARCHIVE_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
set(CMAKE_LIBRARY_OUTPUT_DIRECTORY "${CMAKE_BINARY_DIR}/lib")
Expand Down
20 changes: 13 additions & 7 deletions examples/libriheavy/matching.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def get_params() -> AttributeDict:
"max_duration": 30,
"expected_duration": (5, 20),
"max_error_rate": 0.20,
# output
"output_post_texts": False,
}
)

Expand Down Expand Up @@ -330,6 +332,7 @@ def split(


def write(
params: AttributeDict,
batch_cuts: List[MonoCut],
results,
cuts_writer: SequentialJsonlWriter,
Expand Down Expand Up @@ -358,6 +361,14 @@ def write(
for seg in segments:
id = f"{current_cut.id}_{cut_segment_index[current_cut.id]}"
cut_segment_index[current_cut.id] += 1
custom = {
"texts": [seg["ref"], seg["hyp"]],
"pre_texts": [seg["pre_ref"], seg["pre_hyp"]],
"begin_byte": seg["begin_byte"],
"end_byte": seg["end_byte"],
}
if params.output_post_texts:
custom["post_texts"] = [seg["post_ref"], seg["post_hyp"]]
supervision = SupervisionSegment(
id=id,
channel=current_cut.supervisions[cut_indexes[1]].channel,
Expand All @@ -366,13 +377,7 @@ def write(
recording_id=current_cut.recording.id,
start=0,
duration=seg["duration"],
custom={
"texts": [seg["ref"], seg["hyp"]],
"pre_texts": [seg["pre_ref"], seg["pre_hyp"]],
"post_texts": [seg["post_ref"], seg["post_ref"]],
"begin_byte": seg["begin_byte"],
"end_byte": seg["end_byte"],
},
custom=custom,
)
cut = MonoCut(
id,
Expand Down Expand Up @@ -423,6 +428,7 @@ def process_one_batch(
logging.warning("Splitted data is empty.")
return
write(
params=params,
batch_cuts=batch_cuts,
results=splited_data,
cuts_writer=cuts_writer,
Expand Down
9 changes: 7 additions & 2 deletions examples/libriheavy/matching_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ def get_args():
parser.add_argument(
"--batch-size",
type=int,
default=100,
default=50,
help="""The number of cuts in a batch.
""",
)
Expand Down Expand Up @@ -245,6 +245,7 @@ def splitter(


def writer(
params: AttributeDict,
write_queue: Queue,
cuts_writer: SequentialJsonlWriter,
):
Expand All @@ -266,7 +267,10 @@ def writer(
results = item["segments"]
batch_cuts = item["cuts"]
write(
batch_cuts=batch_cuts, results=results, cuts_writer=cuts_writer
params=params,
batch_cuts=batch_cuts,
results=results,
cuts_writer=cuts_writer,
)
except Exception as e:
logging.error(f"Writer caught {type(e)}: e")
Expand Down Expand Up @@ -367,6 +371,7 @@ def main():
writer_thread = Thread(
target=writer,
args=(
params,
write_queue,
cuts_writer,
),
Expand Down
5 changes: 1 addition & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ build-backend = "setuptools.build_meta"

[project]
name = "fasttextsearch"
version = "0.7"
version = "0.8"
authors = [
{ name="Next-gen Kaldi development team", email="[email protected]" },
]
Expand All @@ -33,6 +33,3 @@ classifiers = [

[tool.black]
line-length = 80



61 changes: 45 additions & 16 deletions textsearch/python/textsearch/match.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,12 @@
)
from .suffix_array import create_suffix_array
from .datatypes import SourcedText, TextSource, Transcript
from .utils import is_overlap, is_punctuation, row_ids_to_row_splits
from .utils import (
PUCTUATIONS,
is_overlap,
is_punctuation,
row_ids_to_row_splits,
)


def get_longest_increasing_pairs(
Expand Down Expand Up @@ -241,7 +246,7 @@ def add_segments(
target_end,
)
else:
segments.append((query_start, query_end, target_start, target_end,))
segments.append((query_start, query_end, target_start, target_end))
else:
add_segments(
query_start,
Expand Down Expand Up @@ -682,7 +687,7 @@ def _get_segment_candidates(
"(?<!Mr|Mrs|Dr|Ms|Prof|Pro|Capt|Gen|Sen|Rev|Hon|St)\."
)
# the largest length of the patterns.
period_pattern_length = 4
period_pattern_length = 5

for i, align in enumerate(aligns):
matched = align["ref"] == align["hyp"]
Expand Down Expand Up @@ -741,7 +746,10 @@ def _get_segment_candidates(
]
]
)
if period_patterns.search(tmp) is not None:
if current_token != "." or (
current_token == "."
and period_patterns.search(tmp) is not None
):
prev_punctuation = punctuation_score
break
else:
Expand All @@ -764,7 +772,10 @@ def _get_segment_candidates(
]
]
)
if period_patterns.search(tmp) is not None:
if current_token != "." or (
current_token == "."
and period_patterns.search(tmp) is not None
):
succ_punctuation = punctuation_score
break
else:
Expand All @@ -788,16 +799,16 @@ def _get_segment_candidates(

if target_source.has_punctuation:
if prev_punctuation > 0 or i == 0:
begin_scores.append((i, begin_score,))
begin_scores.append((i, begin_score))
if succ_punctuation > 0 or i == len(aligns) - 1:
end_scores.append((i, end_score,))
end_scores.append((i, end_score))
else:
if matched and (prev_silence >= silence_length_to_break or i == 0):
begin_scores.append((i, begin_score,))
begin_scores.append((i, begin_score))
if matched and (
succ_silence >= silence_length_to_break or i == len(aligns) - 1
):
end_scores.append((i, end_score,))
end_scores.append((i, end_score))

# (start, end, score)
begin_list: List[Tuple[int, int, float]] = []
Expand Down Expand Up @@ -875,7 +886,7 @@ def _get_segment_candidates(
item_q,
(
point_score + matched_score - error_score + duration_score,
(item[0], end_scores[ind][0],),
(item[0], end_scores[ind][0]),
),
)
if len(item_q) > num_of_best_position:
Expand Down Expand Up @@ -956,7 +967,7 @@ def _get_segment_candidates(
item_q,
(
point_score + matched_score - error_score + duration_score,
(begin_scores[ind][0], item[0],),
(begin_scores[ind][0], item[0]),
),
)
if len(item_q) > num_of_best_position:
Expand Down Expand Up @@ -1080,7 +1091,23 @@ def _split_into_segments(

for seg in segments:
begin_pos = aligns[seg[0]]["ref_pos"]
end_pos = aligns[seg[1]]["ref_pos"] + 1
while begin_pos >= 1:
current_token = chr(target_source.binary_text[begin_pos - 1])
if current_token in PUCTUATIONS["left"]:
begin_pos -= 1
else:
break

end_pos = aligns[seg[1]]["ref_pos"]
while end_pos + 1 < target_source.binary_text.size:
current_token = chr(target_source.binary_text[end_pos + 1])
if (
current_token in PUCTUATIONS["right"]
or current_token in PUCTUATIONS["eos"]
):
end_pos += 1
else:
break

preceding_index = seg[0] if seg[0] == 0 else seg[0] - 1
succeeding_index = seg[1] if seg[1] == len(aligns) - 1 else seg[1] + 1
Expand All @@ -1105,11 +1132,13 @@ def _split_into_segments(
end_time = aligns[succeeding_index]["hyp_time"]

hyp_begin_pos = aligns[seg[0]]["hyp_pos"]
hyp_end_pos = aligns[seg[1]]["hyp_pos"] + 1
hyp_end_pos = aligns[seg[1]]["hyp_pos"]
hyp = "".join(
[
chr(i)
for i in query_source.binary_text[hyp_begin_pos:hyp_end_pos]
for i in query_source.binary_text[
hyp_begin_pos : hyp_end_pos + 1
]
]
)

Expand All @@ -1131,7 +1160,7 @@ def _split_into_segments(
[
chr(i)
for i in target_source.binary_text[
end_pos : end_pos + preceding_context_length
end_pos + 1 : end_pos + preceding_context_length
]
]
)
Expand All @@ -1151,7 +1180,7 @@ def _split_into_segments(
[
chr(i)
for i in query_source.binary_text[
hyp_end_pos : hyp_begin_pos + preceding_context_length
hyp_end_pos + 1 : hyp_begin_pos + preceding_context_length
]
]
)
Expand Down
11 changes: 9 additions & 2 deletions textsearch/python/textsearch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,13 @@

Pathlike = Union[str, Path]

PUCTUATIONS = {
"all": set("',.;?!():-<>/\",。;?!():-《》【】”“"),
"eos": set(".?!。?!"),
"left": set("\"'(<《【“"),
"right": set("\"')>》】”"),
}


class AttributeDict(dict):
def __getattr__(self, key):
Expand Down Expand Up @@ -161,8 +168,8 @@ def is_punctuation(c: str, eos_only: bool = False) -> bool:
If True the punctuations are only those indicating end of a sentence (.?! for now).
"""
if eos_only:
return c in ".?!"
return c in ',.;?!():-<>-/"'
return c in PUCTUATIONS["eos"]
return c in PUCTUATIONS["all"]


def str2bool(v):
Expand Down

0 comments on commit 63b18db

Please sign in to comment.