diff --git a/CHANGELOG.md b/CHANGELOG.md index 45d58139..ca246338 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,8 +1,10 @@ -## 0.7.10-dev1 +## 0.7.10-dev2 * fix: Reduce Chipper memory consumption on x86_64 cpus * fix: Skips ordering elements coming from Chipper * fix: After refactoring to introduce Chipper, annotate() weren't able to show text with extra info from elements, this is fixed now. +* feat: add table cell and dataframe output formats to table transformer's `run_prediction` call +* breaking change: function `unstructured_inference.models.tables.recognize` no longer takes `out_html` parameter and it now only returns table cell data format (lists of dictionaries) ## 0.7.9 diff --git a/test_unstructured_inference/conftest.py b/test_unstructured_inference/conftest.py index 41a8976c..75771080 100644 --- a/test_unstructured_inference/conftest.py +++ b/test_unstructured_inference/conftest.py @@ -2,7 +2,11 @@ import pytest from PIL import Image -from unstructured_inference.inference.elements import EmbeddedTextRegion, Rectangle, TextRegion +from unstructured_inference.inference.elements import ( + EmbeddedTextRegion, + Rectangle, + TextRegion, +) from unstructured_inference.inference.layoutelement import LayoutElement @@ -122,3 +126,43 @@ def mock_layout(mock_embedded_text_regions): LayoutElement(text=r.text, type="UncategorizedText", bbox=r.bbox) for r in mock_embedded_text_regions ] + + +@pytest.fixture() +def example_table_cells(): + cells = [ + {"cell text": "Disability Category", "row_nums": [0, 1], "column_nums": [0]}, + {"cell text": "Participants", "row_nums": [0, 1], "column_nums": [1]}, + {"cell text": "Ballots Completed", "row_nums": [0, 1], "column_nums": [2]}, + {"cell text": "Ballots Incomplete/Terminated", "row_nums": [0, 1], "column_nums": [3]}, + {"cell text": "Results", "row_nums": [0], "column_nums": [4, 5]}, + {"cell text": "Accuracy", "row_nums": [1], "column_nums": [4]}, + {"cell text": "Time to complete", "row_nums": [1], "column_nums": [5]}, + {"cell text": "Blind", "row_nums": [2], "column_nums": [0]}, + {"cell text": "Low Vision", "row_nums": [3], "column_nums": [0]}, + {"cell text": "Dexterity", "row_nums": [4], "column_nums": [0]}, + {"cell text": "Mobility", "row_nums": [5], "column_nums": [0]}, + {"cell text": "5", "row_nums": [2], "column_nums": [1]}, + {"cell text": "5", "row_nums": [3], "column_nums": [1]}, + {"cell text": "5", "row_nums": [4], "column_nums": [1]}, + {"cell text": "3", "row_nums": [5], "column_nums": [1]}, + {"cell text": "1", "row_nums": [2], "column_nums": [2]}, + {"cell text": "2", "row_nums": [3], "column_nums": [2]}, + {"cell text": "4", "row_nums": [4], "column_nums": [2]}, + {"cell text": "3", "row_nums": [5], "column_nums": [2]}, + {"cell text": "4", "row_nums": [2], "column_nums": [3]}, + {"cell text": "3", "row_nums": [3], "column_nums": [3]}, + {"cell text": "1", "row_nums": [4], "column_nums": [3]}, + {"cell text": "0", "row_nums": [5], "column_nums": [3]}, + {"cell text": "34.5%, n=1", "row_nums": [2], "column_nums": [4]}, + {"cell text": "98.3% n=2 (97.7%, n=3)", "row_nums": [3], "column_nums": [4]}, + {"cell text": "98.3%, n=4", "row_nums": [4], "column_nums": [4]}, + {"cell text": "95.4%, n=3", "row_nums": [5], "column_nums": [4]}, + {"cell text": "1199 sec, n=1", "row_nums": [2], "column_nums": [5]}, + {"cell text": "1716 sec, n=3 (1934 sec, n=2)", "row_nums": [3], "column_nums": [5]}, + {"cell text": "1672.1 sec, n=4", "row_nums": [4], "column_nums": [5]}, + {"cell text": "1416 sec, n=3", "row_nums": [5], "column_nums": [5]}, + ] + for i in range(len(cells)): + cells[i]["column header"] = False + return [cells] diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 740cbf0f..90348c1f 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -361,6 +361,54 @@ def test_table_prediction_tesseract(table_transformer, example_image): ) in prediction +@pytest.mark.parametrize( + ("output_format", "expectation"), + [ + ("html", "Blind51434.5%, n=1"), + ( + "cells", + { + "column_nums": [0], + "row_nums": [2], + "column header": False, + "cell text": "Blind", + }, + ), + ("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]), + (None, "Blind51434.5%, n=1"), + ], +) +def test_table_prediction_output_format( + output_format, + expectation, + table_transformer, + example_image, + mocker, + example_table_cells, +): + mocker.patch.object(tables, "recognize", return_value=example_table_cells) + mocker.patch.object( + tables.UnstructuredTableTransformerModel, + "get_structure", + return_value=None, + ) + mocker.patch.object(tables.UnstructuredTableTransformerModel, "get_tokens", return_value=None) + if output_format: + result = table_transformer.run_prediction(example_image, result_format=output_format) + else: + result = table_transformer.run_prediction(example_image) + + if output_format == "dataframe": + assert expectation in result.values + elif output_format == "cells": + # other output like bbox are flakey to test since they depend on OCR and it may change + # slightly when OCR pacakge changes or even on different machines + validation_fields = ("column_nums", "row_nums", "column header", "cell text") + assert expectation in [{key: cell[key] for key in validation_fields} for cell in result] + else: + assert expectation in result + + def test_table_prediction_tesseract_with_ocr_tokens(table_transformer, example_image): ocr_tokens = [ { diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index a28807cc..5a48fbf0 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.10-dev1" # pragma: no cover +__version__ = "0.7.10-dev2" # pragma: no cover diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 760dd46b..74cd64ae 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -19,6 +19,7 @@ from unstructured_inference.constants import ( TESSERACT_TEXT_HEIGHT, ) +from unstructured_inference.inference.layoutelement import table_cells_to_dataframe from unstructured_inference.logger import logger from unstructured_inference.models.table_postprocess import Rect from unstructured_inference.models.unstructuredmodel import UnstructuredModel @@ -176,6 +177,7 @@ def run_prediction( x: Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ocr_tokens: Optional[List[Dict]] = None, + result_format: Optional[str] = "html", ): """Predict table structure""" outputs_structure = self.get_structure(x, pad_for_structure_detection) @@ -186,8 +188,12 @@ def run_prediction( ) ocr_tokens = self.get_tokens(x=x) - html = recognize(outputs_structure, x, tokens=ocr_tokens, out_html=True)["html"] - prediction = html[0] if html else "" + prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0] + if result_format == "html": + # Convert cells to HTML + prediction = cells_to_html(prediction) or "" + elif result_format == "dataframe": + prediction = table_cells_to_dataframe(prediction) return prediction @@ -234,10 +240,8 @@ def get_class_map(data_type: str): } -def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False): +def recognize(outputs: dict, img: Image, tokens: list): """Recognize table elements.""" - out_formats = {} - str_class_name2idx = get_class_map("structure") str_class_idx2name = {v: k for k, v in str_class_name2idx.items()} str_class_thresholds = structure_class_thresholds @@ -248,14 +252,7 @@ def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False): # Further process the detected objects so they correspond to a consistent table tables_structure = objects_to_structures(objects, tokens, str_class_thresholds) # Enumerate all table cells: grid cells and spanning cells - tables_cells = [structure_to_cells(structure, tokens)[0] for structure in tables_structure] - - # Convert cells to HTML - if out_html: - tables_htmls = [cells_to_html(cells) for cells in tables_cells] - out_formats["html"] = tables_htmls - - return out_formats + return [structure_to_cells(structure, tokens)[0] for structure in tables_structure] def outputs_to_objects(outputs, img_size, class_idx2name):