From faac0a4d17caf3d9c145a99fb2a596925ad975e2 Mon Sep 17 00:00:00 2001 From: icecraft Date: Thu, 12 Sep 2024 14:12:47 +0800 Subject: [PATCH] fix: 1. resolve uncorrect pair relation of figure and footnote, 2. resolve uncorrect pair relation of table and caption #590 --- magic_pdf/libs/boxbase.py | 19 +++++ magic_pdf/model/magic_model.py | 141 ++++++++++++++++++++++----------- magic_pdf/tools/common.py | 2 +- 3 files changed, 116 insertions(+), 46 deletions(-) diff --git a/magic_pdf/libs/boxbase.py b/magic_pdf/libs/boxbase.py index 90f46ef2..0472328f 100644 --- a/magic_pdf/libs/boxbase.py +++ b/magic_pdf/libs/boxbase.py @@ -426,3 +426,22 @@ def dist(point1, point2): elif top: return y2 - y1b return 0.0 + + +def box_area(bbox): + return (bbox[2] - bbox[0]) * (bbox[3] - bbox[1]) + + +def get_overlap_area(bbox1, bbox2): + """计算box1和box2的重叠面积占bbox1的比例.""" + # Determine the coordinates of the intersection rectangle + x_left = max(bbox1[0], bbox2[0]) + y_top = max(bbox1[1], bbox2[1]) + x_right = min(bbox1[2], bbox2[2]) + y_bottom = min(bbox1[3], bbox2[3]) + + if x_right < x_left or y_bottom < y_top: + return 0.0 + + # The area of overlap area + return (x_right - x_left) * (y_bottom - y_top) diff --git a/magic_pdf/model/magic_model.py b/magic_pdf/model/magic_model.py index 61dc3a43..bd8e061a 100644 --- a/magic_pdf/model/magic_model.py +++ b/magic_pdf/model/magic_model.py @@ -1,8 +1,9 @@ import json from magic_pdf.libs.boxbase import (_is_in, _is_part_overlap, bbox_distance, - bbox_relative_pos, calculate_iou, - calculate_overlap_area_in_bbox1_area_ratio) + bbox_relative_pos, box_area, calculate_iou, + calculate_overlap_area_in_bbox1_area_ratio, + get_overlap_area) from magic_pdf.libs.commons import fitz, join_path from magic_pdf.libs.coordinate_transform import get_scale_ratio from magic_pdf.libs.local_math import float_gt @@ -12,6 +13,7 @@ from magic_pdf.rw.DiskReaderWriter import DiskReaderWriter CAPATION_OVERLAP_AREA_RATIO = 0.6 +MERGE_BOX_OVERLAP_AREA_RATIO = 1.1 class MagicModel: @@ -124,49 +126,51 @@ def __fix_footnote(self): tables.append(obj) if len(footnotes) * len(figures) == 0: continue - dis_figure_footnote = {} - dis_table_footnote = {} - - for i in range(len(footnotes)): - for j in range(len(figures)): - pos_flag_count = sum( - list( - map( - lambda x: 1 if x else 0, - bbox_relative_pos( - footnotes[i]['bbox'], figures[j]['bbox'] - ), - ) + dis_figure_footnote = {} + dis_table_footnote = {} + + for i in range(len(footnotes)): + for j in range(len(figures)): + pos_flag_count = sum( + list( + map( + lambda x: 1 if x else 0, + bbox_relative_pos( + footnotes[i]['bbox'], figures[j]['bbox'] + ), ) ) - if pos_flag_count > 1: - continue - dis_figure_footnote[i] = min( - bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), - dis_figure_footnote.get(i, float('inf')), - ) - for i in range(len(footnotes)): - for j in range(len(tables)): - pos_flag_count = sum( - list( - map( - lambda x: 1 if x else 0, - bbox_relative_pos( - footnotes[i]['bbox'], tables[j]['bbox'] - ), - ) + ) + if pos_flag_count > 1: + continue + dis_figure_footnote[i] = min( + bbox_distance(figures[j]['bbox'], footnotes[i]['bbox']), + dis_figure_footnote.get(i, float('inf')), + ) + for i in range(len(footnotes)): + for j in range(len(tables)): + pos_flag_count = sum( + list( + map( + lambda x: 1 if x else 0, + bbox_relative_pos( + footnotes[i]['bbox'], tables[j]['bbox'] + ), ) ) - if pos_flag_count > 1: - continue + ) + if pos_flag_count > 1: + continue - dis_table_footnote[i] = min( - bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), - dis_table_footnote.get(i, float('inf')), - ) - for i in range(len(footnotes)): - if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]: - footnotes[i]['category_id'] = CategoryId.ImageFootnote + dis_table_footnote[i] = min( + bbox_distance(tables[j]['bbox'], footnotes[i]['bbox']), + dis_table_footnote.get(i, float('inf')), + ) + for i in range(len(footnotes)): + if i not in dis_figure_footnote: + continue + if dis_table_footnote.get(i, float('inf')) > dis_figure_footnote[i]: + footnotes[i]['category_id'] = CategoryId.ImageFootnote def __reduct_overlap(self, bboxes): N = len(bboxes) @@ -191,6 +195,44 @@ def __tie_up_category_by_distance( 筛选出所有和 merged bbox 有 overlap 且 overlap 面积大于 object 的面积的 subjects。 再求出筛选出的 subjects 和 object 的最短距离 """ + def search_overlap_between_boxes( + subject_idx, object_idx + ): + idxes = [subject_idx, object_idx] + x0s = [all_bboxes[idx]['bbox'][0] for idx in idxes] + y0s = [all_bboxes[idx]['bbox'][1] for idx in idxes] + x1s = [all_bboxes[idx]['bbox'][2] for idx in idxes] + y1s = [all_bboxes[idx]['bbox'][3] for idx in idxes] + + merged_bbox = [ + min(x0s), + min(y0s), + max(x1s), + max(y1s), + ] + ratio = 0 + + other_objects = list( + map( + lambda x: {'bbox': x['bbox'], 'score': x['score']}, + filter( + lambda x: x['category_id'] + not in (object_category_id, subject_category_id), + self.__model_list[page_no]['layout_dets'], + ), + ) + ) + for other_object in other_objects: + ratio = max( + ratio, + get_overlap_area( + merged_bbox, other_object['bbox'] + ) * 1.0 / box_area(all_bboxes[object_idx]['bbox']) + ) + if ratio >= MERGE_BOX_OVERLAP_AREA_RATIO: + break + + return ratio def may_find_other_nearest_bbox(subject_idx, object_idx): ret = float('inf') @@ -299,6 +341,15 @@ def expand_bbbox(idxes): ): continue + subject_idx, object_idx = i, j + if all_bboxes[j]['category_id'] == subject_category_id: + subject_idx, object_idx = j, i + + if search_overlap_between_boxes(subject_idx, object_idx) >= MERGE_BOX_OVERLAP_AREA_RATIO: + dis[i][j] = float('inf') + dis[j][i] = dis[i][j] + continue + dis[i][j] = bbox_distance(all_bboxes[i]['bbox'], all_bboxes[j]['bbox']) dis[j][i] = dis[i][j] @@ -627,13 +678,13 @@ def remove_duplicate_spans(spans): span['type'] = ContentType.Image elif category_id == 5: # 获取table模型结果 - latex = layout_det.get("latex", None) - html = layout_det.get("html", None) + latex = layout_det.get('latex', None) + html = layout_det.get('html', None) if latex: - span["latex"] = latex + span['latex'] = latex elif html: - span["html"] = html - span["type"] = ContentType.Table + span['html'] = html + span['type'] = ContentType.Table elif category_id == 13: span['content'] = layout_det['latex'] span['type'] = ContentType.InlineEquation diff --git a/magic_pdf/tools/common.py b/magic_pdf/tools/common.py index 6d7a381b..419457ec 100644 --- a/magic_pdf/tools/common.py +++ b/magic_pdf/tools/common.py @@ -46,7 +46,7 @@ def do_parse( end_page_id=None, ): if debug_able: - logger.warning("debug mode is on") + logger.warning('debug mode is on') f_dump_content_list = True f_draw_model_bbox = True