-
Notifications
You must be signed in to change notification settings - Fork 7.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update label_ops.py #13665
Closed
Closed
Update label_ops.py #13665
Changes from all commits
Commits
Show all changes
11 commits
Select commit
Hold shift + click to select a range
f24f5d6
Update label_ops.py
hicricket 2b82362
Update label_ops.py
hicricket 4316ddc
Update label_ops.py
hicricket 6f56fa6
Update label_ops.py
hicricket d09874e
Add files via upload
hicricket 0edd614
Update test_label_ops.py
hicricket 6836fbf
Update test_label_ops.py
hicricket bd413a1
Merge branch 'main' into bo_rec
hicricket b323ce3
Merge branch 'main' into bo_rec
GreatV 7376f43
Update tests/test_label_ops.py
GreatV c84010f
Update tests/test_label_ops.py
GreatV File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,320 @@ | ||
import os | ||
import sys | ||
import pytest | ||
import numpy as np | ||
import json | ||
|
||
# Import modules | ||
current_dir = os.path.dirname(os.path.abspath(__file__)) | ||
sys.path.append(os.path.abspath(os.path.join(current_dir, ".."))) | ||
|
||
from ppocr.data.imaug.label_ops import ( | ||
ClsLabelEncode, | ||
DetLabelEncode, | ||
CTCLabelEncode, | ||
AttnLabelEncode, | ||
) | ||
|
||
|
||
# Data generator function | ||
def generate_character_dict(tmp_path, characters): | ||
character_dict = tmp_path / "char_dict.txt" | ||
character_dict.write_text("\n".join(characters) + "\n") | ||
return str(character_dict) | ||
|
||
|
||
# Fixture: ClsLabelEncode | ||
@pytest.fixture | ||
def setup_cls_label_encode(): | ||
return ClsLabelEncode(label_list=["label1", "label2", "label3"]) | ||
|
||
|
||
# Fixture: CTCLabelEncode | ||
@pytest.fixture | ||
def setup_ctc_label_encode(tmp_path): | ||
character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"]) | ||
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path) | ||
|
||
|
||
@pytest.fixture | ||
def setup_ctc_label_encode_chinese(tmp_path): | ||
character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"]) | ||
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path) | ||
|
||
|
||
@pytest.fixture | ||
def setup_ctc_label_encode_tibetan(tmp_path): | ||
character_dict_path = generate_character_dict( | ||
tmp_path, ["ཀ", "ཁ", "ག", "ང", "ཀྵ", "ཀྪོ", "ཀྩོ", "ཀྤྲེ", "ཀླཱ", "གྒྲ"] | ||
) | ||
return CTCLabelEncode(max_text_length=25, character_dict_path=character_dict_path) | ||
|
||
|
||
# Fixture: AttnLabelEncode | ||
@pytest.fixture | ||
def setup_attn_label_encode(tmp_path): | ||
character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"]) | ||
return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path) | ||
|
||
|
||
@pytest.fixture | ||
def setup_attn_label_encode_chinese(tmp_path): | ||
character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"]) | ||
return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path) | ||
|
||
|
||
# Fixture: DetLabelEncode | ||
@pytest.fixture | ||
def setup_det_label_encode(): | ||
return DetLabelEncode() | ||
|
||
|
||
# Test functions | ||
@pytest.mark.parametrize( | ||
"label, expected", | ||
[ | ||
("label1", 0), | ||
("unknown_label", None), | ||
("", None), | ||
], | ||
) | ||
def test_cls_label_encode_call(setup_cls_label_encode, label, expected): | ||
encoder = setup_cls_label_encode | ||
data = {"label": label} | ||
encoded_data = encoder(data) | ||
|
||
if expected is not None: | ||
assert ( | ||
encoded_data["label"] == expected | ||
), f"Expected {expected} for label: {label}, but got {encoded_data['label']}" | ||
else: | ||
assert ( | ||
encoded_data is None | ||
), f"Expected None for label: {label}, but got {encoded_data}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"label, expected", | ||
[ | ||
("abc", np.array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0])), | ||
("unknown", None), | ||
("", None), | ||
("a" * 20, None), | ||
], | ||
) | ||
def test_ctc_label_encode_call(setup_ctc_label_encode, label, expected): | ||
encoder = setup_ctc_label_encode | ||
data = {"label": label} | ||
encoded_data = encoder(data) | ||
|
||
if expected is not None: | ||
assert np.array_equal( | ||
encoded_data["label"], expected | ||
), f"Failed for label: {label}, expected {expected} but got {encoded_data['label']}" | ||
assert encoded_data["length"] == len( | ||
label | ||
), f"Expected length {len(label)} but got {encoded_data['length']}" | ||
else: | ||
assert ( | ||
encoded_data is None | ||
), f"Expected None for label: {label}, but got {encoded_data}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"label, expected_result", | ||
[ | ||
("你好世界", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])), | ||
], | ||
) | ||
def test_ctc_label_encode_call_valid_text_chinese( | ||
setup_ctc_label_encode_chinese, label, expected_result | ||
): | ||
encoder = setup_ctc_label_encode_chinese | ||
data = {"label": label} | ||
encoded_data = encoder(data) | ||
|
||
assert np.array_equal( | ||
encoded_data["label"], expected_result | ||
), f"Failed for Chinese text: {label}" | ||
assert encoded_data["length"] == len( | ||
label | ||
), f"Expected length {len(label)} but got {encoded_data['length']}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"label, expected_result", | ||
[ | ||
( | ||
"ཀཁགངཀྪོཀྩོཀྤྲེཀླཱགྒྲགྒྲ", | ||
np.array( | ||
[ | ||
1, | ||
2, | ||
3, | ||
4, | ||
6, | ||
7, | ||
8, | ||
9, | ||
10, | ||
10, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
] | ||
), | ||
), | ||
( | ||
"ཀྵཁགངཀྩོཀྪོ", | ||
np.array( | ||
[ | ||
5, | ||
2, | ||
3, | ||
4, | ||
7, | ||
6, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
0, | ||
] | ||
), | ||
), | ||
], | ||
) | ||
def test_ctc_label_encode_call_valid_text_tibetan( | ||
setup_ctc_label_encode_tibetan, label, expected_result | ||
): | ||
encoder = setup_ctc_label_encode_tibetan | ||
data = {"label": label} | ||
encoded_data = encoder(data) | ||
assert np.array_equal( | ||
encoded_data["label"], expected_result | ||
), f"Failed for Tibetan text: {label}" | ||
assert encoded_data["length"] == len( | ||
expected_result[expected_result != 0] | ||
), f"Expected length {len(expected_result[expected_result != 0])} but got {encoded_data['length']}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"label, expected_shape, expected_length", | ||
[ | ||
("abc", (10,), 3), | ||
("unknown", None, None), | ||
("", None, None), | ||
("a" * 20, None, None), | ||
], | ||
) | ||
def test_attn_label_encode_call( | ||
setup_attn_label_encode, label, expected_shape, expected_length | ||
): | ||
encoder = setup_attn_label_encode | ||
data = {"label": label} | ||
encoded_data = encoder(data) | ||
|
||
if expected_shape is not None: | ||
assert ( | ||
encoded_data["label"].shape == expected_shape | ||
), f"Expected shape {expected_shape} but got {encoded_data['label'].shape}" | ||
assert ( | ||
encoded_data["label"][0] == 0 | ||
), f"Expected SOS token at start but got {encoded_data['label'][0]}" | ||
assert ( | ||
encoded_data["label"][expected_length + 1] == len(encoder.character) - 1 | ||
), f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}" | ||
assert ( | ||
encoded_data["length"] == expected_length | ||
), f"Expected length {expected_length} but got {encoded_data['length']}" | ||
else: | ||
assert ( | ||
encoded_data is None | ||
), f"Expected None for label: {label}, but got {encoded_data}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"label, expected_shape, expected_length", | ||
[ | ||
("你好世界", (10,), 4), | ||
], | ||
) | ||
def test_attn_label_encode_call_valid_text_chinese( | ||
setup_attn_label_encode_chinese, label, expected_shape, expected_length | ||
): | ||
encoder = setup_attn_label_encode_chinese | ||
data = {"label": label} | ||
encoded_data = encoder(data) | ||
|
||
assert ( | ||
encoded_data["label"].shape == expected_shape | ||
), f"Expected shape {expected_shape} but got {encoded_data['label'].shape}" | ||
assert ( | ||
encoded_data["label"][0] == 0 | ||
), f"Expected SOS token at start but got {encoded_data['label'][0]}" | ||
assert ( | ||
encoded_data["label"][expected_length + 1] == len(encoder.character) - 1 | ||
), f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}" | ||
assert ( | ||
encoded_data["length"] == expected_length | ||
), f"Expected length {expected_length} but got {encoded_data['length']}" | ||
|
||
|
||
@pytest.mark.parametrize( | ||
"label, expected_texts", | ||
[ | ||
('[{"points": [[0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', ["text"]), | ||
("[]", None), | ||
("", pytest.raises(json.JSONDecodeError)), | ||
( | ||
'[{"points": [0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', | ||
pytest.raises(json.JSONDecodeError), | ||
), | ||
], | ||
) | ||
def test_det_label_encode_call(setup_det_label_encode, label, expected_texts): | ||
encoder = setup_det_label_encode | ||
data = {"label": label} | ||
|
||
if isinstance(expected_texts, list): | ||
encoded_data = encoder(data) | ||
assert "polys" in encoded_data, "Missing polys in encoded data" | ||
assert "texts" in encoded_data, "Missing texts in encoded data" | ||
assert "ignore_tags" in encoded_data, "Missing ignore_tags in encoded data" | ||
assert ( | ||
encoded_data["texts"] == expected_texts | ||
), f"Expected texts {expected_texts} but got {encoded_data['texts']}" | ||
elif isinstance(expected_texts, type(pytest.raises(Exception))): | ||
with expected_texts: | ||
encoder(data) | ||
else: | ||
encoded_data = encoder(data) | ||
assert ( | ||
encoded_data is None | ||
), f"Expected None for label: {label}, but got {encoded_data}" |
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于中文等长字符集的语言来说,会不会导致处理速度变得特别慢?建议再考虑一下多语言的兼容性问题。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
特别慢倒是不会 但确实会影响到计算效率 特别是主流的字典行不存在多字符 只能是特殊情况用比较好 我觉得可以保留原来的逻辑 初始化的时候先对字典行进行判断 若存在多字符的行再用修改后的逻辑 这样也是个方法我想的是