Skip to content

Commit

Permalink
Feat: add more output format for table inference (#263)
Browse files Browse the repository at this point in the history
This PR addresses
[CORE-2307](https://unstructured-ai.atlassian.net/browse/CORE-2307)
- add a new kwarg to `UnstructuredTableTransformerModel.run_prediction`:
`output_format`
- default `output_format` is `html`, which is current behavior: output
html string representation of the table
- another options available is `dataframe`, which returns a pandas
dataframe representation of the table
- if not specified or any other string value for `output_format` it
returns a list of dictionaries: table cell format, the original output
format from table transformer
- `unstructured.model.tables.recognize` no longer accepts `out_html`
kwarg and it now only returns table cell format

[CORE-2307]:
https://unstructured-ai.atlassian.net/browse/CORE-2307?atlOrigin=eyJpIjoiNWRkNTljNzYxNjVmNDY3MDlhMDU5Y2ZhYzA5YTRkZjUiLCJwIjoiZ2l0aHViLWNvbS1KU1cifQ

---------

Co-authored-by: qued <[email protected]>
  • Loading branch information
badGarnet and qued authored Oct 20, 2023
1 parent 2ee38e6 commit 326f180
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 16 deletions.
4 changes: 3 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
## 0.7.10-dev1
## 0.7.10-dev2

* fix: Reduce Chipper memory consumption on x86_64 cpus
* fix: Skips ordering elements coming from Chipper
* fix: After refactoring to introduce Chipper, annotate() weren't able to show text with extra info from elements, this is fixed now.
* feat: add table cell and dataframe output formats to table transformer's `run_prediction` call
* breaking change: function `unstructured_inference.models.tables.recognize` no longer takes `out_html` parameter and it now only returns table cell data format (lists of dictionaries)

## 0.7.9

Expand Down
46 changes: 45 additions & 1 deletion test_unstructured_inference/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,11 @@
import pytest
from PIL import Image

from unstructured_inference.inference.elements import EmbeddedTextRegion, Rectangle, TextRegion
from unstructured_inference.inference.elements import (
EmbeddedTextRegion,
Rectangle,
TextRegion,
)
from unstructured_inference.inference.layoutelement import LayoutElement


Expand Down Expand Up @@ -122,3 +126,43 @@ def mock_layout(mock_embedded_text_regions):
LayoutElement(text=r.text, type="UncategorizedText", bbox=r.bbox)
for r in mock_embedded_text_regions
]


@pytest.fixture()
def example_table_cells():
cells = [
{"cell text": "Disability Category", "row_nums": [0, 1], "column_nums": [0]},
{"cell text": "Participants", "row_nums": [0, 1], "column_nums": [1]},
{"cell text": "Ballots Completed", "row_nums": [0, 1], "column_nums": [2]},
{"cell text": "Ballots Incomplete/Terminated", "row_nums": [0, 1], "column_nums": [3]},
{"cell text": "Results", "row_nums": [0], "column_nums": [4, 5]},
{"cell text": "Accuracy", "row_nums": [1], "column_nums": [4]},
{"cell text": "Time to complete", "row_nums": [1], "column_nums": [5]},
{"cell text": "Blind", "row_nums": [2], "column_nums": [0]},
{"cell text": "Low Vision", "row_nums": [3], "column_nums": [0]},
{"cell text": "Dexterity", "row_nums": [4], "column_nums": [0]},
{"cell text": "Mobility", "row_nums": [5], "column_nums": [0]},
{"cell text": "5", "row_nums": [2], "column_nums": [1]},
{"cell text": "5", "row_nums": [3], "column_nums": [1]},
{"cell text": "5", "row_nums": [4], "column_nums": [1]},
{"cell text": "3", "row_nums": [5], "column_nums": [1]},
{"cell text": "1", "row_nums": [2], "column_nums": [2]},
{"cell text": "2", "row_nums": [3], "column_nums": [2]},
{"cell text": "4", "row_nums": [4], "column_nums": [2]},
{"cell text": "3", "row_nums": [5], "column_nums": [2]},
{"cell text": "4", "row_nums": [2], "column_nums": [3]},
{"cell text": "3", "row_nums": [3], "column_nums": [3]},
{"cell text": "1", "row_nums": [4], "column_nums": [3]},
{"cell text": "0", "row_nums": [5], "column_nums": [3]},
{"cell text": "34.5%, n=1", "row_nums": [2], "column_nums": [4]},
{"cell text": "98.3% n=2 (97.7%, n=3)", "row_nums": [3], "column_nums": [4]},
{"cell text": "98.3%, n=4", "row_nums": [4], "column_nums": [4]},
{"cell text": "95.4%, n=3", "row_nums": [5], "column_nums": [4]},
{"cell text": "1199 sec, n=1", "row_nums": [2], "column_nums": [5]},
{"cell text": "1716 sec, n=3 (1934 sec, n=2)", "row_nums": [3], "column_nums": [5]},
{"cell text": "1672.1 sec, n=4", "row_nums": [4], "column_nums": [5]},
{"cell text": "1416 sec, n=3", "row_nums": [5], "column_nums": [5]},
]
for i in range(len(cells)):
cells[i]["column header"] = False
return [cells]
48 changes: 48 additions & 0 deletions test_unstructured_inference/models/test_tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,54 @@ def test_table_prediction_tesseract(table_transformer, example_image):
) in prediction


