Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Question]: ERNIE模型支持512的token长度,那为何预测时要按照512的字符长度截断呢 #9236

Open
Mengyueke opened this issue Oct 9, 2024 · 3 comments
Assignees
Labels
question Further information is requested

Comments

@Mengyueke
Copy link

请提出你的问题

ERNIE模型支持512的token长度,那为何预测时要按照512的字符串长度截断呢,512字符串长度对于英文文本来说实在太短了

@Mengyueke Mengyueke added the question Further information is requested label Oct 9, 2024
@Mengyueke
Copy link
Author

在UIE中

@DrownFish19
Copy link
Collaborator

您好,可以提供具体代码位置吗?

@Mengyueke
Copy link
Author

在PaddleNLP/paddlenlp/taskflow/information_extraction.py中,推理过程_single_stage_predict会调用auto_splitter函数对输入字符串进行切割

    def _single_stage_predict(self, inputs):
        input_texts = [d["text"] for d in inputs]
        prompts = [d["prompt"] for d in inputs]

        # max predict length should exclude the length of prompt and summary tokens
        max_predict_len = self._max_seq_len - len(max(prompts)) - self._summary_token_num

        if self._init_class in ["UIEX"]:
            bbox_list = [d["bbox"] for d in inputs]
            short_input_texts, short_bbox_list, input_mapping = self._auto_splitter(
                input_texts, max_predict_len, bbox_list=bbox_list, split_sentence=self._split_sentence
            )
        else:
            short_input_texts, input_mapping = self._auto_splitter(
                input_texts, max_predict_len, split_sentence=self._split_sentence
            )

也就是PaddleNLP/paddlenlp/taskflow
/task.py中的auto_splitter函数,这个函数是按照字符串长度进行切割的,如果对中文,字符串长度确实和token长度大致相同,但是英文则相差很大,这样切分的英文文本的token长度应该会远小于512

    def _auto_splitter(self, input_texts, max_text_len, bbox_list=None, split_sentence=False):
        """
        Split the raw texts automatically for model inference.
        Args:
            input_texts (List[str]): input raw texts.
            max_text_len (int): cutting length.
            bbox_list (List[float, float,float, float]): bbox for document input.
            split_sentence (bool): If True, sentence-level split will be performed.
                `split_sentence` will be set to False if bbox_list is not None since sentence-level split is not support for document.
        return:
            short_input_texts (List[str]): the short input texts for model inference.
            input_mapping (dict): mapping between raw text and short input texts.
        """
        input_mapping = {}
        short_input_texts = []
        cnt_org = 0
        cnt_short = 0
        with_bbox = False
        if bbox_list:
            with_bbox = True
            short_bbox_list = []
            if split_sentence:
                logger.warning(
                    "`split_sentence` will be set to False if bbox_list is not None since sentence-level split is not support for document."
                )
                split_sentence = False

        for idx in range(len(input_texts)):
            if not split_sentence:
                sens = [input_texts[idx]]
            else:
                sens = cut_chinese_sent(input_texts[idx])
            for sen in sens:
                lens = len(sen)
                if lens <= max_text_len:
                    short_input_texts.append(sen)
                    if with_bbox:
                        short_bbox_list.append(bbox_list[idx])
                    input_mapping.setdefault(cnt_org, []).append(cnt_short)
                    cnt_short += 1
                else:
                    temp_text_list = [sen[i : i + max_text_len] for i in range(0, lens, max_text_len)]
                    short_input_texts.extend(temp_text_list)
                    if with_bbox:
                        if bbox_list[idx] is not None:
                            temp_bbox_list = [
                                bbox_list[idx][i : i + max_text_len] for i in range(0, lens, max_text_len)
                            ]
                            short_bbox_list.extend(temp_bbox_list)
                        else:
                            short_bbox_list.extend([None for _ in range(len(temp_text_list))])
                    short_idx = cnt_short
                    cnt_short += math.ceil(lens / max_text_len)
                    temp_text_id = [short_idx + i for i in range(cnt_short - short_idx)]
                    input_mapping.setdefault(cnt_org, []).extend(temp_text_id)
            cnt_org += 1
        if with_bbox:
            return short_input_texts, short_bbox_list, input_mapping
        else:
            return short_input_texts, input_mapping

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

3 participants