-
Notifications
You must be signed in to change notification settings - Fork 7.8k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Update label_ops.py #13665
Update label_ops.py #13665
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这个PR是用来解决什么问题,能提供更详细一点的描述和说明?
我在训练藏文数据集的时候,对于字典中的多个字符组合的叠写字符无法正常匹配,训练标签只能匹配字典中的单字符 |
|
建议在代码添加一段英文注释说明一下,修复codestyle,删除掉上面被注释掉的代码。还要测试一下中英文正常的训练是否受影响。如果都没问题就可以合入了。直接在这个PR上继续commit就行。 感谢您的贡献。 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
感谢您的贡献,需要等我抽时间测试一下才能合入。
建议添加一个单测 import os
import sys
import pytest
import numpy as np
import json
# Import modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
from ppocr.data.imaug.label_ops import (
ClsLabelEncode,
DetLabelEncode,
CTCLabelEncode,
AttnLabelEncode,
)
# Data generator function
def generate_character_dict(tmp_path, characters):
character_dict = tmp_path / "char_dict.txt"
character_dict.write_text("\n".join(characters) + "\n")
return str(character_dict)
# Fixture: ClsLabelEncode
@pytest.fixture
def setup_cls_label_encode():
return ClsLabelEncode(label_list=["label1", "label2", "label3"])
# Fixture: CTCLabelEncode
@pytest.fixture
def setup_ctc_label_encode(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
@pytest.fixture
def setup_ctc_label_encode_chinese(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
@pytest.fixture
def setup_ctc_label_encode_tibetan(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["ཀ", "ཁ", "ག", "ང", "ཀྵ"])
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
# Fixture: AttnLabelEncode
@pytest.fixture
def setup_attn_label_encode(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
@pytest.fixture
def setup_attn_label_encode_chinese(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
# Fixture: DetLabelEncode
@pytest.fixture
def setup_det_label_encode():
return DetLabelEncode()
# Test functions
@pytest.mark.parametrize("label, expected", [
("label1", 0),
("unknown_label", None),
("", None),
])
def test_cls_label_encode_call(setup_cls_label_encode, label, expected):
encoder = setup_cls_label_encode
data = {"label": label}
encoded_data = encoder(data)
if expected is not None:
assert encoded_data["label"] == expected, f"Expected {expected} for label: {label}, but got {encoded_data['label']}"
else:
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"
@pytest.mark.parametrize("label, expected", [
("abc", np.array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0])),
("unknown", None),
("", None),
("a" * 20, None),
])
def test_ctc_label_encode_call(setup_ctc_label_encode, label, expected):
encoder = setup_ctc_label_encode
data = {"label": label}
encoded_data = encoder(data)
if expected is not None:
assert np.array_equal(encoded_data["label"], expected), f"Failed for label: {label}, expected {expected} but got {encoded_data['label']}"
assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
else:
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"
@pytest.mark.parametrize("label, expected_result", [
("你好世界", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_chinese(setup_ctc_label_encode_chinese, label, expected_result):
encoder = setup_ctc_label_encode_chinese
data = {"label": label}
encoded_data = encoder(data)
assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Chinese text: {label}"
assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
@pytest.mark.parametrize("label, expected_result", [
("ཀཁགང", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
("ཀྵཁགང", np.array([5, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_tibetan(setup_ctc_label_encode_tibetan, label, expected_result):
encoder = setup_ctc_label_encode_tibetan
data = {"label": label}
encoded_data = encoder(data)
assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Tibetan text: {label}"
# assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
@pytest.mark.parametrize("label, expected_shape, expected_length", [
("abc", (10,), 3),
("unknown", None, None),
("", None, None),
("a" * 20, None, None),
])
def test_attn_label_encode_call(setup_attn_label_encode, label, expected_shape, expected_length):
encoder = setup_attn_label_encode
data = {"label": label}
encoded_data = encoder(data)
if expected_shape is not None:
assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"
else:
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"
@pytest.mark.parametrize("label, expected_shape, expected_length", [
("你好世界", (10,), 4),
])
def test_attn_label_encode_call_valid_text_chinese(setup_attn_label_encode_chinese, label, expected_shape, expected_length):
encoder = setup_attn_label_encode_chinese
data = {"label": label}
encoded_data = encoder(data)
assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"
@pytest.mark.parametrize("label, expected_texts", [
('[{"points": [[0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', ["text"]),
("[]", None),
("", pytest.raises(json.JSONDecodeError)),
('[{"points": [0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', pytest.raises(json.JSONDecodeError)),
])
def test_det_label_encode_call(setup_det_label_encode, label, expected_texts):
encoder = setup_det_label_encode
data = {"label": label}
if isinstance(expected_texts, list):
encoded_data = encoder(data)
assert "polys" in encoded_data, "Missing polys in encoded data"
assert "texts" in encoded_data, "Missing texts in encoded data"
assert "ignore_tags" in encoded_data, "Missing ignore_tags in encoded data"
assert encoded_data["texts"] == expected_texts, f"Expected texts {expected_texts} but got {encoded_data['texts']}"
elif isinstance(expected_texts, type(pytest.raises(Exception))):
with expected_texts:
encoder(data)
else:
encoded_data = encoder(data)
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}" |
我按照您的意思添加并在项目中成功运行了该文件,终端没有任何报错信息 |
把那个单测文件也提交上来吧,针对藏文叠写字符再多添加几个单测。@hicricket |
我尝试添加了藏文叠字,发现len(label)不为字符串全字匹配的文本数,增加了最大字符数为25后将数值补0为25长度,得到以下报错
我修改后的代码如下: import os
import sys
import pytest
import numpy as np
import json
# Import modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))
from ppocr.data.imaug.label_ops import (
ClsLabelEncode,
DetLabelEncode,
CTCLabelEncode,
AttnLabelEncode,
)
print(sys.path) # 添加这行来检查路径
# Data generator function
def generate_character_dict(tmp_path, characters):
character_dict = tmp_path / "char_dict.txt"
character_dict.write_text("\n".join(characters) + "\n")
return str(character_dict)
# Fixture: ClsLabelEncode
@pytest.fixture
def setup_cls_label_encode():
return ClsLabelEncode(label_list=["label1", "label2", "label3"])
# Fixture: CTCLabelEncode
@pytest.fixture
def setup_ctc_label_encode(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
@pytest.fixture
def setup_ctc_label_encode_chinese(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
@pytest.fixture
def setup_ctc_label_encode_tibetan(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["ཀ", "ཁ", "ག", "ང", "ཀྵ", "ཀྪོ", "ཀྩོ", "ཀྤྲེ", "ཀླཱ", "གྒྲ"])
print(f"Character dictionary path: {character_dict_path}")
with open(character_dict_path, 'r', encoding='utf-8') as f:
print(f"Character dictionary content:\n{f.read()}")
return CTCLabelEncode(max_text_length=25, character_dict_path=character_dict_path)
# Fixture: AttnLabelEncode
@pytest.fixture
def setup_attn_label_encode(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
@pytest.fixture
def setup_attn_label_encode_chinese(tmp_path):
character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)
# Fixture: DetLabelEncode
@pytest.fixture
def setup_det_label_encode():
return DetLabelEncode()
# Test functions
@pytest.mark.parametrize("label, expected", [
("label1", 0),
("unknown_label", None),
("", None),
])
def test_cls_label_encode_call(setup_cls_label_encode, label, expected):
encoder = setup_cls_label_encode
data = {"label": label}
encoded_data = encoder(data)
if expected is not None:
assert encoded_data["label"] == expected, f"Expected {expected} for label: {label}, but got {encoded_data['label']}"
else:
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"
@pytest.mark.parametrize("label, expected", [
("abc", np.array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0])),
("unknown", None),
("", None),
("a" * 20, None),
])
def test_ctc_label_encode_call(setup_ctc_label_encode, label, expected):
encoder = setup_ctc_label_encode
data = {"label": label}
encoded_data = encoder(data)
if expected is not None:
assert np.array_equal(encoded_data["label"], expected), f"Failed for label: {label}, expected {expected} but got {encoded_data['label']}"
assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
else:
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"
@pytest.mark.parametrize("label, expected_result", [
("你好世界", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_chinese(setup_ctc_label_encode_chinese, label, expected_result):
encoder = setup_ctc_label_encode_chinese
data = {"label": label}
encoded_data = encoder(data)
assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Chinese text: {label}"
assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
@pytest.mark.parametrize("label, expected_result", [
("ཀཁགངཀྪོཀྩོཀྤྲེཀླཱགྒྲགྒྲ", np.array([1, 2, 3, 4, 6, 7, 8, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
("ཀྵཁགངཀྩོཀྪོ", np.array([5, 2, 3, 4, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_tibetan(setup_ctc_label_encode_tibetan, label, expected_result):
encoder = setup_ctc_label_encode_tibetan
data = {"label": label}
encoded_data = encoder(data)
print(f"Encoded data for label '{label}': {encoded_data}")
assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Tibetan text: {label}"
assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
@pytest.mark.parametrize("label, expected_shape, expected_length", [
("abc", (10,), 3),
("unknown", None, None),
("", None, None),
("a" * 20, None, None),
])
def test_attn_label_encode_call(setup_attn_label_encode, label, expected_shape, expected_length):
encoder = setup_attn_label_encode
data = {"label": label}
encoded_data = encoder(data)
if expected_shape is not None:
assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"
else:
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"
@pytest.mark.parametrize("label, expected_shape, expected_length", [
("你好世界", (10,), 4),
])
def test_attn_label_encode_call_valid_text_chinese(setup_attn_label_encode_chinese, label, expected_shape, expected_length):
encoder = setup_attn_label_encode_chinese
data = {"label": label}
encoded_data = encoder(data)
assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"
@pytest.mark.parametrize("label, expected_texts", [
('[{"points": [[0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', ["text"]),
("[]", None),
("", pytest.raises(json.JSONDecodeError)),
('[{"points": [0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', pytest.raises(json.JSONDecodeError)),
])
def test_det_label_encode_call(setup_det_label_encode, label, expected_texts):
encoder = setup_det_label_encode
data = {"label": label}
if isinstance(expected_texts, list):
encoded_data = encoder(data)
assert "polys" in encoded_data, "Missing polys in encoded data"
assert "texts" in encoded_data, "Missing texts in encoded data"
assert "ignore_tags" in encoded_data, "Missing ignore_tags in encoded data"
assert encoded_data["texts"] == expected_texts, f"Expected texts {expected_texts} but got {encoded_data['texts']}"
elif isinstance(expected_texts, type(pytest.raises(Exception))):
with expected_texts:
encoder(data)
else:
encoded_data = encoder(data)
assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}" 也就是说我之前的训练都是把最大字符数x3以保证训练标签在最大范围内 |
应该,可以调整一下判断适配藏文。 assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}" assert encoded_data["length"] == len(expected_result[expected_result != 0]), f"Expected length {len(expected_result[expected_result != 0])} but got {encoded_data['length']}" |
可以的,成功了:19 passed, 1 warning in 1.72s |
# logger.warning('{} is not in dict'.format(char)) | ||
continue | ||
text_list.append(self.dict[char]) | ||
"""Full word matching dictionary line""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
对于中文等长字符集的语言来说,会不会导致处理速度变得特别慢?建议再考虑一下多语言的兼容性问题。
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
特别慢倒是不会 但确实会影响到计算效率 特别是主流的字典行不存在多字符 只能是特殊情况用比较好 我觉得可以保留原来的逻辑 初始化的时候先对字典行进行判断 若存在多字符的行再用修改后的逻辑 这样也是个方法我想的是
@hicricket 该PR在PaddleX的例行监控流水线阻塞了15小时,可能影响到了PaddleX的使用,建议排查一下原因哈,流水线链接:https://xly.bce.baidu.com/paddlepaddle/PaddleXOpenSource/newipipe/detail/11368647/job/27277988 |
对于藏文字典我使用过单字符的字典,长文本识别效果不太理想,所以把藏文叠字的所有常见组合作为字典,大概有1w行,训练时间并没有太大影响,增加类别数确实提升了叠字特别是3-4个字符叠起来的字的识别准确率,这也算是一种尝试吧 |
Resolve the issue of labels not matching dictionary characters in full during training