diff --git a/CHANGELOG.md b/CHANGELOG.md
index 45d58139..ca246338 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,8 +1,10 @@
-## 0.7.10-dev1
+## 0.7.10-dev2
* fix: Reduce Chipper memory consumption on x86_64 cpus
* fix: Skips ordering elements coming from Chipper
* fix: After refactoring to introduce Chipper, annotate() weren't able to show text with extra info from elements, this is fixed now.
+* feat: add table cell and dataframe output formats to table transformer's `run_prediction` call
+* breaking change: function `unstructured_inference.models.tables.recognize` no longer takes `out_html` parameter and it now only returns table cell data format (lists of dictionaries)
## 0.7.9
diff --git a/test_unstructured_inference/conftest.py b/test_unstructured_inference/conftest.py
index 41a8976c..75771080 100644
--- a/test_unstructured_inference/conftest.py
+++ b/test_unstructured_inference/conftest.py
@@ -2,7 +2,11 @@
import pytest
from PIL import Image
-from unstructured_inference.inference.elements import EmbeddedTextRegion, Rectangle, TextRegion
+from unstructured_inference.inference.elements import (
+ EmbeddedTextRegion,
+ Rectangle,
+ TextRegion,
+)
from unstructured_inference.inference.layoutelement import LayoutElement
@@ -122,3 +126,43 @@ def mock_layout(mock_embedded_text_regions):
LayoutElement(text=r.text, type="UncategorizedText", bbox=r.bbox)
for r in mock_embedded_text_regions
]
+
+
+@pytest.fixture()
+def example_table_cells():
+ cells = [
+ {"cell text": "Disability Category", "row_nums": [0, 1], "column_nums": [0]},
+ {"cell text": "Participants", "row_nums": [0, 1], "column_nums": [1]},
+ {"cell text": "Ballots Completed", "row_nums": [0, 1], "column_nums": [2]},
+ {"cell text": "Ballots Incomplete/Terminated", "row_nums": [0, 1], "column_nums": [3]},
+ {"cell text": "Results", "row_nums": [0], "column_nums": [4, 5]},
+ {"cell text": "Accuracy", "row_nums": [1], "column_nums": [4]},
+ {"cell text": "Time to complete", "row_nums": [1], "column_nums": [5]},
+ {"cell text": "Blind", "row_nums": [2], "column_nums": [0]},
+ {"cell text": "Low Vision", "row_nums": [3], "column_nums": [0]},
+ {"cell text": "Dexterity", "row_nums": [4], "column_nums": [0]},
+ {"cell text": "Mobility", "row_nums": [5], "column_nums": [0]},
+ {"cell text": "5", "row_nums": [2], "column_nums": [1]},
+ {"cell text": "5", "row_nums": [3], "column_nums": [1]},
+ {"cell text": "5", "row_nums": [4], "column_nums": [1]},
+ {"cell text": "3", "row_nums": [5], "column_nums": [1]},
+ {"cell text": "1", "row_nums": [2], "column_nums": [2]},
+ {"cell text": "2", "row_nums": [3], "column_nums": [2]},
+ {"cell text": "4", "row_nums": [4], "column_nums": [2]},
+ {"cell text": "3", "row_nums": [5], "column_nums": [2]},
+ {"cell text": "4", "row_nums": [2], "column_nums": [3]},
+ {"cell text": "3", "row_nums": [3], "column_nums": [3]},
+ {"cell text": "1", "row_nums": [4], "column_nums": [3]},
+ {"cell text": "0", "row_nums": [5], "column_nums": [3]},
+ {"cell text": "34.5%, n=1", "row_nums": [2], "column_nums": [4]},
+ {"cell text": "98.3% n=2 (97.7%, n=3)", "row_nums": [3], "column_nums": [4]},
+ {"cell text": "98.3%, n=4", "row_nums": [4], "column_nums": [4]},
+ {"cell text": "95.4%, n=3", "row_nums": [5], "column_nums": [4]},
+ {"cell text": "1199 sec, n=1", "row_nums": [2], "column_nums": [5]},
+ {"cell text": "1716 sec, n=3 (1934 sec, n=2)", "row_nums": [3], "column_nums": [5]},
+ {"cell text": "1672.1 sec, n=4", "row_nums": [4], "column_nums": [5]},
+ {"cell text": "1416 sec, n=3", "row_nums": [5], "column_nums": [5]},
+ ]
+ for i in range(len(cells)):
+ cells[i]["column header"] = False
+ return [cells]
diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py
index 740cbf0f..90348c1f 100644
--- a/test_unstructured_inference/models/test_tables.py
+++ b/test_unstructured_inference/models/test_tables.py
@@ -361,6 +361,54 @@ def test_table_prediction_tesseract(table_transformer, example_image):
) in prediction
+@pytest.mark.parametrize(
+ ("output_format", "expectation"),
+ [
+ ("html", "
Blind | 5 | 1 | 4 | 34.5%, n=1 | "),
+ (
+ "cells",
+ {
+ "column_nums": [0],
+ "row_nums": [2],
+ "column header": False,
+ "cell text": "Blind",
+ },
+ ),
+ ("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]),
+ (None, "
Blind | 5 | 1 | 4 | 34.5%, n=1 | "),
+ ],
+)
+def test_table_prediction_output_format(
+ output_format,
+ expectation,
+ table_transformer,
+ example_image,
+ mocker,
+ example_table_cells,
+):
+ mocker.patch.object(tables, "recognize", return_value=example_table_cells)
+ mocker.patch.object(
+ tables.UnstructuredTableTransformerModel,
+ "get_structure",
+ return_value=None,
+ )
+ mocker.patch.object(tables.UnstructuredTableTransformerModel, "get_tokens", return_value=None)
+ if output_format:
+ result = table_transformer.run_prediction(example_image, result_format=output_format)
+ else:
+ result = table_transformer.run_prediction(example_image)
+
+ if output_format == "dataframe":
+ assert expectation in result.values
+ elif output_format == "cells":
+ # other output like bbox are flakey to test since they depend on OCR and it may change
+ # slightly when OCR pacakge changes or even on different machines
+ validation_fields = ("column_nums", "row_nums", "column header", "cell text")
+ assert expectation in [{key: cell[key] for key in validation_fields} for cell in result]
+ else:
+ assert expectation in result
+
+
def test_table_prediction_tesseract_with_ocr_tokens(table_transformer, example_image):
ocr_tokens = [
{
diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py
index a28807cc..5a48fbf0 100644
--- a/unstructured_inference/__version__.py
+++ b/unstructured_inference/__version__.py
@@ -1 +1 @@
-__version__ = "0.7.10-dev1" # pragma: no cover
+__version__ = "0.7.10-dev2" # pragma: no cover
diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py
index 760dd46b..74cd64ae 100644
--- a/unstructured_inference/models/tables.py
+++ b/unstructured_inference/models/tables.py
@@ -19,6 +19,7 @@
from unstructured_inference.constants import (
TESSERACT_TEXT_HEIGHT,
)
+from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
from unstructured_inference.logger import logger
from unstructured_inference.models.table_postprocess import Rect
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
@@ -176,6 +177,7 @@ def run_prediction(
x: Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
+ result_format: Optional[str] = "html",
):
"""Predict table structure"""
outputs_structure = self.get_structure(x, pad_for_structure_detection)
@@ -186,8 +188,12 @@ def run_prediction(
)
ocr_tokens = self.get_tokens(x=x)
- html = recognize(outputs_structure, x, tokens=ocr_tokens, out_html=True)["html"]
- prediction = html[0] if html else ""
+ prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0]
+ if result_format == "html":
+ # Convert cells to HTML
+ prediction = cells_to_html(prediction) or ""
+ elif result_format == "dataframe":
+ prediction = table_cells_to_dataframe(prediction)
return prediction
@@ -234,10 +240,8 @@ def get_class_map(data_type: str):
}
-def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
+def recognize(outputs: dict, img: Image, tokens: list):
"""Recognize table elements."""
- out_formats = {}
-
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
str_class_thresholds = structure_class_thresholds
@@ -248,14 +252,7 @@ def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
# Further process the detected objects so they correspond to a consistent table
tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
# Enumerate all table cells: grid cells and spanning cells
- tables_cells = [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
-
- # Convert cells to HTML
- if out_html:
- tables_htmls = [cells_to_html(cells) for cells in tables_cells]
- out_formats["html"] = tables_htmls
-
- return out_formats
+ return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
def outputs_to_objects(outputs, img_size, class_idx2name):