@pytest.mark.parametrize(
("output_format", "expectation"),
[
("html", "<tr><td>Blind</td><td>5</td><td>1</td><td>4</td><td>34.5%, n=1</td>"),
(
"cells",
{
"column_nums": [0],
"row_nums": [2],
"column header": False,
"cell text": "Blind",
},
),
("dataframe", ["Blind", "5", "1", "4", "34.5%, n=1", "1199 sec, n=1"]),
(None, "<tr><td>Blind</td><td>5</td><td>1</td><td>4</td><td>34.5%, n=1</td>"),
],
)
def test_table_prediction_output_format(
output_format,
expectation,
table_transformer,
example_image,
mocker,
example_table_cells,
):
mocker.patch.object(tables, "recognize", return_value=example_table_cells)
mocker.patch.object(
tables.UnstructuredTableTransformerModel,
"get_structure",
return_value=None,
)
mocker.patch.object(tables.UnstructuredTableTransformerModel, "get_tokens", return_value=None)
if output_format:
result = table_transformer.run_prediction(example_image, result_format=output_format)
else:
result = table_transformer.run_prediction(example_image)

if output_format == "dataframe":
assert expectation in result.values
elif output_format == "cells":
# other output like bbox are flakey to test since they depend on OCR and it may change
# slightly when OCR pacakge changes or even on different machines
validation_fields = ("column_nums", "row_nums", "column header", "cell text")
assert expectation in [{key: cell[key] for key in validation_fields} for cell in result]
else:
assert expectation in result


def test_table_prediction_tesseract_with_ocr_tokens(table_transformer, example_image):
ocr_tokens = [
{
Expand Down
2 changes: 1 addition & 1 deletion unstructured_inference/__version__.py
Original file line number Diff line number Diff line change
@@ -1 +1 @@
__version__ = "0.7.10-dev1" # pragma: no cover
__version__ = "0.7.10-dev2" # pragma: no cover
23 changes: 10 additions & 13 deletions unstructured_inference/models/tables.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from unstructured_inference.constants import (
TESSERACT_TEXT_HEIGHT,
)
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
from unstructured_inference.logger import logger
from unstructured_inference.models.table_postprocess import Rect
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
Expand Down Expand Up @@ -176,6 +177,7 @@ def run_prediction(
x: Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
result_format: Optional[str] = "html",
):
"""Predict table structure"""
outputs_structure = self.get_structure(x, pad_for_structure_detection)
Expand All @@ -186,8 +188,12 @@ def run_prediction(
)
ocr_tokens = self.get_tokens(x=x)

html = recognize(outputs_structure, x, tokens=ocr_tokens, out_html=True)["html"]
prediction = html[0] if html else ""
prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0]
if result_format == "html":
# Convert cells to HTML
prediction = cells_to_html(prediction) or ""
elif result_format == "dataframe":
prediction = table_cells_to_dataframe(prediction)
return prediction


Expand Down Expand Up @@ -234,10 +240,8 @@ def get_class_map(data_type: str):
}


def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
def recognize(outputs: dict, img: Image, tokens: list):
"""Recognize table elements."""
out_formats = {}

str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
str_class_thresholds = structure_class_thresholds
Expand All @@ -248,14 +252,7 @@ def recognize(outputs: dict, img: Image, tokens: list, out_html: bool = False):
# Further process the detected objects so they correspond to a consistent table
tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
# Enumerate all table cells: grid cells and spanning cells
tables_cells = [structure_to_cells(structure, tokens)[0] for structure in tables_structure]

# Convert cells to HTML
if out_html:
tables_htmls = [cells_to_html(cells) for cells in tables_cells]
out_formats["html"] = tables_htmls

return out_formats
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]


def outputs_to_objects(outputs, img_size, class_idx2name):
Expand Down

0 comments on commit 326f180

Please sign in to comment.