Skip to content

Commit

Permalink
Included Iterators in the library
Browse files Browse the repository at this point in the history
- Added WordIterator, GlobalWordIterator and SentenceIterator to the library

- Added quotes around the generator expression inside target_link_libraries command

- Included tests for iterators
  • Loading branch information
sivaprasad2000 committed Dec 12, 2023
1 parent 7f89761 commit bc88fea
Show file tree
Hide file tree
Showing 5 changed files with 122 additions and 3 deletions.
1 change: 1 addition & 0 deletions bindings/python/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


from .repository import Aggregator, TranslateLocallyLike
from .iterators import *

REPOSITORY = Aggregator(
[
Expand Down
78 changes: 78 additions & 0 deletions bindings/python/iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,78 @@
class WordIterator:
def __init__(self, annotation, sentence_id):
self._sentence_id = sentence_id
self._word_id = -1
self._annotation = annotation

def __iter__(self):
self._word_id = -1
return self

def __next__(self):
self._word_id += 1
if self._word_id >= self._annotation.word_count(self._sentence_id):
raise StopIteration
return self

def surface(self):
range = self.range()
return self._annotation.text[range.begin:range.end]

def range(self):
return self._annotation.word_as_range(self._sentence_id, self._word_id)

def id(self):
return (self._sentence_id, self._word_id)

class GlobalWordIterator:
def __init__(self, annotation):
self._annotation = annotation
self._word_id = -1
self._sentence_id = 0

def __iter__(self):
self._word_id = -1
self._sentence_id = 0
return self

def __next__(self):
self._word_id += 1
if self._word_id >= self._annotation.word_count(self._sentence_id):
self._sentence_id += 1
if self._sentence_id >= self._annotation.sentence_count():
raise StopIteration
self._word_id = 0
return self

def surface(self):
range = self.range()
return self._annotation.text[range.begin:range.end]

def range(self):
return self._annotation.word_as_range(self._sentence_id, self._word_id)

def id(self):
return (self._sentence_id, self._word_id)

class SentenceIterator:
def __init__(self, annotation):
self._annotation = annotation
self._sentence_id = -1

def __iter__(self):
self._sentence_id = -1
return self

def __next__(self):
self._sentence_id += 1
if self._sentence_id >= self._annotation.sentence_count():
raise StopIteration
return self

def words(self):
return WordIterator(self._annotation, self._sentence_id)

def __repr__(self):
range = self._annotation.sentence_as_range(self._sentence_id)
sentence = self._annotation.text[range.begin:range.end]
return f'{sentence}'
2 changes: 1 addition & 1 deletion bindings/python/tests/test_encoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,4 +46,4 @@ def test_basic(service, models):
byte_range = byte.word_as_range(sentence_idx, word_idx)
utf8_to_byte_range = utf8_to_byte.word_as_range(sentence_idx, word_idx)
assert byte_range.begin == utf8_to_byte_range.begin
assert byte_range.end == utf8_to_byte_range.end
assert byte_range.end == utf8_to_byte_range.end
40 changes: 40 additions & 0 deletions bindings/python/tests/test_iterators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
# type: ignore
from slimt import SentenceIterator, WordIterator, GlobalWordIterator

def test_basic(service, models):
source = "Hi, How are you? Its been a long time.\nCan you help me out with some things?"
model = models[1]
response = service.translate(model, [source], html=False)[0]

target = response.target
text = target.text

sen_iter_tgt = SentenceIterator(target)
word_iter_global = GlobalWordIterator(target)

sentence_count = target.sentence_count()
for sentence_idx, word_iter in zip(range(sentence_count), sen_iter_tgt):
word_count = target.word_count(sentence_idx)
for word_idx, word in zip(range(word_count), word_iter.words()):

expected_text_range = target.word_as_range(sentence_idx, word_idx)
reconstructed_text_range = word.range()

# For SentenceIterator and WordIterator
assert expected_text_range.begin == reconstructed_text_range.begin
assert expected_text_range.end == reconstructed_text_range.end

expected_word = text[expected_text_range.begin:expected_text_range.end]
reconstructed_word = word.surface()

assert expected_word == reconstructed_word

word_global = next(word_iter_global)

reconstructed_text_range_glob = word_global.range()
reconstructed_word_glob = word_global.surface()

# For GlobalWordIterator
assert expected_text_range.begin == reconstructed_text_range_glob.begin
assert expected_text_range.end == reconstructed_text_range_glob.end
assert expected_word == reconstructed_word_glob
4 changes: 2 additions & 2 deletions slimt/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,8 @@ foreach(SLIMT_LIB IN LISTS SLIMT_LIBRARIES)
target_link_libraries(
${SLIMT_LIB}
PUBLIC ${SLIMT_PUBLIC_LIBS}
INTERFACE $<BUILD_INTERFACE:${SLIMT_INTERFACE_LIBS}>
PRIVATE $<BUILD_INTERFACE:${SLIMT_PRIVATE_LIBS}>)
INTERFACE "$<BUILD_INTERFACE:${SLIMT_INTERFACE_LIBS}>"
PRIVATE "$<BUILD_INTERFACE:${SLIMT_PRIVATE_LIBS}>")

target_include_directories(
${SLIMT_LIB}
Expand Down

0 comments on commit bc88fea

Please sign in to comment.