Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update label_ops.py #13665

Closed
wants to merge 11 commits into from
Closed

Update label_ops.py #13665

wants to merge 11 commits into from

Conversation

hicricket
Copy link

Resolve the issue of labels not matching dictionary characters in full during training

Copy link
Collaborator

@GreatV GreatV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个PR是用来解决什么问题,能提供更详细一点的描述和说明?

@hicricket hicricket closed this Aug 15, 2024
@hicricket
Copy link
Author

这个PR是用来解决什么问题,能提供更详细一点的描述和说明?

我在训练藏文数据集的时候,对于字典中的多个字符组合的叠写字符无法正常匹配,训练标签只能匹配字典中的单字符

@hicricket
Copy link
Author

这个PR是用来解决什么问题,能提供更详细一点的描述和说明?
逻辑改为全字匹配,训练实测,解决了训练标签于字典行多字符索引无法匹配问题

@GreatV GreatV reopened this Aug 15, 2024
@hicricket hicricket closed this Aug 15, 2024
@hicricket hicricket reopened this Aug 15, 2024
@hicricket hicricket closed this Aug 15, 2024
@GreatV
Copy link
Collaborator

GreatV commented Aug 15, 2024

建议在代码添加一段英文注释说明一下,修复codestyle,删除掉上面被注释掉的代码。还要测试一下中英文正常的训练是否受影响。如果都没问题就可以合入了。直接在这个PR上继续commit就行。

感谢您的贡献。

@hicricket hicricket reopened this Aug 15, 2024
@hicricket hicricket requested a review from GreatV August 15, 2024 09:50
Copy link
Collaborator

@GreatV GreatV left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

感谢您的贡献,需要等我抽时间测试一下才能合入。

@GreatV
Copy link
Collaborator

GreatV commented Aug 19, 2024

建议添加一个单测tests/test_label_ops.py来测试修改的有效性,以及不会改变之前行为。

import os
import sys
import pytest
import numpy as np
import json

# Import modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))

from ppocr.data.imaug.label_ops import (
    ClsLabelEncode,
    DetLabelEncode,
    CTCLabelEncode,
    AttnLabelEncode,
)

# Data generator function
def generate_character_dict(tmp_path, characters):
    character_dict = tmp_path / "char_dict.txt"
    character_dict.write_text("\n".join(characters) + "\n")
    return str(character_dict)

# Fixture: ClsLabelEncode
@pytest.fixture
def setup_cls_label_encode():
    return ClsLabelEncode(label_list=["label1", "label2", "label3"])

