diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml
index ff79d80d..97cae495 100644
--- a/.github/workflows/ci.yml
+++ b/.github/workflows/ci.yml
@@ -7,7 +7,7 @@ on:
branches: [ main ]
env:
- PYTHON_VERSION: 3.8
+ PYTHON_VERSION: 3.9
jobs:
setup:
@@ -22,7 +22,7 @@ jobs:
key: ${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements/*.txt') }}
lookup-only: true
- name: Set up Python ${{ env.PYTHON_VERSION }}
- uses: actions/setup-python@v4
+ uses: actions/setup-python@v5
with:
python-version: ${{ env.PYTHON_VERSION }}
- name: Install Poppler
@@ -104,48 +104,50 @@ jobs:
CI=true make test
make check-coverage
- test_ingest:
- strategy:
- matrix:
- python-version: ["3.8","3.9","3.10"]
- runs-on: ubuntu-latest
- env:
- NLTK_DATA: ${{ github.workspace }}/nltk_data
- needs: lint
- steps:
- - name: Checkout unstructured repo for integration testing
- uses: actions/checkout@v4
- with:
- repository: 'Unstructured-IO/unstructured'
- - name: Checkout this repo
- uses: actions/checkout@v4
- with:
- path: inference
- - name: Set up Python ${{ matrix.python-version }}
- uses: actions/setup-python@v4
- with:
- python-version: ${{ matrix.python-version }}
- - name: Test
- env:
- GH_READ_ONLY_ACCESS_TOKEN: ${{ secrets.GH_READ_ONLY_ACCESS_TOKEN }}
- SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
- DISCORD_TOKEN: ${{ secrets.DISCORD_TOKEN }}
- run: |
- python${{ matrix.python-version }} -m venv .venv
- source .venv/bin/activate
- [ ! -d "$NLTK_DATA" ] && mkdir "$NLTK_DATA"
- make install-ci
- pip install -e inference/
- sudo apt-get update
- sudo apt-get install -y libmagic-dev poppler-utils libreoffice pandoc
- sudo add-apt-repository -y ppa:alex-p/tesseract-ocr5
- sudo apt-get install -y tesseract-ocr
- sudo apt-get install -y tesseract-ocr-kor
- sudo apt-get install -y diffstat
- tesseract --version
- make install-all-ingest
- # only run ingest tests that check expected output diffs.
- bash inference/scripts/test-unstructured-ingest-helper.sh
+ # NOTE(robinson) - disabling ingest tests for now, as of 5/22/2024 they seem to have been
+ # broken for the past six months
+ # test_ingest:
+ # strategy:
+ # matrix:
+ # python-version: ["3.9","3.10"]
+ # runs-on: ubuntu-latest
+ # env:
+ # NLTK_DATA: ${{ github.workspace }}/nltk_data
+ # needs: lint
+ # steps:
+ # - name: Checkout unstructured repo for integration testing
+ # uses: actions/checkout@v4
+ # with:
+ # repository: 'Unstructured-IO/unstructured'
+ # - name: Checkout this repo
+ # uses: actions/checkout@v4
+ # with:
+ # path: inference
+ # - name: Set up Python ${{ matrix.python-version }}
+ # uses: actions/setup-python@v4
+ # with:
+ # python-version: ${{ matrix.python-version }}
+ # - name: Test
+ # env:
+ # GH_READ_ONLY_ACCESS_TOKEN: ${{ secrets.GH_READ_ONLY_ACCESS_TOKEN }}
+ # SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }}
+ # DISCORD_TOKEN: ${{ secrets.DISCORD_TOKEN }}
+ # run: |
+ # python${{ matrix.python-version }} -m venv .venv
+ # source .venv/bin/activate
+ # [ ! -d "$NLTK_DATA" ] && mkdir "$NLTK_DATA"
+ # make install-ci
+ # pip install -e inference/
+ # sudo apt-get update
+ # sudo apt-get install -y libmagic-dev poppler-utils libreoffice pandoc
+ # sudo add-apt-repository -y ppa:alex-p/tesseract-ocr5
+ # sudo apt-get install -y tesseract-ocr
+ # sudo apt-get install -y tesseract-ocr-kor
+ # sudo apt-get install -y diffstat
+ # tesseract --version
+ # make install-all-ingest
+ # # only run ingest tests that check expected output diffs.
+ # bash inference/scripts/test-unstructured-ingest-helper.sh
changelog:
runs-on: ubuntu-latest
diff --git a/CHANGELOG.md b/CHANGELOG.md
index 11560e32..944f1b27 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,7 +1,85 @@
-## 0.7.20
+## 0.7.37-dev0
* refactor: remove layout analysis related code
+## 0.7.36
+
+fix: add input parameter validation to `fill_cells()` when converting cells to html
+
+## 0.7.35
+
+Fix syntax for generated HTML tables
+
+## 0.7.34
+
+* Reduce excessive logging
+
+## 0.7.33
+
+* BREAKING CHANGE: removes legacy detectron2 model
+* deps: remove layoutparser optional dependencies
+
+## 0.7.32
+
+* refactor: remove all code related to filling inferred elements text from embedded text (pdfminer).
+* bug: set the Chipper max_length variable
+
+## 0.7.31
+
+* refactor: remove all `cid` related code that was originally added to filter out invalid `pdfminer` text
+* enhancement: Wrapped hf_hub_download with a function that checks for local file before checking HF
+
+## 0.7.30
+
+* fix: table transformer doesn't return multiple cells with same coordinates
+*
+## 0.7.29
+
+* fix: table transformer predictions are now removed if confidence is below threshold
+
+
+## 0.7.28
+
+* feat: allow table transformer agent to return table prediction in not parsed format
+
+## 0.7.27
+
+* fix: remove pin from `onnxruntime` dependency.
+
+## 0.7.26
+
+* feat: add a set of new `ElementType`s to extend future element types recognition
+* feat: allow registering of new models for inference using `unstructured_inference.models.base.register_new_model` function
+
+## 0.7.25
+
+* fix: replace `Rectangle.is_in()` with `Rectangle.is_almost_subregion_of()` when filling in an inferred element with embedded text
+* bug: check for None in Chipper bounding box reduction
+* chore: removes `install-detectron2` from the `Makefile`
+* fix: convert label_map keys read from os.environment `UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH` to int type
+* feat: removes supergradients references
+
+## 0.7.24
+
+* fix: assign value to `text_as_html` element attribute only if `text` attribute contains HTML tags.
+
+## 0.7.23
+
+* fix: added handling in `UnstructuredTableTransformerModel` for if `recognize` returns an empty
+ list in `run_prediction`.
+
+## 0.7.22
+
+* fix: add logic to handle computation of intersections betwen 2 `Rectangle`s when a `Rectangle` has `None` value in its coordinates
+
+## 0.7.21
+
+* fix: fix a bug where chipper, or any element extraction model based `PageLayout` object, lack `image_metadata` and other attributes that are required for downstream processing; this fix also reduces the memory overhead of using chipper model
+
+## 0.7.20
+
+* chipper-v3: improved table prediction
+
## 0.7.19
* refactor: remove all OCR related code
diff --git a/Makefile b/Makefile
index ab987bb8..5b806af5 100644
--- a/Makefile
+++ b/Makefile
@@ -19,7 +19,7 @@ install-base: install-base-pip-packages
## install: installs all test, dev, and experimental requirements
.PHONY: install
-install: install-base-pip-packages install-dev install-detectron2
+install: install-base-pip-packages install-dev
.PHONY: install-ci
install-ci: install-base-pip-packages install-test
@@ -28,10 +28,6 @@ install-ci: install-base-pip-packages install-test
install-base-pip-packages:
python3 -m pip install pip==${PIP_VERSION}
-.PHONY: install-detectron2
-install-detectron2:
- pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@57bdb21249d5418c130d54e2ebdc94dda7a4c01a"
-
.PHONY: install-test
install-test: install-base
pip install -r requirements/test.txt
@@ -44,10 +40,6 @@ install-dev: install-test
.PHONY: pip-compile
pip-compile:
pip-compile --upgrade requirements/base.in
- # NOTE(robinson) - We want the dependencies for detectron2 in the requirements.txt, but not
- # the detectron2 repo itself. If detectron2 is in the requirements.txt file, an order of
- # operations issue related to the torch library causes the install to fail
- sed 's/^detectron2 @/# detectron2 @/g' requirements/base.txt
pip-compile --upgrade requirements/test.in
pip-compile --upgrade requirements/dev.in
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..aa4949aa
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,2 @@
+[tool.black]
+line-length = 100
diff --git a/requirements/base.in b/requirements/base.in
index 5f4a8f5d..fc60b9bc 100644
--- a/requirements/base.in
+++ b/requirements/base.in
@@ -1,11 +1,13 @@
-c constraints.in
-layoutparser[layoutmodels,tesseract]
+layoutparser
python-multipart
huggingface-hub
opencv-python!=4.7.0.68
onnx
-# NOTE(benjamin): Pinned because onnxruntime changed the way quantization is done, and we need to update our code to support it
-onnxruntime<1.16
+onnxruntime>=1.17.0
+matplotlib
+torch
+timm
# NOTE(alan): Pinned because this is when the most recent module we import appeared
transformers>=4.25.1
rapidfuzz
diff --git a/requirements/base.txt b/requirements/base.txt
index 28857d7a..536ea640 100644
--- a/requirements/base.txt
+++ b/requirements/base.txt
@@ -1,12 +1,10 @@
#
-# This file is autogenerated by pip-compile with Python 3.8
+# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# pip-compile requirements/base.in
#
-antlr4-python3-runtime==4.9.3
- # via omegaconf
-certifi==2023.7.22
+certifi==2024.2.2
# via requests
cffi==1.16.0
# via cryptography
@@ -16,28 +14,26 @@ charset-normalizer==3.3.2
# requests
coloredlogs==15.0.1
# via onnxruntime
-contourpy==1.1.1
+contourpy==1.2.1
# via matplotlib
-cryptography==41.0.5
+cryptography==42.0.7
# via pdfminer-six
cycler==0.12.1
# via matplotlib
-effdet==0.4.1
- # via layoutparser
-filelock==3.13.1
+filelock==3.14.0
# via
# huggingface-hub
# torch
# transformers
-flatbuffers==23.5.26
+flatbuffers==24.3.25
# via onnxruntime
-fonttools==4.44.3
+fonttools==4.51.0
# via matplotlib
-fsspec==2023.10.0
+fsspec==2024.5.0
# via
# huggingface-hub
# torch
-huggingface-hub==0.19.4
+huggingface-hub==0.23.1
# via
# -r requirements/base.in
# timm
@@ -45,27 +41,27 @@ huggingface-hub==0.19.4
# transformers
humanfriendly==10.0
# via coloredlogs
-idna==3.4
+idna==3.7
# via requests
-importlib-resources==6.1.1
+importlib-resources==6.4.0
# via matplotlib
iopath==0.1.10
# via layoutparser
-jinja2==3.1.2
+jinja2==3.1.4
# via torch
kiwisolver==1.4.5
# via matplotlib
-layoutparser[layoutmodels,tesseract]==0.3.4
+layoutparser==0.3.4
# via -r requirements/base.in
-markupsafe==2.1.3
+markupsafe==2.1.5
# via jinja2
-matplotlib==3.7.3
- # via pycocotools
+matplotlib==3.9.0
+ # via -r requirements/base.in
mpmath==1.3.0
# via sympy
-networkx==3.1
+networkx==3.2.1
# via torch
-numpy==1.24.4
+numpy==1.26.4
# via
# contourpy
# layoutparser
@@ -74,88 +70,77 @@ numpy==1.24.4
# onnxruntime
# opencv-python
# pandas
- # pycocotools
# scipy
# torchvision
# transformers
-omegaconf==2.3.0
- # via effdet
-onnx==1.15.0
+onnx==1.16.0
# via -r requirements/base.in
-onnxruntime==1.15.1
+onnxruntime==1.18.0
# via -r requirements/base.in
-opencv-python==4.8.1.78
+opencv-python==4.9.0.80
# via
# -r requirements/base.in
# layoutparser
-packaging==23.2
+packaging==24.0
# via
# huggingface-hub
# matplotlib
# onnxruntime
- # pytesseract
# transformers
-pandas==2.0.3
+pandas==2.2.2
# via layoutparser
-pdf2image==1.16.3
+pdf2image==1.17.0
# via layoutparser
-pdfminer-six==20221105
+pdfminer-six==20231228
# via pdfplumber
-pdfplumber==0.10.3
+pdfplumber==0.11.0
# via layoutparser
-pillow==10.1.0
+pillow==10.3.0
# via
# layoutparser
# matplotlib
# pdf2image
# pdfplumber
- # pytesseract
# torchvision
portalocker==2.8.2
# via iopath
-protobuf==4.25.1
+protobuf==5.26.1
# via
# onnx
# onnxruntime
-pycocotools==2.0.7
- # via effdet
-pycparser==2.21
+pycparser==2.22
# via cffi
-pyparsing==3.1.1
+pyparsing==3.1.2
# via matplotlib
-pypdfium2==4.24.0
+pypdfium2==4.30.0
# via pdfplumber
-pytesseract==0.3.10
- # via layoutparser
-python-dateutil==2.8.2
+python-dateutil==2.9.0.post0
# via
# matplotlib
# pandas
-python-multipart==0.0.6
+python-multipart==0.0.9
# via -r requirements/base.in
-pytz==2023.3.post1
+pytz==2024.1
# via pandas
pyyaml==6.0.1
# via
# huggingface-hub
# layoutparser
- # omegaconf
# timm
# transformers
-rapidfuzz==3.5.2
+rapidfuzz==3.9.1
# via -r requirements/base.in
-regex==2023.10.3
+regex==2024.5.15
# via transformers
-requests==2.31.0
+requests==2.32.2
# via
# huggingface-hub
- # torchvision
# transformers
-safetensors==0.4.0
+safetensors==0.4.3
# via
# timm
# transformers
-scipy==1.10.1
+scipy==1.13.0
# via layoutparser
six==1.16.0
# via python-dateutil
@@ -163,36 +148,32 @@ sympy==1.12
# via
# onnxruntime
# torch
-timm==0.9.10
- # via effdet
-tokenizers==0.15.0
+timm==1.0.3
+ # via -r requirements/base.in
+tokenizers==0.19.1
# via transformers
-torch==2.1.1
+torch==2.3.0
# via
- # effdet
- # layoutparser
+ # -r requirements/base.in
# timm
# torchvision
-torchvision==0.16.1
- # via
- # effdet
- # layoutparser
- # timm
-tqdm==4.66.1
+torchvision==0.18.0
+ # via timm
+tqdm==4.66.4
# via
# huggingface-hub
# iopath
# transformers
-transformers==4.35.2
+transformers==4.41.0
# via -r requirements/base.in
-typing-extensions==4.8.0
+typing-extensions==4.11.0
# via
# huggingface-hub
# iopath
# torch
-tzdata==2023.3
+tzdata==2024.1
# via pandas
-urllib3==2.1.0
+urllib3==2.2.1
# via requests
-zipp==3.17.0
+zipp==3.18.2
# via importlib-resources
diff --git a/requirements/dev.txt b/requirements/dev.txt
index 8f45a80c..1021b087 100644
--- a/requirements/dev.txt
+++ b/requirements/dev.txt
@@ -1,17 +1,16 @@
#
-# This file is autogenerated by pip-compile with Python 3.8
+# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# pip-compile requirements/dev.in
#
-anyio==4.0.0
+anyio==4.3.0
# via
# -c requirements/test.txt
+ # httpx
# jupyter-server
-appnope==0.1.3
- # via
- # ipykernel
- # ipython
+appnope==0.1.4
+ # via ipykernel
argon2-cffi==23.1.0
# via jupyter-server
argon2-cffi-bindings==21.2.0
@@ -22,24 +21,24 @@ asttokens==2.4.1
# via stack-data
async-lru==2.0.4
# via jupyterlab
-attrs==23.1.0
+attrs==23.2.0
# via
# jsonschema
# referencing
-babel==2.13.1
+babel==2.15.0
# via jupyterlab-server
-backcall==0.2.0
- # via ipython
-beautifulsoup4==4.12.2
+beautifulsoup4==4.12.3
# via nbconvert
bleach==6.1.0
# via nbconvert
-build==1.0.3
+build==1.2.1
# via pip-tools
-certifi==2023.7.22
+certifi==2024.2.2
# via
# -c requirements/base.txt
# -c requirements/test.txt
+ # httpcore
+ # httpx
# requests
cffi==1.16.0
# via
@@ -54,11 +53,11 @@ click==8.1.7
# via
# -c requirements/test.txt
# pip-tools
-comm==0.2.0
+comm==0.2.2
# via
# ipykernel
# ipywidgets
-contourpy==1.1.1
+contourpy==1.2.1
# via
# -c requirements/base.txt
# matplotlib
@@ -66,34 +65,48 @@ cycler==0.12.1
# via
# -c requirements/base.txt
# matplotlib
-debugpy==1.8.0
+debugpy==1.8.1
# via ipykernel
decorator==5.1.1
# via ipython
defusedxml==0.7.1
# via nbconvert
-exceptiongroup==1.1.3
+exceptiongroup==1.2.1
# via
# -c requirements/test.txt
# anyio
+ # ipython
executing==2.0.1
# via stack-data
-fastjsonschema==2.19.0
+fastjsonschema==2.19.1
# via nbformat
-fonttools==4.44.3
+fonttools==4.51.0
# via
# -c requirements/base.txt
# matplotlib
fqdn==1.5.1
# via jsonschema
-idna==3.4
+h11==0.14.0
+ # via
+ # -c requirements/test.txt
+ # httpcore
+httpcore==1.0.5
+ # via
+ # -c requirements/test.txt
+ # httpx
+httpx==0.27.0
+ # via
+ # -c requirements/test.txt
+ # jupyterlab
+idna==3.7
# via
# -c requirements/base.txt
# -c requirements/test.txt
# anyio
+ # httpx
# jsonschema
# requests
-importlib-metadata==6.8.0
+importlib-metadata==7.1.0
# via
# build
# jupyter-client
@@ -101,52 +114,49 @@ importlib-metadata==6.8.0
# jupyterlab
# jupyterlab-server
# nbconvert
-importlib-resources==6.1.1
+importlib-resources==6.4.0
# via
# -c requirements/base.txt
- # jsonschema
- # jsonschema-specifications
- # jupyterlab
# matplotlib
-ipykernel==6.26.0
+ipykernel==6.29.4
# via
# jupyter
# jupyter-console
# jupyterlab
# qtconsole
-ipython==8.12.3
+ipython==8.18.1
# via
# -r requirements/dev.in
# ipykernel
# ipywidgets
# jupyter-console
-ipywidgets==8.1.1
+ipywidgets==8.1.2
# via jupyter
isoduration==20.11.0
# via jsonschema
jedi==0.19.1
# via ipython
-jinja2==3.1.2
+jinja2==3.1.4
# via
# -c requirements/base.txt
# jupyter-server
# jupyterlab
# jupyterlab-server
# nbconvert
-json5==0.9.14
+json5==0.9.25
# via jupyterlab-server
jsonpointer==2.4
# via jsonschema
-jsonschema[format-nongpl]==4.20.0
+jsonschema[format-nongpl]==4.22.0
# via
# jupyter-events
# jupyterlab-server
# nbformat
-jsonschema-specifications==2023.11.1
+jsonschema-specifications==2023.12.1
# via jsonschema
jupyter==1.0.0
# via -r requirements/dev.in
-jupyter-client==8.6.0
+jupyter-client==8.6.1
# via
# ipykernel
# jupyter-console
@@ -155,7 +165,7 @@ jupyter-client==8.6.0
# qtconsole
jupyter-console==6.6.3
# via jupyter
-jupyter-core==5.5.0
+jupyter-core==5.7.2
# via
# ipykernel
# jupyter-client
@@ -166,75 +176,75 @@ jupyter-core==5.5.0
# nbconvert
# nbformat
# qtconsole
-jupyter-events==0.9.0
+jupyter-events==0.10.0
# via jupyter-server
-jupyter-lsp==2.2.0
+jupyter-lsp==2.2.5
# via jupyterlab
-jupyter-server==2.10.1
+jupyter-server==2.14.0
# via
# jupyter-lsp
# jupyterlab
# jupyterlab-server
# notebook
# notebook-shim
-jupyter-server-terminals==0.4.4
+jupyter-server-terminals==0.5.3
# via jupyter-server
-jupyterlab==4.0.8
+jupyterlab==4.2.0
# via notebook
-jupyterlab-pygments==0.2.2
+jupyterlab-pygments==0.3.0
# via nbconvert
-jupyterlab-server==2.25.1
+jupyterlab-server==2.27.1
# via
# jupyterlab
# notebook
-jupyterlab-widgets==3.0.9
+jupyterlab-widgets==3.0.10
# via ipywidgets
kiwisolver==1.4.5
# via
# -c requirements/base.txt
# matplotlib
-markupsafe==2.1.3
+markupsafe==2.1.5
# via
# -c requirements/base.txt
# jinja2
# nbconvert
-matplotlib==3.7.3
+matplotlib==3.9.0
# via
# -c requirements/base.txt
# -r requirements/dev.in
-matplotlib-inline==0.1.6
+matplotlib-inline==0.1.7
# via
# ipykernel
# ipython
mistune==3.0.2
# via nbconvert
-nbclient==0.9.0
+nbclient==0.10.0
# via nbconvert
-nbconvert==7.11.0
+nbconvert==7.16.4
# via
# jupyter
# jupyter-server
-nbformat==5.9.2
+nbformat==5.10.4
# via
# jupyter-server
# nbclient
# nbconvert
-nest-asyncio==1.5.8
+nest-asyncio==1.6.0
# via ipykernel
-notebook==7.0.6
+notebook==7.2.0
# via jupyter
-notebook-shim==0.2.3
+notebook-shim==0.2.4
# via
# jupyterlab
# notebook
-numpy==1.24.4
+numpy==1.26.4
# via
# -c requirements/base.txt
# contourpy
# matplotlib
-overrides==7.4.0
+overrides==7.7.0
# via jupyter-server
-packaging==23.2
+packaging==24.0
# via
# -c requirements/base.txt
# -c requirements/test.txt
@@ -247,34 +257,30 @@ packaging==23.2
# nbconvert
# qtconsole
# qtpy
-pandocfilters==1.5.0
+pandocfilters==1.5.1
# via nbconvert
-parso==0.8.3
+parso==0.8.4
# via jedi
-pexpect==4.8.0
- # via ipython
-pickleshare==0.7.5
+pexpect==4.9.0
# via ipython
-pillow==10.1.0
+pillow==10.3.0
# via
# -c requirements/base.txt
# -c requirements/test.txt
# matplotlib
-pip-tools==7.3.0
+pip-tools==7.4.1
# via -r requirements/dev.in
-pkgutil-resolve-name==1.3.10
- # via jsonschema
-platformdirs==4.0.0
+platformdirs==4.2.2
# via
# -c requirements/test.txt
# jupyter-core
-prometheus-client==0.18.0
+prometheus-client==0.20.0
# via jupyter-server
-prompt-toolkit==3.0.41
+prompt-toolkit==3.0.43
# via
# ipython
# jupyter-console
-psutil==5.9.6
+psutil==5.9.8
# via ipykernel
ptyprocess==0.7.0
# via
@@ -282,23 +288,25 @@ ptyprocess==0.7.0
# terminado
pure-eval==0.2.2
# via stack-data
-pycparser==2.21
+pycparser==2.22
# via
# -c requirements/base.txt
# cffi
-pygments==2.16.1
+pygments==2.18.0
# via
# ipython
# jupyter-console
# nbconvert
# qtconsole
-pyparsing==3.1.1
+pyparsing==3.1.2
# via
# -c requirements/base.txt
# matplotlib
-pyproject-hooks==1.0.0
- # via build
-python-dateutil==2.8.2
+pyproject-hooks==1.1.0
+ # via
+ # build
+ # pip-tools
+python-dateutil==2.9.0.post0
# via
# -c requirements/base.txt
# arrow
@@ -306,32 +314,28 @@ python-dateutil==2.8.2
# matplotlib
python-json-logger==2.0.7
# via jupyter-events
-pytz==2023.3.post1
- # via
- # -c requirements/base.txt
- # babel
pyyaml==6.0.1
# via
# -c requirements/base.txt
# -c requirements/test.txt
# jupyter-events
-pyzmq==25.1.1
+pyzmq==26.0.3
# via
# ipykernel
# jupyter-client
# jupyter-console
# jupyter-server
# qtconsole
-qtconsole==5.5.1
+qtconsole==5.5.2
# via jupyter
qtpy==2.4.1
# via qtconsole
-referencing==0.31.0
+referencing==0.35.1
# via
# jsonschema
# jsonschema-specifications
# jupyter-events
-requests==2.31.0
+requests==2.32.2
# via
# -c requirements/base.txt
# -c requirements/test.txt
@@ -344,11 +348,11 @@ rfc3986-validator==0.1.1
# via
# jsonschema
# jupyter-events
-rpds-py==0.13.0
+rpds-py==0.18.1
# via
# jsonschema
# referencing
-send2trash==1.8.2
+send2trash==1.8.3
# via jupyter-server
six==1.16.0
# via
@@ -357,19 +361,20 @@ six==1.16.0
# bleach
# python-dateutil
# rfc3339-validator
-sniffio==1.3.0
+sniffio==1.3.1
# via
# -c requirements/test.txt
# anyio
+ # httpx
soupsieve==2.5
# via beautifulsoup4
stack-data==0.6.3
# via ipython
-terminado==0.18.0
+terminado==0.18.1
# via
# jupyter-server
# jupyter-server-terminals
-tinycss2==1.2.1
+tinycss2==1.3.0
# via nbconvert
tomli==2.0.1
# via
@@ -377,8 +382,7 @@ tomli==2.0.1
# build
# jupyterlab
# pip-tools
- # pyproject-hooks
-tornado==6.3.3
+tornado==6.4
# via
# ipykernel
# jupyter-client
@@ -386,7 +390,7 @@ tornado==6.3.3
# jupyterlab
# notebook
# terminado
-traitlets==5.13.0
+traitlets==5.14.3
# via
# comm
# ipykernel
@@ -403,22 +407,23 @@ traitlets==5.13.0
# nbconvert
# nbformat
# qtconsole
-types-python-dateutil==2.8.19.14
+types-python-dateutil==2.9.0.20240316
# via arrow
-typing-extensions==4.8.0
+typing-extensions==4.11.0
# via
# -c requirements/base.txt
# -c requirements/test.txt
+ # anyio
# async-lru
# ipython
uri-template==1.3.0
# via jsonschema
-urllib3==2.1.0
+urllib3==2.2.1
# via
# -c requirements/base.txt
# -c requirements/test.txt
# requests
-wcwidth==0.2.10
+wcwidth==0.2.13
# via prompt-toolkit
webcolors==1.13
# via jsonschema
@@ -426,13 +431,13 @@ webencodings==0.5.1
# via
# bleach
# tinycss2
-websocket-client==1.6.4
+websocket-client==1.8.0
# via jupyter-server
-wheel==0.41.3
+wheel==0.43.0
# via pip-tools
-widgetsnbextension==4.0.9
+widgetsnbextension==4.0.10
# via ipywidgets
-zipp==3.17.0
+zipp==3.18.2
# via
# -c requirements/base.txt
# importlib-metadata
diff --git a/requirements/test.txt b/requirements/test.txt
index 4e5a2adb..5f738730 100644
--- a/requirements/test.txt
+++ b/requirements/test.txt
@@ -1,14 +1,14 @@
#
-# This file is autogenerated by pip-compile with Python 3.8
+# This file is autogenerated by pip-compile with Python 3.9
# by the following command:
#
# pip-compile requirements/test.in
#
-anyio==4.0.0
+anyio==4.3.0
# via httpx
-black==23.11.0
+black==24.4.2
# via -r requirements/test.in
-certifi==2023.7.22
+certifi==2024.2.2
# via
# -c requirements/base.txt
# httpcore
@@ -22,39 +22,39 @@ click==8.1.7
# via
# -r requirements/test.in
# black
-coverage[toml]==7.3.2
+coverage[toml]==7.5.1
# via
# -r requirements/test.in
# pytest-cov
-exceptiongroup==1.1.3
+exceptiongroup==1.2.1
# via
# anyio
# pytest
-filelock==3.13.1
+filelock==3.14.0
# via
# -c requirements/base.txt
# huggingface-hub
-flake8==6.1.0
+flake8==7.0.0
# via
# -r requirements/test.in
# flake8-docstrings
flake8-docstrings==1.7.0
# via -r requirements/test.in
-fsspec==2023.10.0
+fsspec==2024.5.0
# via
# -c requirements/base.txt
# huggingface-hub
h11==0.14.0
# via httpcore
-httpcore==1.0.2
+httpcore==1.0.5
# via httpx
-httpx==0.25.1
+httpx==0.27.0
# via -r requirements/test.in
-huggingface-hub==0.19.4
+huggingface-hub==0.23.1
# via
# -c requirements/base.txt
# -r requirements/test.in
-idna==3.4
+idna==3.7
# via
# -c requirements/base.txt
# anyio
@@ -64,57 +64,57 @@ iniconfig==2.0.0
# via pytest
mccabe==0.7.0
# via flake8
-mypy==1.7.0
+mypy==1.10.0
# via -r requirements/test.in
mypy-extensions==1.0.0
# via
# black
# mypy
-packaging==23.2
+packaging==24.0
# via
# -c requirements/base.txt
# black
# huggingface-hub
# pytest
-pathspec==0.11.2
+pathspec==0.12.1
# via black
-pdf2image==1.16.3
+pdf2image==1.17.0
# via
# -c requirements/base.txt
# -r requirements/test.in
-pillow==10.1.0
+pillow==10.3.0
# via
# -c requirements/base.txt
# pdf2image
-platformdirs==4.0.0
+platformdirs==4.2.2
# via black
-pluggy==1.3.0
+pluggy==1.5.0
# via pytest
pycodestyle==2.11.1
# via flake8
pydocstyle==6.3.0
# via flake8-docstrings
-pyflakes==3.1.0
+pyflakes==3.2.0
# via flake8
-pytest==7.4.3
+pytest==8.2.1
# via
# pytest-cov
# pytest-mock
-pytest-cov==4.1.0
+pytest-cov==5.0.0
# via -r requirements/test.in
-pytest-mock==3.12.0
+pytest-mock==3.14.0
# via -r requirements/test.in
pyyaml==6.0.1
# via
# -c requirements/base.txt
# huggingface-hub
-requests==2.31.0
+requests==2.32.2
# via
# -c requirements/base.txt
# huggingface-hub
-ruff==0.1.5
+ruff==0.4.4
# via -r requirements/test.in
-sniffio==1.3.0
+sniffio==1.3.1
# via
# anyio
# httpx
@@ -126,19 +126,20 @@ tomli==2.0.1
# coverage
# mypy
# pytest
-tqdm==4.66.1
+tqdm==4.66.4
# via
# -c requirements/base.txt
# huggingface-hub
-types-pyyaml==6.0.12.12
+types-pyyaml==6.0.12.20240311
# via -r requirements/test.in
-typing-extensions==4.8.0
+typing-extensions==4.11.0
# via
# -c requirements/base.txt
+ # anyio
# black
# huggingface-hub
# mypy
-urllib3==2.1.0
+urllib3==2.2.1
# via
# -c requirements/base.txt
# requests
diff --git a/scripts/version-sync.sh b/scripts/version-sync.sh
index e8888efa..4a62d26e 100755
--- a/scripts/version-sync.sh
+++ b/scripts/version-sync.sh
@@ -13,12 +13,12 @@ function usage {
}
function getopts-extra () {
- declare i=1
+ declare -i i=1
# if the next argument is not an option, then append it to array OPTARG
while [[ ${OPTIND} -le $# && ${!OPTIND:0:1} != '-' ]]; do
OPTARG[i]=${!OPTIND}
- i+=1
- OPTIND+=1
+ ((i += 1))
+ ((OPTIND += 1))
done
}
diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py
index e7f77164..6cb618ce 100644
--- a/test_unstructured_inference/inference/test_layout.py
+++ b/test_unstructured_inference/inference/test_layout.py
@@ -9,7 +9,10 @@
import unstructured_inference.models.base as models
from unstructured_inference.inference import elements, layout, layoutelement
-from unstructured_inference.inference.elements import EmbeddedTextRegion, ImageTextRegion
+from unstructured_inference.inference.elements import (
+ EmbeddedTextRegion,
+ ImageTextRegion,
+)
from unstructured_inference.models.unstructuredmodel import (
UnstructuredElementExtractionModel,
UnstructuredObjectDetectionModel,
@@ -156,7 +159,7 @@ def test_process_data_with_model_raises_on_invalid_model_name():
layout.process_data_with_model(fp, model_name="fake")
-@pytest.mark.parametrize("model_name", [None, "checkbox"])
+@pytest.mark.parametrize("model_name", [None, "yolox"])
def test_process_file_with_model(monkeypatch, mock_final_layout, model_name):
def mock_initialize(self, *args, **kwargs):
self.model = MockLayoutModel(mock_final_layout)
@@ -166,7 +169,7 @@ def mock_initialize(self, *args, **kwargs):
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
- monkeypatch.setattr(models.UnstructuredDetectronModel, "initialize", mock_initialize)
+ monkeypatch.setattr(models.UnstructuredDetectronONNXModel, "initialize", mock_initialize)
filename = ""
assert layout.process_file_with_model(filename, model_name=model_name)
@@ -180,7 +183,7 @@ def mock_initialize(self, *args, **kwargs):
"from_file",
lambda *args, **kwargs: layout.DocumentLayout.from_pages([]),
)
- monkeypatch.setattr(models.UnstructuredDetectronModel, "initialize", mock_initialize)
+ monkeypatch.setattr(models.UnstructuredDetectronONNXModel, "initialize", mock_initialize)
filename = ""
layout.process_file_with_model(filename, model_name=None)
# There should be no UserWarning, but if there is one it should not have the following message
@@ -224,33 +227,6 @@ def __init__(
self.detection_model = detection_model
-@pytest.mark.parametrize(
- ("text", "expected"),
- [
- ("base", 0.0),
- ("", 0.0),
- ("(cid:2)", 1.0),
- ("(cid:1)a", 0.5),
- ("c(cid:1)ab", 0.25),
- ],
-)
-def test_cid_ratio(text, expected):
- assert elements.cid_ratio(text) == expected
-
-
-@pytest.mark.parametrize(
- ("text", "expected"),
- [
- ("base", False),
- ("(cid:2)", True),
- ("(cid:1234567890)", True),
- ("jkl;(cid:12)asdf", True),
- ],
-)
-def test_is_cid_present(text, expected):
- assert elements.is_cid_present(text) == expected
-
-
class MockLayout:
def __init__(self, *elements):
self.elements = elements
@@ -271,12 +247,14 @@ def filter_by(self, *args, **kwargs):
return MockLayout()
+@pytest.mark.parametrize("element_extraction_model", [None, "foo"])
@pytest.mark.parametrize("filetype", ["png", "jpg", "tiff"])
-def test_from_image_file(monkeypatch, mock_final_layout, filetype):
+def test_from_image_file(monkeypatch, mock_final_layout, filetype, element_extraction_model):
def mock_get_elements(self, *args, **kwargs):
self.elements = [mock_final_layout]
monkeypatch.setattr(layout.PageLayout, "get_elements_with_detection_model", mock_get_elements)
+ monkeypatch.setattr(layout.PageLayout, "get_elements_using_image_extraction", mock_get_elements)
filename = f"sample-docs/loremipsum.{filetype}"
image = Image.open(filename)
image_metadata = {
@@ -285,7 +263,10 @@ def mock_get_elements(self, *args, **kwargs):
"height": image.height,
}
- doc = layout.DocumentLayout.from_image_file(filename)
+ doc = layout.DocumentLayout.from_image_file(
+ filename,
+ element_extraction_model=element_extraction_model,
+ )
page = doc.pages[0]
assert page.elements[0] == mock_final_layout
assert page.image is None
@@ -331,16 +312,6 @@ def test_from_image_file_raises_isadirectoryerror_with_dir():
layout.DocumentLayout.from_image_file(tempdir)
-@pytest.mark.parametrize("idx", range(2))
-def test_get_elements_from_layout(mock_initial_layout, idx):
- page = MockPageLayout()
- block = mock_initial_layout[idx]
- block.bbox.pad(3)
- fixed_layout = [block]
- elements = page.get_elements_from_layout(fixed_layout)
- assert elements[0].text == block.text
-
-
def test_page_numbers_in_page_objects():
with patch(
"unstructured_inference.inference.layout.PageLayout.get_elements_with_detection_model",
@@ -350,49 +321,8 @@ def test_page_numbers_in_page_objects():
assert [page.number for page in doc.pages] == list(range(1, len(doc.pages) + 1))
-@pytest.mark.parametrize(
- ("fixed_layouts", "called_method", "not_called_method"),
- [
- (
- [MockLayout()],
- "get_elements_from_layout",
- "get_elements_with_detection_model",
- ),
- (None, "get_elements_with_detection_model", "get_elements_from_layout"),
- ],
-)
-def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method):
- with patch.object(
- layout.PageLayout,
- "get_elements_with_detection_model",
- return_value=[],
- ), patch.object(
- layout.PageLayout,
- "get_elements_from_layout",
- return_value=[],
- ):
- layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf", fixed_layouts=fixed_layouts)
- getattr(layout.PageLayout, called_method).assert_called()
- getattr(layout.PageLayout, not_called_method).assert_not_called()
-
-
-@pytest.mark.parametrize(
- ("text", "expected"),
- [("c\to\x0cn\ftrol\ncharacter\rs\b", "control characters"), ("\"'\\", "\"'\\")],
-)
-def test_remove_control_characters(text, expected):
- assert elements.remove_control_characters(text) == expected
-
-
no_text_region = EmbeddedTextRegion.from_coords(0, 0, 100, 100)
text_region = EmbeddedTextRegion.from_coords(0, 0, 100, 100, text="test")
-cid_text_region = EmbeddedTextRegion.from_coords(
- 0,
- 0,
- 100,
- 100,
- text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)",
-)
overlapping_rect = ImageTextRegion.from_coords(50, 50, 150, 150)
nonoverlapping_rect = ImageTextRegion.from_coords(150, 150, 200, 200)
populated_text_region = EmbeddedTextRegion.from_coords(50, 50, 60, 60, text="test")
@@ -443,12 +373,6 @@ def check_annotated_image():
check_annotated_image()
-@pytest.mark.parametrize(("text", "expected"), [("asdf", "asdf"), (None, "")])
-def test_embedded_text_region(text, expected):
- etr = elements.EmbeddedTextRegion.from_coords(0, 0, 24, 24, text=text)
- assert etr.extract_text(objects=None) == expected
-
-
class MockDetectionModel(layout.UnstructuredObjectDetectionModel):
def initialize(self, *args, **kwargs):
pass
diff --git a/test_unstructured_inference/inference/test_layout_element.py b/test_unstructured_inference/inference/test_layout_element.py
index da4b0f10..f814c180 100644
--- a/test_unstructured_inference/inference/test_layout_element.py
+++ b/test_unstructured_inference/inference/test_layout_element.py
@@ -5,18 +5,6 @@
from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion
-def test_layout_element_extract_text(
- mock_layout_element,
- mock_text_region,
-):
- extracted_text = mock_layout_element.extract_text(
- objects=[mock_text_region],
- )
-
- assert isinstance(extracted_text, str)
- assert "Sample text" in extracted_text
-
-
def test_layout_element_do_dict(mock_layout_element):
expected = {
"coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)),
diff --git a/test_unstructured_inference/models/test_chippermodel.py b/test_unstructured_inference/models/test_chippermodel.py
index 5a63894b..c68aa6bc 100644
--- a/test_unstructured_inference/models/test_chippermodel.py
+++ b/test_unstructured_inference/models/test_chippermodel.py
@@ -3,6 +3,7 @@
import pytest
import torch
from PIL import Image
+from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured_inference.models import chipper
from unstructured_inference.models.base import get_model
@@ -139,13 +140,8 @@ def test_no_repeat_ngram_logits():
no_repeat_ngram_size = 2
- output = chipper._no_repeat_ngram_logits(
- input_ids=input_ids,
- cur_len=cur_len,
- logits=logits,
- batch_size=batch_size,
- no_repeat_ngram_size=no_repeat_ngram_size,
- )
+ logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2)
+ output = logitsProcessor(input_ids=input_ids, scores=logits)
assert (
int(
@@ -194,6 +190,25 @@ def test_no_repeat_ngram_logits():
)
+def test_ngram_repetiton_stopping_criteria():
+ input_ids = torch.tensor([[1, 2, 3, 4, 0, 1, 2, 3, 4]])
+ logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]])
+
+ stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
+ repetition_window=2, skip_tokens={0, 1, 2, 3, 4}
+ )
+
+ output = stoppingCriteria(input_ids=input_ids, scores=logits)
+
+ assert output is False
+
+ stoppingCriteria = chipper.NGramRepetitonStoppingCriteria(
+ repetition_window=2, skip_tokens={1, 2, 3, 4}
+ )
+ output = stoppingCriteria(input_ids=input_ids, scores=logits)
+ assert output is True
+
+
@pytest.mark.parametrize(
("decoded_str", "expected_classes"),
[
@@ -241,7 +256,51 @@ def test_postprocess_bbox(decoded_str, expected_classes):
assert out[i].type == expected_classes[i]
-def test_run_chipper_v2():
+def test_predict_tokens_beam_indices():
+ model = get_model("chipper")
+ model.stopping_criteria = [
+ chipper.NGramRepetitonStoppingCriteria(
+ repetition_window=1,
+ skip_tokens={},
+ ),
+ ]
+ img = Image.open("sample-docs/easy_table.jpg")
+ output = model.predict_tokens(image=img)
+ assert len(output) > 0
+
+
+def test_largest_margin_edge():
+ model = get_model("chipper")
+ img = Image.open("sample-docs/easy_table.jpg")
+ output = model.largest_margin(image=img, input_bbox=[0, 1, 0, 0], transpose=False)
+
+ assert output is None
+
+ output = model.largest_margin(img, [1, 1, 1, 1], False)
+
+ assert output is None
+
+ output = model.largest_margin(img, [2, 1, 3, 10], True)
+
+ assert output == (0, 0, 0)
+
+
+def test_deduplicate_detected_elements():
+ model = get_model("chipper")
+ img = Image.open("sample-docs/easy_table.jpg")
+ elements = model(img)
+
+ output = model.deduplicate_detected_elements(elements)
+
+ assert len(output) == 2
+
+
+def test_norepeatnGramlogitsprocessor_exception():
+ with pytest.raises(ValueError):
+ chipper.NoRepeatNGramLogitsProcessor(ngram_size="")
+
+
+def test_run_chipper_v3():
model = get_model("chipper")
img = Image.open("sample-docs/easy_table.jpg")
elements = model(img)
@@ -364,3 +423,26 @@ def test_check_overlap(bbox1, bbox2, output):
model = get_model("chipper")
assert model.check_overlap(bbox1, bbox2) == output
+
+
+def test_format_table_elements():
+ table_html = "
"
+ texts = [
+ "Text",
+ " - List element",
+ table_html,
+ None,
+ ]
+ elements = [LayoutElement(bbox=mock.MagicMock(), text=text) for text in texts]
+ formatted_elements = chipper.UnstructuredChipperModel.format_table_elements(elements)
+ text_attributes = [fe.text for fe in formatted_elements]
+ text_as_html_attributes = [
+ fe.text_as_html if hasattr(fe, "text_as_html") else None for fe in formatted_elements
+ ]
+ assert text_attributes == [
+ "Text",
+ " - List element",
+ "Cell 1Cell 2Cell 3",
+ None,
+ ]
+ assert text_as_html_attributes == [None, None, table_html, None]
diff --git a/test_unstructured_inference/models/test_detectron2.py b/test_unstructured_inference/models/test_detectron2.py
deleted file mode 100644
index 987120e1..00000000
--- a/test_unstructured_inference/models/test_detectron2.py
+++ /dev/null
@@ -1,50 +0,0 @@
-from unittest.mock import patch
-
-import pytest
-
-import unstructured_inference.models.base as models
-from unstructured_inference.models import detectron2
-
-
-class MockDetectron2LayoutModel:
- def __init__(self, *args, **kwargs):
- self.args = args
- self.kwargs = kwargs
-
- def detect(self, x):
- return []
-
-
-def test_load_default_model(monkeypatch):
- monkeypatch.setattr(detectron2, "Detectron2LayoutModel", MockDetectron2LayoutModel)
- monkeypatch.setattr(models, "models", {})
-
- with patch.object(detectron2, "is_detectron2_available", return_value=True):
- model = models.get_model("detectron2_lp")
-
- assert isinstance(model.model, MockDetectron2LayoutModel)
-
-
-def test_load_default_model_raises_when_not_available(monkeypatch):
- monkeypatch.setattr(models, "models", {})
- with patch.object(detectron2, "is_detectron2_available", return_value=False):
- with pytest.raises(ImportError):
- models.get_model("detectron2_lp")
-
-
-@pytest.mark.parametrize(("config_path", "model_path"), [("asdf", "diufs"), ("dfaw", "hfhfhfh")])
-def test_load_model(monkeypatch, config_path, model_path):
- monkeypatch.setattr(detectron2, "Detectron2LayoutModel", MockDetectron2LayoutModel)
- with patch.object(detectron2, "is_detectron2_available", return_value=True):
- model = detectron2.UnstructuredDetectronModel()
- model.initialize(config_path=config_path, model_path=model_path)
- assert config_path == model.model.args[0]
- assert model_path == model.model.kwargs["model_path"]
-
-
-def test_unstructured_detectron_model():
- model = detectron2.UnstructuredDetectronModel()
- model.model = MockDetectron2LayoutModel()
- result = model(None)
- assert isinstance(result, list)
- assert len(result) == 0
diff --git a/test_unstructured_inference/models/test_model.py b/test_unstructured_inference/models/test_model.py
index fc7c18aa..411ef3d4 100644
--- a/test_unstructured_inference/models/test_model.py
+++ b/test_unstructured_inference/models/test_model.py
@@ -1,3 +1,4 @@
+import json
from typing import Any
from unittest import mock
@@ -25,10 +26,33 @@ def predict(self, x: Any) -> Any:
return []
+MOCK_MODEL_TYPES = {
+ "foo": {
+ "input_shape": (640, 640),
+ },
+}
+
+
def test_get_model(monkeypatch):
monkeypatch.setattr(models, "models", {})
- with mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
- assert isinstance(models.get_model("checkbox"), MockModel)
+ with mock.patch.dict(models.model_class_map, {"yolox": MockModel}):
+ assert isinstance(models.get_model("yolox"), MockModel)
+
+
+def test_register_new_model():
+ assert "foo" not in models.model_class_map
+ assert "foo" not in models.model_config_map
+ models.register_new_model(MOCK_MODEL_TYPES, MockModel)
+ assert "foo" in models.model_class_map
+ assert "foo" in models.model_config_map
+ model = models.get_model("foo")
+ assert len(model.initializer.mock_calls) == 1
+ assert model.initializer.mock_calls[0][-1] == MOCK_MODEL_TYPES["foo"]
+ assert isinstance(model, MockModel)
+ # unregister the new model by reset to default
+ models.model_class_map, models.model_config_map = models.get_default_model_mappings()
+ assert "foo" not in models.model_class_map
+ assert "foo" not in models.model_config_map
def test_raises_invalid_model():
@@ -38,14 +62,16 @@ def test_raises_invalid_model():
def test_raises_uninitialized():
with pytest.raises(ModelNotInitializedError):
- models.UnstructuredDetectronModel().predict(None)
+ models.UnstructuredDetectronONNXModel().predict(None)
def test_model_initializes_once():
from unstructured_inference.inference import layout
with mock.patch.dict(models.model_class_map, {"yolox": MockModel}), mock.patch.object(
- models, "models", {}
+ models,
+ "models",
+ {},
):
doc = layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf")
doc.pages[0].detection_model.initializer.assert_called_once()
@@ -143,23 +169,32 @@ def test_env_variables_override_default_model(monkeypatch):
# args, we should get back the model the env var calls for
monkeypatch.setattr(models, "models", {})
with mock.patch.dict(
- models.os.environ, {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "checkbox"}
- ), mock.patch.dict(models.model_class_map, {"checkbox": MockModel}):
+ models.os.environ,
+ {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "yolox"},
+ ), mock.patch.dict(models.model_class_map, {"yolox": MockModel}):
model = models.get_model()
assert isinstance(model, MockModel)
-def test_env_variables_override_intialization_params(monkeypatch):
+def test_env_variables_override_initialization_params(monkeypatch):
# When initialization params are specified in an environment variable, and we call get_model, we
# should see that the model was initialized with those params
monkeypatch.setattr(models, "models", {})
+ fake_label_map = {"1": "label1", "2": "label2"}
with mock.patch.dict(
models.os.environ,
{"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH": "fake_json.json"},
), mock.patch.object(models, "DEFAULT_MODEL", "fake"), mock.patch.dict(
- models.model_class_map, {"fake": mock.MagicMock()}
+ models.model_class_map,
+ {"fake": mock.MagicMock()},
), mock.patch(
- "builtins.open", mock.mock_open(read_data='{"date": "3/26/81"}')
+ "builtins.open",
+ mock.mock_open(
+ read_data='{"model_path": "fakepath", "label_map": ' + json.dumps(fake_label_map) + "}",
+ ),
):
model = models.get_model()
- model.initialize.assert_called_once_with(date="3/26/81")
+ model.initialize.assert_called_once_with(
+ model_path="fakepath",
+ label_map={1: "label1", 2: "label2"},
+ )
diff --git a/test_unstructured_inference/models/test_supergradients.py b/test_unstructured_inference/models/test_supergradients.py
deleted file mode 100644
index d56df7d0..00000000
--- a/test_unstructured_inference/models/test_supergradients.py
+++ /dev/null
@@ -1,42 +0,0 @@
-from os import path
-from PIL import Image
-
-from unstructured_inference.constants import Source
-from unstructured_inference.models import super_gradients
-
-
-def test_supergradients_model():
- model_path = path.join(path.dirname(__file__), "test_ci_model.onnx")
- model = super_gradients.UnstructuredSuperGradients()
- model.initialize(
- model_path=model_path,
- label_map={
- "0": "Picture",
- "1": "Caption",
- "2": "Text",
- "3": "Formula",
- "4": "Page number",
- "5": "Address",
- "6": "Footer",
- "7": "Subheadline",
- "8": "Chart",
- "9": "Metadata",
- "10": "Title",
- "11": "Misc",
- "12": "Header",
- "13": "Table",
- "14": "Headline",
- "15": "List-item",
- "16": "List",
- "17": "Author",
- "18": "Value",
- "19": "Link",
- "20": "Field-Name",
- },
- input_shape=(1024, 1024),
- )
- img = Image.open("sample-docs/loremipsum.jpg")
- el, *_ = model(img)
- assert el.source == Source.SUPER_GRADIENTS
- assert el.prob > 0.70
- assert el.type == "Text"
diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py
index 9131cbce..15c467cd 100644
--- a/test_unstructured_inference/models/test_tables.py
+++ b/test_unstructured_inference/models/test_tables.py
@@ -7,9 +7,11 @@
from transformers.models.table_transformer.modeling_table_transformer import (
TableTransformerDecoder,
)
+from copy import deepcopy
import unstructured_inference.models.table_postprocess as postprocess
from unstructured_inference.models import tables
+from unstructured_inference.models.tables import apply_thresholds_on_objects, structure_to_cells
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
@@ -932,9 +934,43 @@ def test_table_prediction_output_format(
assert expectation in result
+def test_table_prediction_output_format_when_wrong_type_then_value_error(
+ table_transformer,
+ example_image,
+ mocker,
+ example_table_cells,
+ mocked_ocr_tokens,
+):
+ mocker.patch.object(tables, "recognize", return_value=example_table_cells)
+ mocker.patch.object(
+ tables.UnstructuredTableTransformerModel,
+ "get_structure",
+ return_value=None,
+ )
+ with pytest.raises(ValueError):
+ table_transformer.run_prediction(
+ example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens
+ )
+
+
+def test_table_prediction_runs_with_empty_recognize(
+ table_transformer,
+ example_image,
+ mocker,
+ mocked_ocr_tokens,
+):
+ mocker.patch.object(tables, "recognize", return_value=[])
+ mocker.patch.object(
+ tables.UnstructuredTableTransformerModel,
+ "get_structure",
+ return_value=None,
+ )
+ assert table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens) == ""
+
+
def test_table_prediction_with_ocr_tokens(table_transformer, example_image, mocked_ocr_tokens):
prediction = table_transformer.predict(example_image, ocr_tokens=mocked_ocr_tokens)
- assert '' in prediction
+ assert '' in prediction
assert " |
---|
Blind | 5 | 1 | 4 | 34.5%, n=1 | " in prediction
@@ -943,6 +979,55 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image):
table_transformer.predict(example_image)
+@pytest.mark.parametrize(
+ ("thresholds", "expected_object_number"),
+ [
+ ({"0": 0.5}, 1),
+ ({"0": 0.1}, 3),
+ ({"0": 0.9}, 0),
+ ],
+)
+def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold(
+ thresholds, expected_object_number
+):
+ objects = [
+ {"label": "0", "score": 0.2},
+ {"label": "0", "score": 0.4},
+ {"label": "0", "score": 0.55},
+ ]
+ assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number
+
+
+@pytest.mark.parametrize(
+ ("thresholds", "expected_object_number"),
+ [
+ ({"0": 0.5, "1": 0.1}, 4),
+ ({"0": 0.1, "1": 0.9}, 3),
+ ({"0": 0.9, "1": 0.5}, 1),
+ ],
+)
+def test_objects_are_filtered_based_on_class_thresholds_when_two_classes(
+ thresholds, expected_object_number
+):
+ objects = [
+ {"label": "0", "score": 0.2},
+ {"label": "0", "score": 0.4},
+ {"label": "0", "score": 0.55},
+ {"label": "1", "score": 0.2},
+ {"label": "1", "score": 0.4},
+ {"label": "1", "score": 0.55},
+ ]
+ assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number
+
+
+def test_objects_filtering_when_missing_threshold():
+ class_name = "class_name"
+ objects = [{"label": class_name, "score": 0.2}]
+ thresholds = {"1": 0.5}
+ with pytest.raises(KeyError, match=class_name):
+ apply_thresholds_on_objects(objects, thresholds)
+
+
def test_intersect():
a = postprocess.Rect()
b = postprocess.Rect([1, 2, 3, 4])
@@ -1131,26 +1216,6 @@ def test_header_supercell_tree(supercells, expected_len):
assert len(supercells) == expected_len
-def test_cells_to_html():
- # example table
- # +----------+---------------------+
- # | two | two columns |
- # | |----------+----------|
- # | rows |sub cell 1|sub cell 2|
- # +----------+----------+----------+
- cells = [
- {"row_nums": [0, 1], "column_nums": [0], "cell text": "two row", "column header": False},
- {"row_nums": [0], "column_nums": [1, 2], "cell text": "two cols", "column header": False},
- {"row_nums": [1], "column_nums": [1], "cell text": "sub cell 1", "column header": False},
- {"row_nums": [1], "column_nums": [2], "cell text": "sub cell 2", "column header": False},
- ]
- expected = (
- 'two row | two '
- "cols | | sub cell 1 | sub cell 2 | "
- )
- assert tables.cells_to_html(cells) == expected
-
-
@pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0])
def test_zoom_image(example_image, zoom):
width, height = example_image.size
@@ -1162,6 +1227,534 @@ def test_zoom_image(example_image, zoom):
assert new_h == np.round(height * zoom, 0)
+@pytest.mark.parametrize(
+ ("input_cells", "expected_html"),
+ [
+ # +----------+---------------------+
+ # | row1col1 | row1col2 | row1col3 |
+ # |----------|----------+----------|
+ # | row2col1 | row2col2 | row2col3 |
+ # +----------+----------+----------+
+ pytest.param(
+ [
+ {
+ "row_nums": [0],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [0],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [0],
+ "column_nums": [2],
+ "cell text": "row1col3",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ ],
+ (
+ "row1col1 | row1col2 | row1col3 | "
+ "row2col1 | row2col2 | row2col3 | "
+ ),
+ id="simple table without header",
+ ),
+ # +----------+---------------------+
+ # | h1col1 | h1col2 | h1col3 |
+ # |----------|----------+----------|
+ # | row1col1 | row1col2 | row1col3 |
+ # |----------|----------+----------|
+ # | row2col1 | row2col2 | row2col3 |
+ # +----------+----------+----------+
+ pytest.param(
+ [
+ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
+ {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
+ {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [2],
+ "cell text": "row1col3",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ ],
+ (
+ "h1col1 | h1col2 | h1col2 | "
+ "row1col1 | row1col2 | row1col3 | "
+ "row2col1 | row2col2 | row2col3 | "
+ ),
+ id="simple table with header",
+ ),
+ # +----------+---------------------+
+ # | h1col1 | h1col2 | h1col3 |
+ # |----------|----------+----------|
+ # | row1col1 | row1col2 | row1col3 |
+ # |----------|----------+----------|
+ # | row2col1 | row2col2 | row2col3 |
+ # +----------+----------+----------+
+ pytest.param(
+ [
+ {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [2],
+ "cell text": "row1col3",
+ "column header": False,
+ },
+ {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
+ ],
+ (
+ "h1col1 | h1col2 | h1col2 | "
+ "row1col1 | row1col2 | row1col3 | "
+ "row2col1 | row2col2 | row2col3 | "
+ ),
+ id="simple table with header, mixed elements",
+ ),
+ # +----------+---------------------+
+ # | two | two columns |
+ # | |----------+----------|
+ # | rows |sub cell 1|sub cell 2|
+ # +----------+----------+----------+
+ pytest.param(
+ [
+ {
+ "row_nums": [0, 1],
+ "column_nums": [0],
+ "cell text": "two row",
+ "column header": False,
+ },
+ {
+ "row_nums": [0],
+ "column_nums": [1, 2],
+ "cell text": "two cols",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "sub cell 1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [2],
+ "cell text": "sub cell 2",
+ "column header": False,
+ },
+ ],
+ (
+ 'two row | two '
+ "cols | sub cell 1 | sub cell 2 | "
+ " "
+ ),
+ id="various spans, no headers",
+ ),
+ # +----------+---------------------+----------+
+ # | | h1col23 | h1col4 |
+ # | h12col1 |----------+----------+----------|
+ # | | h2col2 | h2col34 |
+ # |----------|----------+----------+----------+
+ # | r3col1 | r3col2 | |
+ # |----------+----------| r34col34 |
+ # | r4col12 | |
+ # +----------+----------+----------+----------+
+ pytest.param(
+ [
+ {
+ "row_nums": [0, 1],
+ "column_nums": [0],
+ "cell text": "h12col1",
+ "column header": True,
+ },
+ {
+ "row_nums": [0],
+ "column_nums": [1, 2],
+ "cell text": "h1col23",
+ "column header": True,
+ },
+ {"row_nums": [0], "column_nums": [3], "cell text": "h1col4", "column header": True},
+ {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [2, 3],
+ "cell text": "h2col34",
+ "column header": True,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "r3col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "r3col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2, 3],
+ "column_nums": [2, 3],
+ "cell text": "r34col34",
+ "column header": False,
+ },
+ {
+ "row_nums": [3],
+ "column_nums": [0, 1],
+ "cell text": "r4col12",
+ "column header": False,
+ },
+ ],
+ (
+ 'h12col1 | '
+ 'h1col23 | h1col4 | '
+ 'h2col2 | h2col34 | '
+ 'r3col1 | r3col2 | r34col34 | '
+ 'r4col12 | '
+ ),
+ id="various spans, with 2 row header",
+ ),
+ ],
+)
+def test_cells_to_html(input_cells, expected_html):
+ assert tables.cells_to_html(input_cells) == expected_html
+
+
+@pytest.mark.parametrize(
+ ("input_cells", "expected_cells"),
+ [
+ pytest.param(
+ [
+ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
+ {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
+ {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [2],
+ "cell text": "row1col3",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ ],
+ [
+ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
+ {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True},
+ {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [2],
+ "cell text": "row1col3",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ ],
+ id="identical tables, no changes expected",
+ ),
+ pytest.param(
+ [
+ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
+ {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ ],
+ [
+ {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True},
+ {"row_nums": [0], "column_nums": [1], "cell text": "", "column header": True},
+ {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [0],
+ "cell text": "row1col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [1],
+ "column_nums": [1],
+ "cell text": "row1col2",
+ "column header": False,
+ },
+ {"row_nums": [1], "column_nums": [2], "cell text": "", "column header": False},
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "row2col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [1],
+ "cell text": "row2col2",
+ "column header": False,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [2],
+ "cell text": "row2col3",
+ "column header": False,
+ },
+ ],
+ id="missing column in header and in the middle",
+ ),
+ pytest.param(
+ [
+ {
+ "row_nums": [0, 1],
+ "column_nums": [0],
+ "cell text": "h12col1",
+ "column header": True,
+ },
+ {
+ "row_nums": [0],
+ "column_nums": [1, 2],
+ "cell text": "h1col23",
+ "column header": True,
+ },
+ {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [2, 3],
+ "cell text": "h2col34",
+ "column header": True,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "r3col1",
+ "column header": False,
+ },
+ {
+ "row_nums": [2, 3],
+ "column_nums": [2, 3],
+ "cell text": "r34col34",
+ "column header": False,
+ },
+ {
+ "row_nums": [3],
+ "column_nums": [0, 1],
+ "cell text": "r4col12",
+ "column header": False,
+ },
+ ],
+ [
+ {
+ "row_nums": [0, 1],
+ "column_nums": [0],
+ "cell text": "h12col1",
+ "column header": True,
+ },
+ {
+ "row_nums": [0],
+ "column_nums": [1, 2],
+ "cell text": "h1col23",
+ "column header": True,
+ },
+ {"row_nums": [0], "column_nums": [3], "cell text": "", "column header": True},
+ {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True},
+ {
+ "row_nums": [1],
+ "column_nums": [2, 3],
+ "cell text": "h2col34",
+ "column header": True,
+ },
+ {
+ "row_nums": [2],
+ "column_nums": [0],
+ "cell text": "r3col1",
+ "column header": False,
+ },
+ {"row_nums": [2], "column_nums": [1], "cell text": "", "column header": False},
+ {
+ "row_nums": [2, 3],
+ "column_nums": [2, 3],
+ "cell text": "r34col34",
+ "column header": False,
+ },
+ {
+ "row_nums": [3],
+ "column_nums": [0, 1],
+ "cell text": "r4col12",
+ "column header": False,
+ },
+ ],
+ id="missing column in header and in the middle in table with spans",
+ ),
+ ],
+)
+def test_fill_cells(input_cells, expected_cells):
+ def sort_cells(cells):
+ return sorted(cells, key=lambda x: (x["row_nums"], x["column_nums"]))
+
+ assert sort_cells(tables.fill_cells(input_cells)) == sort_cells(expected_cells)
+
+
def test_padded_results_has_right_dimensions(table_transformer, example_image):
str_class_name2idx = tables.get_class_map("structure")
# a simpler mapping so we keep all structure in the returned objs below for test
@@ -1201,3 +1794,100 @@ def test_padded_results_has_right_dimensions(table_transformer, example_image):
def test_compute_confidence_score_zero_division_error_handling():
assert tables.compute_confidence_score([]) == 0
+
+
+@pytest.mark.parametrize(
+ "column_span_score, row_span_score, expected_text_to_indexes",
+ [
+ (
+ 0.9,
+ 0.8,
+ (
+ {
+ "one three": {"row_nums": [0, 1], "column_nums": [0]},
+ "two": {"row_nums": [0], "column_nums": [1]},
+ "four": {"row_nums": [1], "column_nums": [1]},
+ }
+ ),
+ ),
+ (
+ 0.8,
+ 0.9,
+ (
+ {
+ "one two": {"row_nums": [0], "column_nums": [0, 1]},
+ "three": {"row_nums": [1], "column_nums": [0]},
+ "four": {"row_nums": [1], "column_nums": [1]},
+ }
+ ),
+ ),
+ ],
+)
+def test_subcells_filtering_when_overlapping_spanning_cells(
+ column_span_score, row_span_score, expected_text_to_indexes
+):
+ """
+ # table
+ # +-----------+----------+
+ # | one | two |
+ # |-----------+----------|
+ # | three | four |
+ # +-----------+----------+
+
+ spanning cells over first row and over first column
+ """
+ table_structure = {
+ "rows": [
+ {"bbox": [0, 0, 10, 20]},
+ {"bbox": [10, 0, 20, 20]},
+ ],
+ "columns": [
+ {"bbox": [0, 0, 20, 10]},
+ {"bbox": [0, 10, 20, 20]},
+ ],
+ "spanning cells": [
+ {"bbox": [0, 0, 20, 10], "score": column_span_score},
+ {"bbox": [0, 0, 10, 20], "score": row_span_score},
+ ],
+ }
+ tokens = [
+ {
+ "text": "one",
+ "bbox": [0, 0, 10, 10],
+ },
+ {
+ "text": "two",
+ "bbox": [0, 10, 10, 20],
+ },
+ {
+ "text": "three",
+ "bbox": [10, 0, 20, 10],
+ },
+ {"text": "four", "bbox": [10, 10, 20, 20]},
+ ]
+ token_args = {"span_num": 1, "line_num": 1, "block_num": 1}
+ for token in tokens:
+ token.update(token_args)
+ for spanning_cell in table_structure["spanning cells"]:
+ spanning_cell["projected row header"] = False
+
+ # table structure is edited inside structure_to_cells, save copy for future runs
+ saved_table_structure = deepcopy(table_structure)
+
+ predicted_cells, _ = structure_to_cells(table_structure, tokens=tokens)
+ predicted_text_to_indexes = {
+ cell["cell text"]: {
+ "row_nums": cell["row_nums"],
+ "column_nums": cell["column_nums"],
+ }
+ for cell in predicted_cells
+ }
+ assert predicted_text_to_indexes == expected_text_to_indexes
+
+ # swap spanning cells to ensure the highest prob spanning cell is used
+ spans = saved_table_structure["spanning cells"]
+ spans[0], spans[1] = spans[1], spans[0]
+ saved_table_structure["spanning cells"] = spans
+
+ predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens)
+ assert predicted_cells_after_reorder == predicted_cells
diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py
index abd4ce0c..f6f5b568 100644
--- a/test_unstructured_inference/test_elements.py
+++ b/test_unstructured_inference/test_elements.py
@@ -6,12 +6,12 @@
from unstructured_inference.constants import ElementType
from unstructured_inference.inference import elements
-from unstructured_inference.inference.elements import TextRegion, ImageTextRegion
+from unstructured_inference.inference.elements import Rectangle, TextRegion, ImageTextRegion
from unstructured_inference.inference.layoutelement import (
+ LayoutElement,
+ merge_inferred_layout_with_extracted_layout,
partition_groups_from_regions,
separate,
- merge_inferred_layout_with_extracted_layout,
- LayoutElement,
)
skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"}
@@ -31,6 +31,18 @@ def rand_rect(size=10):
return elements.Rectangle(x1, y1, x1 + size, y1 + size)
+@pytest.mark.parametrize(
+ ("rect1", "rect2", "expected"),
+ [
+ (Rectangle(0, 0, 1, 1), Rectangle(0, 0, None, None), None),
+ (Rectangle(0, 0, None, None), Rectangle(0, 0, 1, 1), None),
+ ],
+)
+def test_unhappy_intersection(rect1, rect2, expected):
+ assert rect1.intersection(rect2) == expected
+ assert not rect1.intersects(rect2)
+
+
@pytest.mark.parametrize("second_size", [10, 20])
def test_intersects(second_size):
for _ in range(1000):
@@ -106,10 +118,6 @@ def test_partition_groups_from_regions(mock_embedded_text_regions):
text = "".join([el.text for el in sorted_groups[-1]])
assert text.startswith("Layout")
- words = []
- groups = partition_groups_from_regions(words)
- assert len(groups) == 0
-
def test_rectangle_area(monkeypatch):
for _ in range(1000):
@@ -202,7 +210,10 @@ def test_intersection_over_min(
def test_grow_region_to_match_region():
- from unstructured_inference.inference.elements import Rectangle, grow_region_to_match_region
+ from unstructured_inference.inference.elements import (
+ Rectangle,
+ grow_region_to_match_region,
+ )
a = Rectangle(1, 1, 2, 2)
b = Rectangle(1, 1, 5, 5)
diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py
index 0f031048..33c5779d 100644
--- a/unstructured_inference/__version__.py
+++ b/unstructured_inference/__version__.py
@@ -1 +1 @@
-__version__ = "0.7.20" # pragma: no cover
+__version__ = "0.7.37-dev0" # pragma: no cover
diff --git a/unstructured_inference/config.py b/unstructured_inference/config.py
index 0d85b6a3..d5765bbf 100644
--- a/unstructured_inference/config.py
+++ b/unstructured_inference/config.py
@@ -5,6 +5,7 @@
settings that should not be altered without making a code change (e.g., definition of 1Gb of memory
in bytes). Constants should go into `./constants.py`
"""
+
import os
from dataclasses import dataclass
diff --git a/unstructured_inference/constants.py b/unstructured_inference/constants.py
index 5de6eea8..173e37b4 100644
--- a/unstructured_inference/constants.py
+++ b/unstructured_inference/constants.py
@@ -8,16 +8,33 @@ class Source(Enum):
CHIPPER = "chipper"
CHIPPERV1 = "chipperv1"
CHIPPERV2 = "chipperv2"
+ CHIPPERV3 = "chipperv3"
MERGED = "merged"
- SUPER_GRADIENTS = "super-gradients"
+
+
+CHIPPER_VERSIONS = (
+ Source.CHIPPER,
+ Source.CHIPPERV1,
+ Source.CHIPPERV2,
+ Source.CHIPPERV3,
+)
class ElementType:
+ PARAGRAPH = "Paragraph"
IMAGE = "Image"
+ PARAGRAPH_IN_IMAGE = "ParagraphInImage"
FIGURE = "Figure"
PICTURE = "Picture"
TABLE = "Table"
+ PARAGRAPH_IN_TABLE = "ParagraphInTable"
LIST = "List"
+ FORM = "Form"
+ PARAGRAPH_IN_FORM = "ParagraphInForm"
+ CHECK_BOX_CHECKED = "CheckBoxChecked"
+ CHECK_BOX_UNCHECKED = "CheckBoxUnchecked"
+ RADIO_BUTTON_CHECKED = "RadioButtonChecked"
+ RADIO_BUTTON_UNCHECKED = "RadioButtonUnchecked"
LIST_ITEM = "List-item"
FORMULA = "Formula"
CAPTION = "Caption"
@@ -29,6 +46,12 @@ class ElementType:
TEXT = "Text"
UNCATEGORIZED_TEXT = "UncategorizedText"
PAGE_BREAK = "PageBreak"
+ CODE_SNIPPET = "CodeSnippet"
+ PAGE_NUMBER = "PageNumber"
+ OTHER = "Other"
FULL_PAGE_REGION_THRESHOLD = 0.99
+
+# this field is defined by pytesseract/unstructured.pytesseract
+TESSERACT_TEXT_HEIGHT = "height"
diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py
index e1194fd6..f6b4d4d8 100644
--- a/unstructured_inference/inference/elements.py
+++ b/unstructured_inference/inference/elements.py
@@ -1,10 +1,8 @@
from __future__ import annotations
-import re
-import unicodedata
from copy import deepcopy
from dataclasses import dataclass
-from typing import Collection, Optional, Union
+from typing import Optional, Union
import numpy as np
@@ -67,6 +65,8 @@ def is_disjoint(self, other: Rectangle) -> bool:
def intersects(self, other: Rectangle) -> bool:
"""Checks whether this rectangle intersects another rectangle."""
+ if self._has_none() or other._has_none():
+ return False
return intersections(self, other)[0, 1]
def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = None) -> bool:
@@ -81,6 +81,10 @@ def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = No
],
)
+ def _has_none(self) -> bool:
+ """return false when one of the coord is nan"""
+ return any((self.x1 is None, self.x2 is None, self.y1 is None, self.y2 is None))
+
@property
def coordinates(self):
"""Gets coordinates of the rectangle"""
@@ -89,6 +93,8 @@ def coordinates(self):
def intersection(self, other: Rectangle) -> Optional[Rectangle]:
"""Gives the rectangle that is the intersection of two rectangles, or None if the
rectangles are disjoint."""
+ if self._has_none() or other._has_none():
+ return None
x1 = max(self.x1, other.x1)
x2 = min(self.x2, other.x2)
y1 = max(self.y1, other.y1)
@@ -176,21 +182,6 @@ class TextRegion:
def __str__(self) -> str:
return str(self.text)
- def extract_text(
- self,
- objects: Optional[Collection[TextRegion]],
- ) -> str:
- """Extracts text contained in region."""
- if self.text is not None:
- # If block text is already populated, we'll assume it's correct
- text = self.text
- elif objects is not None:
- text = aggregate_by_block(self, objects)
- else:
- text = ""
- cleaned_text = remove_control_characters(text)
- return cleaned_text
-
@classmethod
def from_coords(
cls,
@@ -209,67 +200,11 @@ def from_coords(
class EmbeddedTextRegion(TextRegion):
- def extract_text(
- self,
- objects: Optional[Collection[TextRegion]],
- ) -> str:
- """Extracts text contained in region."""
- if self.text is None:
- return ""
- else:
- return self.text
+ pass
class ImageTextRegion(TextRegion):
- def extract_text(
- self,
- objects: Optional[Collection[TextRegion]],
- ) -> str:
- """Extracts text contained in region."""
- if self.text is None:
- return ""
- else:
- return super().extract_text(objects)
-
-
-def aggregate_by_block(
- text_region: TextRegion,
- pdf_objects: Collection[TextRegion],
-) -> str:
- """Extracts the text aggregated from the elements of the given layout that lie within the given
- block."""
- filtered_blocks = [
- obj for obj in pdf_objects if obj.bbox.is_in(text_region.bbox, error_margin=5)
- ]
- text = " ".join([x.text for x in filtered_blocks if x.text])
- return text
-
-
-def cid_ratio(text: str) -> float:
- """Gets ratio of unknown 'cid' characters extracted from text to all characters."""
- if not is_cid_present(text):
- return 0.0
- cid_pattern = r"\(cid\:(\d+)\)"
- unmatched, n_cid = re.subn(cid_pattern, "", text)
- total = n_cid + len(unmatched)
- return n_cid / total
-
-
-def is_cid_present(text: str) -> bool:
- """Checks if a cid code is present in a text selection."""
- if len(text) < len("(cid:x)"):
- return False
- return text.find("(cid:") != -1
-
-
-def remove_control_characters(text: str) -> str:
- """Removes control characters from text."""
-
- # Replace newline character with a space
- text = text.replace("\n", " ")
- # Remove other control characters
- out_text = "".join(c for c in text if unicodedata.category(c)[0] != "C")
- return out_text
+ pass
def region_bounding_boxes_are_almost_the_same(
diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py
index e1491c57..517a2dba 100644
--- a/unstructured_inference/inference/layout.py
+++ b/unstructured_inference/inference/layout.py
@@ -15,10 +15,8 @@
from unstructured_inference.inference.layoutelement import (
LayoutElement,
)
-from unstructured_inference.inference.ordering import order_layout
from unstructured_inference.logger import logger
from unstructured_inference.models.base import get_model
-from unstructured_inference.models.chipper import UnstructuredChipperModel
from unstructured_inference.models.unstructuredmodel import (
UnstructuredElementExtractionModel,
UnstructuredObjectDetectionModel,
@@ -140,7 +138,7 @@ def __init__(
):
if detection_model is not None and element_extraction_model is not None:
raise ValueError("Only one of detection_model and extraction_model should be passed.")
- self.image = image
+ self.image: Optional[Image.Image] = image
if image_metadata is None:
image_metadata = {}
self.image_metadata = image_metadata
@@ -167,6 +165,7 @@ def get_elements_using_image_extraction(
raise ValueError(
"Cannot get elements using image extraction, no image extraction model defined",
)
+ assert self.image is not None
elements = self.element_extraction_model(self.image)
if inplace:
self.elements = elements
@@ -178,7 +177,6 @@ def get_elements_with_detection_model(
inplace: bool = True,
) -> Optional[List[LayoutElement]]:
"""Uses specified model to detect the elements on the page."""
- logger.info("Detecting page elements ...")
if self.detection_model is None:
model = get_model()
if isinstance(model, UnstructuredObjectDetectionModel):
@@ -188,6 +186,7 @@ def get_elements_with_detection_model(
# NOTE(mrobinson) - We'll want make this model inference step some kind of
# remote call in the future.
+ assert self.image is not None
inferred_layout: List[LayoutElement] = self.detection_model(self.image)
inferred_layout = self.detection_model.deduplicate_detected_elements(
inferred_layout,
@@ -199,29 +198,6 @@ def get_elements_with_detection_model(
return inferred_layout
- def get_elements_from_layout(
- self,
- layout: List[TextRegion],
- pdf_objects: Optional[List[TextRegion]] = None,
- ) -> List[LayoutElement]:
- """Uses the given Layout to separate the page text into elements, either extracting the
- text from the discovered layout blocks."""
-
- # If the model is a chipper model, we don't want to order the
- # elements, as they are already ordered
- order_elements = not isinstance(self.detection_model, UnstructuredChipperModel)
- if order_elements:
- layout = order_layout(layout)
-
- elements = [
- get_element_from_block(
- block=e,
- pdf_objects=pdf_objects,
- )
- for e in layout
- ]
- return elements
-
def _get_image_array(self) -> Union[np.ndarray, None]:
"""Converts the raw image into a numpy array."""
if self.image_array is None:
@@ -322,13 +298,13 @@ def from_image(
detection_model=detection_model,
element_extraction_model=element_extraction_model,
)
+ # FIXME (yao): refactor the other methods so they all return elements like the third route
if page.element_extraction_model is not None:
page.get_elements_using_image_extraction()
- return page
- if fixed_layout is None:
+ elif fixed_layout is None:
page.get_elements_with_detection_model()
else:
- page.elements = page.get_elements_from_layout(fixed_layout)
+ page.elements = []
page.image_metadata = {
"format": page.image.format if page.image else None,
@@ -403,19 +379,6 @@ def process_file_with_model(
return layout
-def get_element_from_block(
- block: TextRegion,
- pdf_objects: Optional[List[TextRegion]] = None,
-) -> LayoutElement:
- """Creates a LayoutElement from a given layout or image by finding all the text that lies within
- a given block."""
- element = block if isinstance(block, LayoutElement) else LayoutElement.from_region(block)
- element.text = element.extract_text(
- objects=pdf_objects,
- )
- return element
-
-
def convert_pdf_to_image(
filename: str,
dpi: int = 200,
diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py
index 376c419e..37a9ef24 100644
--- a/unstructured_inference/inference/layoutelement.py
+++ b/unstructured_inference/inference/layoutelement.py
@@ -10,6 +10,7 @@
from unstructured_inference.config import inference_config
from unstructured_inference.constants import (
+ CHIPPER_VERSIONS,
FULL_PAGE_REGION_THRESHOLD,
ElementType,
Source,
@@ -31,16 +32,6 @@ class LayoutElement(TextRegion):
image_path: Optional[str] = None
parent: Optional[LayoutElement] = None
- def extract_text(
- self,
- objects: Optional[Collection[TextRegion]],
- ):
- """Extracts text contained in region"""
- text = super().extract_text(
- objects=objects,
- )
- return text
-
def to_dict(self) -> dict:
"""Converts the class instance to dictionary form."""
out_dict = {
@@ -108,7 +99,7 @@ def merge_inferred_layout_with_extracted_layout(
continue
region_matched = False
for inferred_region in inferred_layout:
- if inferred_region.source in (Source.CHIPPER, Source.CHIPPERV1):
+ if inferred_region.source in CHIPPER_VERSIONS:
continue
if inferred_region.bbox.intersects(extracted_region.bbox):
@@ -164,9 +155,11 @@ def merge_inferred_layout_with_extracted_layout(
categorized_extracted_elements_to_add = [
LayoutElement(
text=el.text,
- type=ElementType.IMAGE
- if isinstance(el, ImageTextRegion)
- else ElementType.UNCATEGORIZED_TEXT,
+ type=(
+ ElementType.IMAGE
+ if isinstance(el, ImageTextRegion)
+ else ElementType.UNCATEGORIZED_TEXT
+ ),
source=el.source,
bbox=el.bbox,
)
@@ -221,7 +214,9 @@ def reduce(keep: Rectangle, reduce: Rectangle):
reduce(keep=region_b, reduce=region_a)
-def table_cells_to_dataframe(cells: dict, nrows: int = 1, ncols: int = 1, header=None) -> DataFrame:
+def table_cells_to_dataframe(
+ cells: List[dict], nrows: int = 1, ncols: int = 1, header=None
+) -> DataFrame:
"""convert table-transformer's cells data into a pandas dataframe"""
arr = np.empty((nrows, ncols), dtype=object)
for cell in cells:
diff --git a/unstructured_inference/inference/ordering.py b/unstructured_inference/inference/ordering.py
deleted file mode 100644
index 33b823c7..00000000
--- a/unstructured_inference/inference/ordering.py
+++ /dev/null
@@ -1,35 +0,0 @@
-from typing import List
-
-from unstructured_inference.inference.elements import TextRegion
-
-
-def order_layout(
- layout: List[TextRegion],
- column_tol_factor: float = 0.2,
- full_page_threshold_factor: float = 0.9,
-) -> List[TextRegion]:
- """Orders the layout elements detected on a page. For groups of elements that are not
- the width of the page, the algorithm attempts to group elements into column based on
- the coordinates of the bounding box. Columns are ordered left to right, and elements
- within columns are ordered top to bottom.
-
- Parameters
- ----------
- layout
- the layout elements to order.
- column_tol_factor
- multiplied by the page width to find the tolerance for considering two elements as
- part of the same column.
- full_page_threshold_factor
- multiplied by the page width to find the minimum width an elements need to be
- for it to be considered a full page width element.
- """
- if len(layout) == 0:
- return []
-
- layout.sort(
- key=lambda element: (element.bbox.y1, element.bbox.x1, element.bbox.y2, element.bbox.x2),
- )
- # NOTE(alan): Temporarily revert to orginal logic pending fixing the new logic
- # See code prior to this commit for new logic.
- return layout
diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py
index 26336e23..7ebe6e42 100644
--- a/unstructured_inference/models/base.py
+++ b/unstructured_inference/models/base.py
@@ -1,43 +1,49 @@
+from __future__ import annotations
+
import json
import os
-from typing import Dict, Optional, Type
+from typing import Dict, Optional, Tuple, Type
from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES
from unstructured_inference.models.chipper import UnstructuredChipperModel
-from unstructured_inference.models.detectron2 import (
- MODEL_TYPES as DETECTRON2_MODEL_TYPES,
-)
-from unstructured_inference.models.detectron2 import (
- UnstructuredDetectronModel,
-)
-from unstructured_inference.models.detectron2onnx import (
- MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES,
-)
-from unstructured_inference.models.detectron2onnx import (
- UnstructuredDetectronONNXModel,
-)
-from unstructured_inference.models.super_gradients import (
- UnstructuredSuperGradients,
-)
+from unstructured_inference.models.detectron2onnx import MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES
+from unstructured_inference.models.detectron2onnx import UnstructuredDetectronONNXModel
from unstructured_inference.models.unstructuredmodel import UnstructuredModel
-from unstructured_inference.models.yolox import (
- MODEL_TYPES as YOLOX_MODEL_TYPES,
-)
-from unstructured_inference.models.yolox import (
- UnstructuredYoloXModel,
-)
+from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES
+from unstructured_inference.models.yolox import UnstructuredYoloXModel
+from unstructured_inference.utils import LazyDict
DEFAULT_MODEL = "yolox"
models: Dict[str, UnstructuredModel] = {}
-model_class_map: Dict[str, Type[UnstructuredModel]] = {
- **{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES},
- **{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
- **{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
- **{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
- "super_gradients": UnstructuredSuperGradients,
-}
+
+def get_default_model_mappings() -> Tuple[
+ Dict[str, Type[UnstructuredModel]],
+ Dict[str, dict | LazyDict],
+]:
+ """default model mappings for models that are in `unstructured_inference` repo"""
+ return {
+ **{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES},
+ **{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES},
+ **{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES},
+ }, {
+ **DETECTRON2_ONNX_MODEL_TYPES,
+ **YOLOX_MODEL_TYPES,
+ **CHIPPER_MODEL_TYPES,
+ }
+
+
+model_class_map, model_config_map = get_default_model_mappings()
+
+
+def register_new_model(model_config: dict, model_class: UnstructuredModel):
+ """Register this model in model_config_map and model_class_map.
+
+ Those maps are updated with the with the new model class information.
+ """
+ model_config_map.update(model_config)
+ model_class_map.update({name: model_class for name in model_config})
def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
@@ -58,15 +64,13 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
if initialize_param_json is not None:
with open(initialize_param_json) as fp:
initialize_params = json.load(fp)
+ label_map_int_keys = {
+ int(key): value for key, value in initialize_params["label_map"].items()
+ }
+ initialize_params["label_map"] = label_map_int_keys
else:
- if model_name in DETECTRON2_MODEL_TYPES:
- initialize_params = DETECTRON2_MODEL_TYPES[model_name]
- elif model_name in DETECTRON2_ONNX_MODEL_TYPES:
- initialize_params = DETECTRON2_ONNX_MODEL_TYPES[model_name]
- elif model_name in YOLOX_MODEL_TYPES:
- initialize_params = YOLOX_MODEL_TYPES[model_name]
- elif model_name in CHIPPER_MODEL_TYPES:
- initialize_params = CHIPPER_MODEL_TYPES[model_name]
+ if model_name in model_config_map:
+ initialize_params = model_config_map[model_name]
else:
raise UnknownModelException(f"Unknown model type: {model_name}")
@@ -78,6 +82,6 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel:
class UnknownModelException(Exception):
- """Exception for the case where a model is called for with an unrecognized identifier."""
+ """A model was requested with an unrecognized identifier."""
pass
diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py
index 0ffa9c7d..857c83e9 100644
--- a/unstructured_inference/models/chipper.py
+++ b/unstructured_inference/models/chipper.py
@@ -9,20 +9,19 @@
import torch
import transformers
from cv2.typing import MatLike
-from huggingface_hub import hf_hub_download
from PIL.Image import Image
from transformers import DonutProcessor, VisionEncoderDecoderModel
from transformers.generation.logits_process import LogitsProcessor
from transformers.generation.stopping_criteria import StoppingCriteria
-from unstructured_inference.constants import Source
+from unstructured_inference.constants import CHIPPER_VERSIONS, Source
from unstructured_inference.inference.elements import Rectangle
from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured_inference.logger import logger
from unstructured_inference.models.unstructuredmodel import (
UnstructuredElementExtractionModel,
)
-from unstructured_inference.utils import LazyDict, strip_tags
+from unstructured_inference.utils import LazyDict, download_if_needed_and_get_local_path, strip_tags
MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = {
"chipperv1": {
@@ -44,11 +43,22 @@
"max_length": 1536,
"heatmap_h": 40,
"heatmap_w": 30,
+ "source": Source.CHIPPERV2,
+ },
+ "chipperv3": {
+ "pre_trained_model_repo": "unstructuredio/chipper-v3",
+ "swap_head": True,
+ "swap_head_hidden_layer_size": 128,
+ "start_token_prefix": "",
+ "max_length": 1536,
+ "heatmap_h": 40,
+ "heatmap_w": 30,
"source": Source.CHIPPER,
},
}
-MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv2"]
+MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv3"]
class UnstructuredChipperModel(UnstructuredElementExtractionModel):
@@ -104,8 +114,8 @@ def initialize(
token=auth_token,
)
if swap_head:
- lm_head_file = hf_hub_download(
- repo_id=pre_trained_model_repo,
+ lm_head_file = download_if_needed_and_get_local_path(
+ path_or_repo=pre_trained_model_repo,
filename="lm_head.pth",
token=auth_token,
)
@@ -160,16 +170,18 @@ def predict(self, image) -> List[LayoutElement]:
return elements
@staticmethod
- def format_table_elements(elements):
- """makes chipper table element return the same as other layout models
+ def format_table_elements(elements: List[LayoutElement]) -> List[LayoutElement]:
+ """Makes chipper table element return the same as other layout models.
- - copies the html representation to attribute text_as_html
- - strip html tags from the attribute text
+ 1. If `text` attribute is an html (has html tags in it), copies the `text`
+ attribute to `text_as_html` attribute.
+ 2. Strips html tags from the `text` attribute.
"""
for element in elements:
- element.text_as_html = element.text
- element.text = strip_tags(element.text)
-
+ text = strip_tags(element.text) if element.text is not None else element.text
+ if text != element.text:
+ element.text_as_html = element.text # type: ignore[attr-defined]
+ element.text = text
return elements
def predict_tokens(
@@ -196,6 +208,7 @@ def predict_tokens(
outputs = self.model.generate(
encoder_outputs=encoder_outputs,
input_ids=self.input_ids,
+ max_length=self.max_length,
no_repeat_ngram_size=0,
num_beams=1,
return_dict_in_generate=True,
@@ -212,6 +225,7 @@ def predict_tokens(
outputs = self.model.generate(
encoder_outputs=encoder_outputs,
input_ids=self.input_ids,
+ max_length=self.max_length,
logits_processor=self.logits_processor,
do_sample=True,
no_repeat_ngram_size=0,
@@ -390,7 +404,7 @@ def deduplicate_detected_elements(
min_text_size: int = 15,
) -> List[LayoutElement]:
"""For chipper, remove elements from other sources."""
- return [el for el in elements if el.source in (Source.CHIPPER, Source.CHIPPERV1)]
+ return [el for el in elements if el.source in CHIPPER_VERSIONS]
def adjust_bbox(self, bbox, x_offset, y_offset, ratio, target_size):
"""Translate bbox by (x_offset, y_offset) and shrink by ratio."""
@@ -450,7 +464,7 @@ def get_bounding_box(
np.asarray(
[
agg_heatmap,
- cv2.resize(
+ cv2.resize( # type: ignore
hmap,
(final_w, final_h),
interpolation=cv2.INTER_LINEAR_EXACT, # cv2.INTER_CUBIC
@@ -516,12 +530,13 @@ def reduce_element_bbox(
Given a LayoutElement element, reduce the size of the bounding box,
depending on existing elements
"""
- bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
+ if element.bbox:
+ bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2]
- if not self.element_overlap(elements, element):
- element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox))
- else:
- element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox))
+ if not self.element_overlap(elements, element):
+ element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox))
+ else:
+ element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox))
def bbox_overlap(
self,
@@ -606,7 +621,7 @@ def reduce_bbox_no_overlap(
):
return input_bbox
- nimage = np.array(image.crop(input_bbox))
+ nimage = np.array(image.crop(input_bbox)) # type: ignore
nimage = self.remove_horizontal_lines(nimage)
@@ -655,7 +670,7 @@ def reduce_bbox_overlap(
):
return input_bbox
- nimage = np.array(image.crop(input_bbox))
+ nimage = np.array(image.crop(input_bbox)) # type: ignore
nimage = self.remove_horizontal_lines(nimage)
@@ -759,7 +774,7 @@ def largest_margin(
):
return None
- nimage = np.array(image.crop(input_bbox))
+ nimage = np.array(image.crop(input_bbox)) # type: ignore
if nimage.shape[0] * nimage.shape[1] == 0:
return None
@@ -868,6 +883,8 @@ def resolve_bbox_overlaps(
continue
ebbox1 = element.bbox
+ if ebbox1 is None:
+ continue
bbox1 = [ebbox1.x1, ebbox1.y1, ebbox1.x2, max(ebbox1.y1, ebbox1.y2)]
for celement in elements:
diff --git a/unstructured_inference/models/detectron2.py b/unstructured_inference/models/detectron2.py
deleted file mode 100644
index 98939f88..00000000
--- a/unstructured_inference/models/detectron2.py
+++ /dev/null
@@ -1,99 +0,0 @@
-from pathlib import Path
-from typing import Any, Dict, Final, List, Optional, Union
-
-from huggingface_hub import hf_hub_download
-from layoutparser.models.detectron2.layoutmodel import (
- Detectron2LayoutModel,
- is_detectron2_available,
-)
-from layoutparser.models.model_config import LayoutModelConfig
-from PIL import Image
-
-from unstructured_inference.constants import ElementType
-from unstructured_inference.inference.layoutelement import LayoutElement
-from unstructured_inference.logger import logger
-from unstructured_inference.models.unstructuredmodel import (
- UnstructuredObjectDetectionModel,
-)
-from unstructured_inference.utils import LazyDict, LazyEvaluateInfo
-
-DETECTRON_CONFIG: Final = "lp://PubLayNet/faster_rcnn_R_50_FPN_3x/config"
-DEFAULT_LABEL_MAP: Final[Dict[int, str]] = {
- 0: ElementType.TEXT,
- 1: ElementType.TITLE,
- 2: ElementType.LIST,
- 3: ElementType.TABLE,
- 4: ElementType.FIGURE,
-}
-DEFAULT_EXTRA_CONFIG: Final[List[Any]] = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8]
-
-
-# NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are
-# needed.
-MODEL_TYPES = {
- "detectron2_lp": LazyDict(
- model_path=LazyEvaluateInfo(
- hf_hub_download,
- "layoutparser/detectron2",
- "PubLayNet/faster_rcnn_R_50_FPN_3x/model_final.pth",
- ),
- config_path=LazyEvaluateInfo(
- hf_hub_download,
- "layoutparser/detectron2",
- "PubLayNet/faster_rcnn_R_50_FPN_3x/config.yml",
- ),
- label_map=DEFAULT_LABEL_MAP,
- extra_config=DEFAULT_EXTRA_CONFIG,
- ),
- "checkbox": LazyDict(
- model_path=LazyEvaluateInfo(
- hf_hub_download,
- "unstructuredio/oer-checkbox",
- "detectron2_finetuned_oer_checkbox.pth",
- ),
- config_path=LazyEvaluateInfo(
- hf_hub_download,
- "unstructuredio/oer-checkbox",
- "detectron2_oer_checkbox.json",
- ),
- label_map={0: "Unchecked", 1: "Checked"},
- extra_config=None,
- ),
-}
-
-
-class UnstructuredDetectronModel(UnstructuredObjectDetectionModel):
- """Unstructured model wrapper for Detectron2LayoutModel."""
-
- def predict(self, x: Image):
- """Makes a prediction using detectron2 model."""
- super().predict(x)
- prediction = self.model.detect(x)
- return [LayoutElement.from_lp_textblock(block) for block in prediction]
-
- def initialize(
- self,
- config_path: Union[str, Path, LayoutModelConfig],
- model_path: Optional[Union[str, Path]] = None,
- label_map: Optional[Dict[int, str]] = None,
- extra_config: Optional[list] = None,
- device: Optional[str] = None,
- ):
- """Loads the detectron2 model using the specified parameters"""
-
- if not is_detectron2_available():
- raise ImportError(
- "Failed to load the Detectron2 model. Ensure that the Detectron2 "
- "module is correctly installed.",
- )
-
- config_path_str = str(config_path)
- model_path_str: Optional[str] = None if model_path is None else str(model_path)
- logger.info("Loading the Detectron2 layout model ...")
- self.model = Detectron2LayoutModel(
- config_path_str,
- model_path=model_path_str,
- label_map=label_map,
- extra_config=extra_config,
- device=device,
- )
diff --git a/unstructured_inference/models/detectron2onnx.py b/unstructured_inference/models/detectron2onnx.py
index 3def8ced..79cd0a1a 100644
--- a/unstructured_inference/models/detectron2onnx.py
+++ b/unstructured_inference/models/detectron2onnx.py
@@ -4,7 +4,6 @@
import cv2
import numpy as np
import onnxruntime
-from huggingface_hub import hf_hub_download
from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE
from onnxruntime.capi import _pybind_state as C
from onnxruntime.quantization import QuantType, quantize_dynamic
@@ -16,7 +15,11 @@
from unstructured_inference.models.unstructuredmodel import (
UnstructuredObjectDetectionModel,
)
-from unstructured_inference.utils import LazyDict, LazyEvaluateInfo
+from unstructured_inference.utils import (
+ LazyDict,
+ LazyEvaluateInfo,
+ download_if_needed_and_get_local_path,
+)
onnxruntime.set_default_logger_severity(logger_onnx.getEffectiveLevel())
@@ -34,7 +37,7 @@
MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = {
"detectron2_onnx": LazyDict(
model_path=LazyEvaluateInfo(
- hf_hub_download,
+ download_if_needed_and_get_local_path,
"unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x",
"model.onnx",
),
@@ -52,7 +55,7 @@
},
"detectron2_mask_rcnn": LazyDict(
model_path=LazyEvaluateInfo(
- hf_hub_download,
+ download_if_needed_and_get_local_path,
"unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x",
"model.onnx",
),
diff --git a/unstructured_inference/models/donut.py b/unstructured_inference/models/donut.py
index 1f753f56..bc60d2c6 100644
--- a/unstructured_inference/models/donut.py
+++ b/unstructured_inference/models/donut.py
@@ -3,7 +3,7 @@
from typing import Optional, Union
import torch
-from PIL import Image
+from PIL import Image as PILImage
from transformers import (
DonutProcessor,
VisionEncoderDecoderConfig,
@@ -16,7 +16,7 @@
class UnstructuredDonutModel(UnstructuredModel):
"""Unstructured model wrapper for Donut image transformer."""
- def predict(self, x: Image):
+ def predict(self, x: PILImage.Image):
"""Make prediction using donut model"""
super().predict(x)
return self.run_prediction(x)
@@ -50,7 +50,7 @@ def initialize(
raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj")
self.model.to(device)
- def run_prediction(self, x: Image):
+ def run_prediction(self, x: PILImage.Image):
"""Internal prediction method."""
pixel_values = self.processor(x, return_tensors="pt").pixel_values
decoder_input_ids = self.processor.tokenizer(
diff --git a/unstructured_inference/models/super_gradients.py b/unstructured_inference/models/super_gradients.py
deleted file mode 100644
index 6d9a25fa..00000000
--- a/unstructured_inference/models/super_gradients.py
+++ /dev/null
@@ -1,100 +0,0 @@
-import os
-from typing import List, cast
-
-import cv2
-import numpy as np
-import onnxruntime
-from onnxruntime.capi import _pybind_state as C
-from PIL import Image
-
-from unstructured_inference.constants import Source
-from unstructured_inference.inference.layoutelement import LayoutElement
-from unstructured_inference.logger import logger
-from unstructured_inference.models.unstructuredmodel import (
- UnstructuredObjectDetectionModel,
-)
-
-
-class UnstructuredSuperGradients(UnstructuredObjectDetectionModel):
- def predict(self, x: Image):
- """Predict using Super-Gradients model."""
- super().predict(x)
- return self.image_processing(x)
-
- def initialize(self, model_path: str, label_map: dict, input_shape: tuple):
- """Start inference session for SuperGradients model."""
-
- if not os.path.exists(model_path):
- logger.info("ONNX Model Path Does Not Exist!")
- self.model_path = model_path
-
- available_providers = C.get_available_providers()
- ordered_providers = [
- "TensorrtExecutionProvider",
- "CUDAExecutionProvider",
- "CPUExecutionProvider",
- ]
-
- providers = [provider for provider in ordered_providers if provider in available_providers]
-
- self.model = onnxruntime.InferenceSession(
- model_path,
- providers=providers,
- )
-
- self.layout_classes = label_map
-
- self.input_shape = input_shape
-
- def image_processing(
- self,
- image: Image.Image,
- ) -> List[LayoutElement]:
- """Method runing SuperGradients Model for layout detection, returns a PageLayout"""
- # Not handling various input images right now
- # TODO (Pravin): check other shapes for inference
- input_shape = self.input_shape
- origin_img = np.array(image)
- img = preprocess(origin_img, input_shape)
- session = self.model
- inputs = [o.name for o in session.get_inputs()]
- outputs = [o.name for o in session.get_outputs()]
- predictions = session.run(outputs, {inputs[0]: img})
-
- regions = []
-
- num_detections, pred_boxes, pred_scores, pred_classes = predictions
- for image_index in range(num_detections.shape[0]):
- for i in range(num_detections[image_index, 0]):
- class_id = pred_classes[image_index, i]
- prob = pred_scores[image_index, i]
- x1, y1, x2, y2 = pred_boxes[image_index, i]
- detected_class = self.layout_classes[str(class_id)]
- region = LayoutElement.from_coords(
- float(x1),
- float(y1),
- float(x2),
- float(y2),
- text=None,
- type=detected_class,
- prob=float(prob),
- source=Source.SUPER_GRADIENTS,
- )
- regions.append(cast(LayoutElement, region))
-
- regions.sort(key=lambda element: element.bbox.y1)
-
- page_layout = regions
-
- return page_layout
-
-
-def preprocess(origin_img, input_shape, swap=(0, 3, 1, 2)):
- """Preprocess image data before Super-Gradients Inputted Model
- Giving a generic preprocess function which simply resizes the image before prediction
- TODO(Pravin): Look into allowing user to specify their own pre-process function
- Which takes a numpy array image and returns a numpy array image
- """
- new_img = cv2.resize(origin_img, input_shape).astype(np.uint8)
- image_bchw = np.transpose(np.expand_dims(new_img, 0), swap)
- return image_bchw
diff --git a/unstructured_inference/models/table_postprocess.py b/unstructured_inference/models/table_postprocess.py
index 602ae564..741277f8 100644
--- a/unstructured_inference/models/table_postprocess.py
+++ b/unstructured_inference/models/table_postprocess.py
@@ -80,24 +80,6 @@ def apply_threshold(objects, threshold):
return [obj for obj in objects if obj["score"] >= threshold]
-# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds):
-# """
-# Filter out bounding boxes whose confidence is below the confidence threshold for
-# its associated class label.
-# """
-# # Apply class-specific thresholds
-# indices_above_threshold = [
-# idx
-# for idx, (score, label) in enumerate(zip(scores, labels))
-# if score >= class_thresholds[class_names[label]]
-# ]
-# bboxes = [bboxes[idx] for idx in indices_above_threshold]
-# scores = [scores[idx] for idx in indices_above_threshold]
-# labels = [labels[idx] for idx in indices_above_threshold]
-
-# return bboxes, scores, labels
-
-
def refine_rows(rows, tokens, score_threshold):
"""
Apply operations to the detected rows, such as
diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py
index 9fc66aee..ade5349e 100644
--- a/unstructured_inference/models/tables.py
+++ b/unstructured_inference/models/tables.py
@@ -3,13 +3,16 @@
import xml.etree.ElementTree as ET
from collections import defaultdict
from pathlib import Path
-from typing import Dict, List, Optional, Union
+from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import cv2
import numpy as np
import torch
-from PIL import Image
+from PIL import Image as PILImage
from transformers import DetrImageProcessor, TableTransformerForObjectDetection
+from transformers.models.table_transformer.modeling_table_transformer import (
+ TableTransformerObjectDetectionOutput,
+)
from unstructured_inference.config import inference_config
from unstructured_inference.inference.layoutelement import table_cells_to_dataframe
@@ -27,7 +30,12 @@ class UnstructuredTableTransformerModel(UnstructuredModel):
def __init__(self):
pass
- def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None):
+ def predict(
+ self,
+ x: PILImage.Image,
+ ocr_tokens: Optional[List[Dict]] = None,
+ result_format: str = "html",
+ ):
"""Predict table structure deferring to run_prediction with ocr tokens
Note:
@@ -44,7 +52,7 @@ def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None):
FIXME: refactor token data into a dataclass so we have clear expectations of the fields
"""
super().predict(x)
- return self.run_prediction(x, ocr_tokens=ocr_tokens)
+ return self.run_prediction(x, ocr_tokens=ocr_tokens, result_format=result_format)
def initialize(
self,
@@ -70,13 +78,12 @@ def initialize(
def get_structure(
self,
- x: Image,
+ x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
) -> dict:
"""get the table structure as a dictionary contaning different types of elements as
key-value pairs; check table-transformer documentation for more information"""
with torch.no_grad():
- logger.info(f"padding image by {pad_for_structure_detection} for structure detection")
encoding = self.feature_extractor(
pad_image_with_background_color(x, pad_for_structure_detection),
return_tensors="pt",
@@ -87,7 +94,7 @@ def get_structure(
def run_prediction(
self,
- x: Image,
+ x: PILImage.Image,
pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD,
ocr_tokens: Optional[List[Dict]] = None,
result_format: Optional[str] = "html",
@@ -96,12 +103,27 @@ def run_prediction(
outputs_structure = self.get_structure(x, pad_for_structure_detection)
if ocr_tokens is None:
raise ValueError("Cannot predict table structure with no OCR tokens")
- prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0]
+
+ recognized_table = recognize(outputs_structure, x, tokens=ocr_tokens)
+ if len(recognized_table) > 0:
+ prediction = recognized_table[0]
+ # NOTE(robinson) - This means that the table was not recognized
+ else:
+ return ""
+
if result_format == "html":
# Convert cells to HTML
prediction = cells_to_html(prediction) or ""
elif result_format == "dataframe":
prediction = table_cells_to_dataframe(prediction)
+ elif result_format == "cells":
+ prediction = prediction
+ else:
+ raise ValueError(
+ f"result_format {result_format} is not a valid format. "
+ f'Valid formats are: "html", "dataframe", "cells"',
+ )
+
return prediction
@@ -148,22 +170,26 @@ def get_class_map(data_type: str):
}
-def recognize(outputs: dict, img: Image, tokens: list):
+def recognize(outputs: dict, img: PILImage.Image, tokens: list):
"""Recognize table elements."""
str_class_name2idx = get_class_map("structure")
str_class_idx2name = {v: k for k, v in str_class_name2idx.items()}
- str_class_thresholds = structure_class_thresholds
+ class_thresholds = structure_class_thresholds
# Post-process detected objects, assign class labels
objects = outputs_to_objects(outputs, img.size, str_class_idx2name)
-
+ high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds)
# Further process the detected objects so they correspond to a consistent table
- tables_structure = objects_to_structures(objects, tokens, str_class_thresholds)
+ tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds)
# Enumerate all table cells: grid cells and spanning cells
return [structure_to_cells(structure, tokens)[0] for structure in tables_structure]
-def outputs_to_objects(outputs, img_size, class_idx2name):
+def outputs_to_objects(
+ outputs: TableTransformerObjectDetectionOutput,
+ img_size: Tuple[int, int],
+ class_idx2name: Mapping[int, str],
+):
"""Output table element types."""
m = outputs["logits"].softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
@@ -192,6 +218,33 @@ def outputs_to_objects(outputs, img_size, class_idx2name):
return objects
+def apply_thresholds_on_objects(
+ objects: Sequence[Mapping[str, Any]],
+ thresholds: Mapping[str, float],
+) -> Sequence[Mapping[str, Any]]:
+ """
+ Filters predicted objects which the confidence scores below the thresholds
+
+ Args:
+ objects: Sequence of mappings for example:
+ [
+ {
+ "label": "table row",
+ "score": 0.55,
+ "bbox": [...],
+ },
+ ...,
+ ]
+ thresholds: Mapping from labels to thresholds
+
+ Returns:
+ Filtered list of objects
+
+ """
+ objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]]
+ return objects
+
+
# for output bounding box post-processing
def box_cxcywh_to_xyxy(x):
"""Convert rectangle format from center-x, center-y, width, height to
@@ -431,6 +484,8 @@ def structure_to_cells(table_structure, tokens):
columns = table_structure["columns"]
rows = table_structure["rows"]
spanning_cells = table_structure["spanning cells"]
+ spanning_cells = sorted(spanning_cells, reverse=True, key=lambda cell: cell["score"])
+
cells = []
subcells = []
# Identify complete cells and subcells
@@ -454,6 +509,7 @@ def structure_to_cells(table_structure, tokens):
spanning_cell_rect.intersect(cell_rect).get_area() / cell_rect.get_area()
) > inference_config.TABLE_IOB_THRESHOLD:
cell["subcell"] = True
+ cell["is_merged"] = False
break
if cell["subcell"]:
@@ -475,7 +531,7 @@ def structure_to_cells(table_structure, tokens):
subcell_rect_area = subcell_rect.get_area()
if (
subcell_rect.intersect(spanning_cell_rect).get_area() / subcell_rect_area
- ) > inference_config.TABLE_IOB_THRESHOLD:
+ ) > inference_config.TABLE_IOB_THRESHOLD and subcell["is_merged"] is False:
if cell_rect is None:
cell_rect = Rect(list(subcell["bbox"]))
else:
@@ -486,6 +542,8 @@ def structure_to_cells(table_structure, tokens):
# as header cells for a spanning cell to be classified as a header cell;
# otherwise, this could lead to a non-rectangular header region
header = header and "column header" in subcell and subcell["column header"]
+ subcell["is_merged"] = True
+
if len(cell_rows) > 0 and len(cell_columns) > 0:
cell = {
"bbox": cell_rect.get_bbox(),
@@ -590,11 +648,8 @@ def structure_to_cells(table_structure, tokens):
def fill_cells(cells: List[dict]) -> List[dict]:
- """add empty cells to pad cells that spans multiple rows for html conversion
-
- For example if a cell takes row 0 and 1 and column 0, we add a new empty cell at row 1 and
- column 0. This padding ensures the structure of the output table is intact. In this example the
- cell data is {"row_nums": [0, 1], "column_nums": [0], ...}
+ """fills the missing cells in the table by adding a cells with empty text
+ where there are no cells detected by the model.
A cell contains the following keys relevent to the html conversion:
row_nums: List[int]
@@ -605,28 +660,63 @@ def fill_cells(cells: List[dict]) -> List[dict]:
than one numbers
cell text: str
the text in this cell
+ column header: bool
+ whether this cell is a column header
"""
- new_cells = cells.copy()
+ if not cells:
+ return []
+
+ table_rows_no = max({row for cell in cells for row in cell["row_nums"]})
+ table_cols_no = max({col for cell in cells for col in cell["column_nums"]})
+ filled = np.zeros((table_rows_no + 1, table_cols_no + 1), dtype=bool)
for cell in cells:
- for extra_row in sorted(cell["row_nums"][1:]):
- new_cell = cell.copy()
- new_cell["row_nums"] = [extra_row]
- new_cell["cell text"] = ""
- new_cells.append(new_cell)
+ for row in cell["row_nums"]:
+ for col in cell["column_nums"]:
+ filled[row, col] = True
+ # add cells for which filled is false
+ header_rows = {row for cell in cells if cell["column header"] for row in cell["row_nums"]}
+ new_cells = cells.copy()
+ not_filled_idx = np.where(filled == False) # noqa: E712
+ for row, col in zip(not_filled_idx[0], not_filled_idx[1]):
+ new_cell = {
+ "row_nums": [row],
+ "column_nums": [col],
+ "cell text": "",
+ "column header": row in header_rows,
+ }
+ new_cells.append(new_cell)
return new_cells
-def cells_to_html(cells):
- """Convert table structure to html format."""
+def cells_to_html(cells: List[dict]) -> str:
+ """Convert table structure to html format.
+
+ Args:
+ cells: List of dictionaries representing table cells, where each dictionary has the
+ following format:
+ {
+ "row_nums": List[int],
+ "column_nums": List[int],
+ "cell text": str,
+ "column header": bool,
+ }
+ Returns:
+ str: HTML table string
+ """
cells = sorted(fill_cells(cells), key=lambda k: (min(k["row_nums"]), min(k["column_nums"])))
table = ET.Element("table")
current_row = -1
+ table_header = None
+ table_has_header = any(cell["column header"] for cell in cells)
+ if table_has_header:
+ table_header = ET.SubElement(table, "thead")
+
+ table_body = ET.SubElement(table, "tbody")
for cell in cells:
this_row = min(cell["row_nums"])
-
attrib = {}
colspan = len(cell["column_nums"])
if colspan > 1:
@@ -637,18 +727,19 @@ def cells_to_html(cells):
if this_row > current_row:
current_row = this_row
if cell["column header"]:
+ table_subelement = table_header
cell_tag = "th"
- row = ET.SubElement(table, "thead")
else:
+ table_subelement = table_body
cell_tag = "td"
- row = ET.SubElement(table, "tr")
+ row = ET.SubElement(table_subelement, "tr") # type: ignore
tcell = ET.SubElement(row, cell_tag, attrib=attrib)
tcell.text = cell["cell text"]
return str(ET.tostring(table, encoding="unicode", short_empty_elements=False))
-def zoom_image(image: Image, zoom: float) -> Image:
+def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image:
"""scale an image based on the zoom factor using cv2; the scaled image is post processed by
dilation then erosion to improve edge sharpness for OCR tasks"""
if zoom <= 0:
@@ -666,4 +757,4 @@ def zoom_image(image: Image, zoom: float) -> Image:
new_image = cv2.dilate(new_image, kernel, iterations=1)
new_image = cv2.erode(new_image, kernel, iterations=1)
- return Image.fromarray(new_image)
+ return PILImage.fromarray(new_image)
diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py
index 47455cf4..0acd93f3 100644
--- a/unstructured_inference/models/yolox.py
+++ b/unstructured_inference/models/yolox.py
@@ -8,14 +8,17 @@
import cv2
import numpy as np
import onnxruntime
-from huggingface_hub import hf_hub_download
from onnxruntime.capi import _pybind_state as C
-from PIL import Image
+from PIL import Image as PILImage
from unstructured_inference.constants import ElementType, Source
from unstructured_inference.inference.layoutelement import LayoutElement
from unstructured_inference.models.unstructuredmodel import UnstructuredObjectDetectionModel
-from unstructured_inference.utils import LazyDict, LazyEvaluateInfo
+from unstructured_inference.utils import (
+ LazyDict,
+ LazyEvaluateInfo,
+ download_if_needed_and_get_local_path,
+)
YOLOX_LABEL_MAP = {
0: ElementType.CAPTION,
@@ -34,7 +37,7 @@
MODEL_TYPES = {
"yolox": LazyDict(
model_path=LazyEvaluateInfo(
- hf_hub_download,
+ download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_l0.05.onnx",
),
@@ -42,7 +45,7 @@
),
"yolox_tiny": LazyDict(
model_path=LazyEvaluateInfo(
- hf_hub_download,
+ download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_tiny.onnx",
),
@@ -50,7 +53,7 @@
),
"yolox_quantized": LazyDict(
model_path=LazyEvaluateInfo(
- hf_hub_download,
+ download_if_needed_and_get_local_path,
"unstructuredio/yolo_x_layout",
"yolox_l0.05_quantized.onnx",
),
@@ -60,7 +63,7 @@
class UnstructuredYoloXModel(UnstructuredObjectDetectionModel):
- def predict(self, x: Image):
+ def predict(self, x: PILImage.Image):
"""Predict using YoloX model."""
super().predict(x)
return self.image_processing(x)
@@ -86,7 +89,7 @@ def initialize(self, model_path: str, label_map: dict):
def image_processing(
self,
- image: Image = None,
+ image: PILImage.Image,
) -> List[LayoutElement]:
"""Method runing YoloX for layout detection, returns a PageLayout
parameters
diff --git a/unstructured_inference/utils.py b/unstructured_inference/utils.py
index 86285954..696a2e8a 100644
--- a/unstructured_inference/utils.py
+++ b/unstructured_inference/utils.py
@@ -1,8 +1,10 @@
+import os
from collections.abc import Mapping
from html.parser import HTMLParser
from io import StringIO
from typing import Any, Callable, Hashable, Iterable, Iterator, Union
+from huggingface_hub import hf_hub_download
from PIL import Image
from unstructured_inference.inference.layoutelement import LayoutElement
@@ -101,3 +103,13 @@ def strip_tags(html: str) -> str:
s = MLStripper()
s.feed(html)
return s.get_data()
+
+
+def download_if_needed_and_get_local_path(path_or_repo: str, filename: str, **kwargs) -> str:
+ """Returns path to local file if it exists, otherwise treats it as a huggingface repo and
+ attempts to download."""
+ full_path = os.path.join(path_or_repo, filename)
+ if os.path.exists(full_path):
+ return full_path
+ else:
+ return hf_hub_download(path_or_repo, filename, **kwargs)
|