# Fixture: CTCLabelEncode
@pytest.fixture
def setup_ctc_label_encode(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
    return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

@pytest.fixture
def setup_ctc_label_encode_chinese(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
    return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

@pytest.fixture
def setup_ctc_label_encode_tibetan(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["ཀ", "ཁ", "ག", "ང", "ཀྵ"])
    return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

# Fixture: AttnLabelEncode
@pytest.fixture
def setup_attn_label_encode(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
    return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

@pytest.fixture
def setup_attn_label_encode_chinese(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
    return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

# Fixture: DetLabelEncode
@pytest.fixture
def setup_det_label_encode():
    return DetLabelEncode()

# Test functions
@pytest.mark.parametrize("label, expected", [
    ("label1", 0),
    ("unknown_label", None),
    ("", None),
])
def test_cls_label_encode_call(setup_cls_label_encode, label, expected):
    encoder = setup_cls_label_encode
    data = {"label": label}
    encoded_data = encoder(data)
    
    if expected is not None:
        assert encoded_data["label"] == expected, f"Expected {expected} for label: {label}, but got {encoded_data['label']}"
    else:
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@pytest.mark.parametrize("label, expected", [
    ("abc", np.array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0])),
    ("unknown", None),
    ("", None),
    ("a" * 20, None),
])
def test_ctc_label_encode_call(setup_ctc_label_encode, label, expected):
    encoder = setup_ctc_label_encode
    data = {"label": label}
    encoded_data = encoder(data)
    
    if expected is not None:
        assert np.array_equal(encoded_data["label"], expected), f"Failed for label: {label}, expected {expected} but got {encoded_data['label']}"
        assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
    else:
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@pytest.mark.parametrize("label, expected_result", [
    ("你好世界", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_chinese(setup_ctc_label_encode_chinese, label, expected_result):
    encoder = setup_ctc_label_encode_chinese
    data = {"label": label}
    encoded_data = encoder(data)
    
    assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Chinese text: {label}"
    assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"

@pytest.mark.parametrize("label, expected_result", [
    ("ཀཁགང", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
    ("ཀྵཁགང", np.array([5, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_tibetan(setup_ctc_label_encode_tibetan, label, expected_result):
    encoder = setup_ctc_label_encode_tibetan
    data = {"label": label}
    encoded_data = encoder(data)
        
    assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Tibetan text: {label}"
    # assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"

@pytest.mark.parametrize("label, expected_shape, expected_length", [
    ("abc", (10,), 3),
    ("unknown", None, None),
    ("", None, None),
    ("a" * 20, None, None),
])
def test_attn_label_encode_call(setup_attn_label_encode, label, expected_shape, expected_length):
    encoder = setup_attn_label_encode
    data = {"label": label}
    encoded_data = encoder(data)
    
    if expected_shape is not None:
        assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
        assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
        assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
        assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"
    else:
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@pytest.mark.parametrize("label, expected_shape, expected_length", [
    ("你好世界", (10,), 4),
])
def test_attn_label_encode_call_valid_text_chinese(setup_attn_label_encode_chinese, label, expected_shape, expected_length):
    encoder = setup_attn_label_encode_chinese
    data = {"label": label}
    encoded_data = encoder(data)
    
    assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
    assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
    assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
    assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"

@pytest.mark.parametrize("label, expected_texts", [
    ('[{"points": [[0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', ["text"]),
    ("[]", None),
    ("", pytest.raises(json.JSONDecodeError)),
    ('[{"points": [0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', pytest.raises(json.JSONDecodeError)),
])
def test_det_label_encode_call(setup_det_label_encode, label, expected_texts):
    encoder = setup_det_label_encode
    data = {"label": label}
    
    if isinstance(expected_texts, list):
        encoded_data = encoder(data)
        assert "polys" in encoded_data, "Missing polys in encoded data"
        assert "texts" in encoded_data, "Missing texts in encoded data"
        assert "ignore_tags" in encoded_data, "Missing ignore_tags in encoded data"
        assert encoded_data["texts"] == expected_texts, f"Expected texts {expected_texts} but got {encoded_data['texts']}"
    elif isinstance(expected_texts, type(pytest.raises(Exception))):
        with expected_texts:
            encoder(data)
    else:
        encoded_data = encoder(data)
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@hicricket
Copy link
Author

我按照您的意思添加并在项目中成功运行了该文件,终端没有任何报错信息

@GreatV
Copy link
Collaborator

GreatV commented Aug 19, 2024

把那个单测文件也提交上来吧,针对藏文叠写字符再多添加几个单测。@hicricket

@hicricket
Copy link
Author

hicricket commented Aug 19, 2024

我尝试添加了藏文叠字,发现len(label)不为字符串全字匹配的文本数,增加了最大字符数为25后将数值补0为25长度,得到以下报错

FAILED tests/test_label_ops.py::test_ctc_label_encode_call_valid_text_tibetan[\u0f40\u0f41\u0f42\u0f44\u0f40\u0faa\u0f7c\u0f40\u0fa9\u0f7c\u0f40\u0fa4\u0fb2\u0f7a\u0f40\u0fb3\u0f71\u0f42\u0f92\u0fb2\u0f42\u0f92\u0fb2-expected_result0] - AssertionError: Expected length 23 but got 10
FAILED tests/test_label_ops.py::test_ctc_label_encode_call_valid_text_tibetan[\u0f40\u0fb5\u0f41\u0f42\u0f44\u0f40\u0fa9\u0f7c\u0f40\u0faa\u0f7c-expected_result1] - AssertionError: Expected length 11 but got 6
============================================================================================== 2 failed, 17 passed, 1 warning in 1.79s

我修改后的代码如下:

import os
import sys
import pytest
import numpy as np
import json

# Import modules
current_dir = os.path.dirname(os.path.abspath(__file__))
sys.path.append(os.path.abspath(os.path.join(current_dir, "..")))

from ppocr.data.imaug.label_ops import (
    ClsLabelEncode,
    DetLabelEncode,
    CTCLabelEncode,
    AttnLabelEncode,
)
print(sys.path)  # 添加这行来检查路径
# Data generator function
def generate_character_dict(tmp_path, characters):
    character_dict = tmp_path / "char_dict.txt"
    character_dict.write_text("\n".join(characters) + "\n")
    return str(character_dict)

# Fixture: ClsLabelEncode
@pytest.fixture
def setup_cls_label_encode():
    return ClsLabelEncode(label_list=["label1", "label2", "label3"])

# Fixture: CTCLabelEncode
@pytest.fixture
def setup_ctc_label_encode(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
    return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

@pytest.fixture
def setup_ctc_label_encode_chinese(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
    return CTCLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

@pytest.fixture
def setup_ctc_label_encode_tibetan(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["ཀ", "ཁ", "ག", "ང", "ཀྵ", "ཀྪོ", "ཀྩོ", "ཀྤྲེ", "ཀླཱ", "གྒྲ"])
    print(f"Character dictionary path: {character_dict_path}")
    with open(character_dict_path, 'r', encoding='utf-8') as f:
        print(f"Character dictionary content:\n{f.read()}")    
    return CTCLabelEncode(max_text_length=25, character_dict_path=character_dict_path)

# Fixture: AttnLabelEncode
@pytest.fixture
def setup_attn_label_encode(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["a", "b", "c"])
    return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

@pytest.fixture
def setup_attn_label_encode_chinese(tmp_path):
    character_dict_path = generate_character_dict(tmp_path, ["你", "好", "世", "界"])
    return AttnLabelEncode(max_text_length=10, character_dict_path=character_dict_path)

# Fixture: DetLabelEncode
@pytest.fixture
def setup_det_label_encode():
    return DetLabelEncode()

# Test functions
@pytest.mark.parametrize("label, expected", [
    ("label1", 0),
    ("unknown_label", None),
    ("", None),
])
def test_cls_label_encode_call(setup_cls_label_encode, label, expected):
    encoder = setup_cls_label_encode
    data = {"label": label}
    encoded_data = encoder(data)
    
    if expected is not None:
        assert encoded_data["label"] == expected, f"Expected {expected} for label: {label}, but got {encoded_data['label']}"
    else:
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@pytest.mark.parametrize("label, expected", [
    ("abc", np.array([1, 2, 3, 0, 0, 0, 0, 0, 0, 0])),
    ("unknown", None),
    ("", None),
    ("a" * 20, None),
])
def test_ctc_label_encode_call(setup_ctc_label_encode, label, expected):
    encoder = setup_ctc_label_encode
    data = {"label": label}
    encoded_data = encoder(data)
    
    if expected is not None:
        assert np.array_equal(encoded_data["label"], expected), f"Failed for label: {label}, expected {expected} but got {encoded_data['label']}"
        assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
    else:
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@pytest.mark.parametrize("label, expected_result", [
    ("你好世界", np.array([1, 2, 3, 4, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_chinese(setup_ctc_label_encode_chinese, label, expected_result):
    encoder = setup_ctc_label_encode_chinese
    data = {"label": label}
    encoded_data = encoder(data)
    
    assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Chinese text: {label}"
    assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"

@pytest.mark.parametrize("label, expected_result", [
    ("ཀཁགངཀྪོཀྩོཀྤྲེཀླཱགྒྲགྒྲ", np.array([1, 2, 3, 4, 6, 7, 8, 9, 10, 10, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
    ("ཀྵཁགངཀྩོཀྪོ", np.array([5, 2, 3, 4, 7, 6, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])),
])
def test_ctc_label_encode_call_valid_text_tibetan(setup_ctc_label_encode_tibetan, label, expected_result):
    encoder = setup_ctc_label_encode_tibetan
    data = {"label": label}
    encoded_data = encoder(data)
    print(f"Encoded data for label '{label}': {encoded_data}")        
    assert np.array_equal(encoded_data["label"], expected_result), f"Failed for Tibetan text: {label}"
    assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"

@pytest.mark.parametrize("label, expected_shape, expected_length", [
    ("abc", (10,), 3),
    ("unknown", None, None),
    ("", None, None),
    ("a" * 20, None, None),
])
def test_attn_label_encode_call(setup_attn_label_encode, label, expected_shape, expected_length):
    encoder = setup_attn_label_encode
    data = {"label": label}
    encoded_data = encoder(data)
    
    if expected_shape is not None:
        assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
        assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
        assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
        assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"
    else:
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

@pytest.mark.parametrize("label, expected_shape, expected_length", [
    ("你好世界", (10,), 4),
])
def test_attn_label_encode_call_valid_text_chinese(setup_attn_label_encode_chinese, label, expected_shape, expected_length):
    encoder = setup_attn_label_encode_chinese
    data = {"label": label}
    encoded_data = encoder(data)
    
    assert encoded_data["label"].shape == expected_shape, f"Expected shape {expected_shape} but got {encoded_data['label'].shape}"
    assert encoded_data["label"][0] == 0, f"Expected SOS token at start but got {encoded_data['label'][0]}"
    assert encoded_data["label"][expected_length + 1] == len(encoder.character) - 1, f"Expected EOS token at position {expected_length + 1} but got {encoded_data['label'][expected_length + 1]}"
    assert encoded_data["length"] == expected_length, f"Expected length {expected_length} but got {encoded_data['length']}"

@pytest.mark.parametrize("label, expected_texts", [
    ('[{"points": [[0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', ["text"]),
    ("[]", None),
    ("", pytest.raises(json.JSONDecodeError)),
    ('[{"points": [0,0],[1,0],[1,1],[0,1]], "transcription": "text"}]', pytest.raises(json.JSONDecodeError)),
])
def test_det_label_encode_call(setup_det_label_encode, label, expected_texts):
    encoder = setup_det_label_encode
    data = {"label": label}
    
    if isinstance(expected_texts, list):
        encoded_data = encoder(data)
        assert "polys" in encoded_data, "Missing polys in encoded data"
        assert "texts" in encoded_data, "Missing texts in encoded data"
        assert "ignore_tags" in encoded_data, "Missing ignore_tags in encoded data"
        assert encoded_data["texts"] == expected_texts, f"Expected texts {expected_texts} but got {encoded_data['texts']}"
    elif isinstance(expected_texts, type(pytest.raises(Exception))):
        with expected_texts:
            encoder(data)
    else:
        encoded_data = encoder(data)
        assert encoded_data is None, f"Expected None for label: {label}, but got {encoded_data}"

也就是说我之前的训练都是把最大字符数x3以保证训练标签在最大范围内

@GreatV
Copy link
Collaborator

GreatV commented Aug 19, 2024

应该,可以调整一下判断适配藏文。

assert encoded_data["length"] == len(label), f"Expected length {len(label)} but got {encoded_data['length']}"
assert encoded_data["length"] == len(expected_result[expected_result != 0]), f"Expected length {len(expected_result[expected_result != 0])} but got {encoded_data['length']}"

@hicricket
Copy link
Author

可以的,成功了:19 passed, 1 warning in 1.72s

tests/test_label_ops.py Outdated Show resolved Hide resolved
tests/test_label_ops.py Outdated Show resolved Hide resolved
# logger.warning('{} is not in dict'.format(char))
continue
text_list.append(self.dict[char])
"""Full word matching dictionary line"""
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

对于中文等长字符集的语言来说,会不会导致处理速度变得特别慢?建议再考虑一下多语言的兼容性问题。

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

特别慢倒是不会 但确实会影响到计算效率 特别是主流的字典行不存在多字符 只能是特殊情况用比较好 我觉得可以保留原来的逻辑 初始化的时候先对字典行进行判断 若存在多字符的行再用修改后的逻辑 这样也是个方法我想的是

@GreatV GreatV requested a review from jzhang533 August 20, 2024 03:31
@changdazhou
Copy link
Collaborator

@hicricket 该PR在PaddleX的例行监控流水线阻塞了15小时,可能影响到了PaddleX的使用,建议排查一下原因哈,流水线链接:https://xly.bce.baidu.com/paddlepaddle/PaddleXOpenSource/newipipe/detail/11368647/job/27277988

@hicricket
Copy link
Author

@hicricket 该PR在PaddleX的例行监控流水线阻塞了15小时,可能影响到了PaddleX的使用,建议排查一下原因哈,流水线链接:https://xly.bce.baidu.com/paddlepaddle/PaddleXOpenSource/newipipe/detail/11368647/job/27277988

对于藏文字典我使用过单字符的字典,长文本识别效果不太理想,所以把藏文叠字的所有常见组合作为字典,大概有1w行,训练时间并没有太大影响,增加类别数确实提升了叠字特别是3-4个字符叠起来的字的识别准确率,这也算是一种尝试吧

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants