diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ff79d80d..97cae495 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -7,7 +7,7 @@ on: branches: [ main ] env: - PYTHON_VERSION: 3.8 + PYTHON_VERSION: 3.9 jobs: setup: @@ -22,7 +22,7 @@ jobs: key: ${{ runner.os }}-${{ env.PYTHON_VERSION }}-${{ hashFiles('requirements/*.txt') }} lookup-only: true - name: Set up Python ${{ env.PYTHON_VERSION }} - uses: actions/setup-python@v4 + uses: actions/setup-python@v5 with: python-version: ${{ env.PYTHON_VERSION }} - name: Install Poppler @@ -104,48 +104,50 @@ jobs: CI=true make test make check-coverage - test_ingest: - strategy: - matrix: - python-version: ["3.8","3.9","3.10"] - runs-on: ubuntu-latest - env: - NLTK_DATA: ${{ github.workspace }}/nltk_data - needs: lint - steps: - - name: Checkout unstructured repo for integration testing - uses: actions/checkout@v4 - with: - repository: 'Unstructured-IO/unstructured' - - name: Checkout this repo - uses: actions/checkout@v4 - with: - path: inference - - name: Set up Python ${{ matrix.python-version }} - uses: actions/setup-python@v4 - with: - python-version: ${{ matrix.python-version }} - - name: Test - env: - GH_READ_ONLY_ACCESS_TOKEN: ${{ secrets.GH_READ_ONLY_ACCESS_TOKEN }} - SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} - DISCORD_TOKEN: ${{ secrets.DISCORD_TOKEN }} - run: | - python${{ matrix.python-version }} -m venv .venv - source .venv/bin/activate - [ ! -d "$NLTK_DATA" ] && mkdir "$NLTK_DATA" - make install-ci - pip install -e inference/ - sudo apt-get update - sudo apt-get install -y libmagic-dev poppler-utils libreoffice pandoc - sudo add-apt-repository -y ppa:alex-p/tesseract-ocr5 - sudo apt-get install -y tesseract-ocr - sudo apt-get install -y tesseract-ocr-kor - sudo apt-get install -y diffstat - tesseract --version - make install-all-ingest - # only run ingest tests that check expected output diffs. - bash inference/scripts/test-unstructured-ingest-helper.sh + # NOTE(robinson) - disabling ingest tests for now, as of 5/22/2024 they seem to have been + # broken for the past six months + # test_ingest: + # strategy: + # matrix: + # python-version: ["3.9","3.10"] + # runs-on: ubuntu-latest + # env: + # NLTK_DATA: ${{ github.workspace }}/nltk_data + # needs: lint + # steps: + # - name: Checkout unstructured repo for integration testing + # uses: actions/checkout@v4 + # with: + # repository: 'Unstructured-IO/unstructured' + # - name: Checkout this repo + # uses: actions/checkout@v4 + # with: + # path: inference + # - name: Set up Python ${{ matrix.python-version }} + # uses: actions/setup-python@v4 + # with: + # python-version: ${{ matrix.python-version }} + # - name: Test + # env: + # GH_READ_ONLY_ACCESS_TOKEN: ${{ secrets.GH_READ_ONLY_ACCESS_TOKEN }} + # SLACK_TOKEN: ${{ secrets.SLACK_TOKEN }} + # DISCORD_TOKEN: ${{ secrets.DISCORD_TOKEN }} + # run: | + # python${{ matrix.python-version }} -m venv .venv + # source .venv/bin/activate + # [ ! -d "$NLTK_DATA" ] && mkdir "$NLTK_DATA" + # make install-ci + # pip install -e inference/ + # sudo apt-get update + # sudo apt-get install -y libmagic-dev poppler-utils libreoffice pandoc + # sudo add-apt-repository -y ppa:alex-p/tesseract-ocr5 + # sudo apt-get install -y tesseract-ocr + # sudo apt-get install -y tesseract-ocr-kor + # sudo apt-get install -y diffstat + # tesseract --version + # make install-all-ingest + # # only run ingest tests that check expected output diffs. + # bash inference/scripts/test-unstructured-ingest-helper.sh changelog: runs-on: ubuntu-latest diff --git a/CHANGELOG.md b/CHANGELOG.md index 11560e32..944f1b27 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,7 +1,85 @@ -## 0.7.20 +## 0.7.37-dev0 * refactor: remove layout analysis related code +## 0.7.36 + +fix: add input parameter validation to `fill_cells()` when converting cells to html + +## 0.7.35 + +Fix syntax for generated HTML tables + +## 0.7.34 + +* Reduce excessive logging + +## 0.7.33 + +* BREAKING CHANGE: removes legacy detectron2 model +* deps: remove layoutparser optional dependencies + +## 0.7.32 + +* refactor: remove all code related to filling inferred elements text from embedded text (pdfminer). +* bug: set the Chipper max_length variable + +## 0.7.31 + +* refactor: remove all `cid` related code that was originally added to filter out invalid `pdfminer` text +* enhancement: Wrapped hf_hub_download with a function that checks for local file before checking HF + +## 0.7.30 + +* fix: table transformer doesn't return multiple cells with same coordinates +* +## 0.7.29 + +* fix: table transformer predictions are now removed if confidence is below threshold + + +## 0.7.28 + +* feat: allow table transformer agent to return table prediction in not parsed format + +## 0.7.27 + +* fix: remove pin from `onnxruntime` dependency. + +## 0.7.26 + +* feat: add a set of new `ElementType`s to extend future element types recognition +* feat: allow registering of new models for inference using `unstructured_inference.models.base.register_new_model` function + +## 0.7.25 + +* fix: replace `Rectangle.is_in()` with `Rectangle.is_almost_subregion_of()` when filling in an inferred element with embedded text +* bug: check for None in Chipper bounding box reduction +* chore: removes `install-detectron2` from the `Makefile` +* fix: convert label_map keys read from os.environment `UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH` to int type +* feat: removes supergradients references + +## 0.7.24 + +* fix: assign value to `text_as_html` element attribute only if `text` attribute contains HTML tags. + +## 0.7.23 + +* fix: added handling in `UnstructuredTableTransformerModel` for if `recognize` returns an empty + list in `run_prediction`. + +## 0.7.22 + +* fix: add logic to handle computation of intersections betwen 2 `Rectangle`s when a `Rectangle` has `None` value in its coordinates + +## 0.7.21 + +* fix: fix a bug where chipper, or any element extraction model based `PageLayout` object, lack `image_metadata` and other attributes that are required for downstream processing; this fix also reduces the memory overhead of using chipper model + +## 0.7.20 + +* chipper-v3: improved table prediction + ## 0.7.19 * refactor: remove all OCR related code diff --git a/Makefile b/Makefile index ab987bb8..5b806af5 100644 --- a/Makefile +++ b/Makefile @@ -19,7 +19,7 @@ install-base: install-base-pip-packages ## install: installs all test, dev, and experimental requirements .PHONY: install -install: install-base-pip-packages install-dev install-detectron2 +install: install-base-pip-packages install-dev .PHONY: install-ci install-ci: install-base-pip-packages install-test @@ -28,10 +28,6 @@ install-ci: install-base-pip-packages install-test install-base-pip-packages: python3 -m pip install pip==${PIP_VERSION} -.PHONY: install-detectron2 -install-detectron2: - pip install "detectron2@git+https://github.com/facebookresearch/detectron2.git@57bdb21249d5418c130d54e2ebdc94dda7a4c01a" - .PHONY: install-test install-test: install-base pip install -r requirements/test.txt @@ -44,10 +40,6 @@ install-dev: install-test .PHONY: pip-compile pip-compile: pip-compile --upgrade requirements/base.in - # NOTE(robinson) - We want the dependencies for detectron2 in the requirements.txt, but not - # the detectron2 repo itself. If detectron2 is in the requirements.txt file, an order of - # operations issue related to the torch library causes the install to fail - sed 's/^detectron2 @/# detectron2 @/g' requirements/base.txt pip-compile --upgrade requirements/test.in pip-compile --upgrade requirements/dev.in diff --git a/pyproject.toml b/pyproject.toml new file mode 100644 index 00000000..aa4949aa --- /dev/null +++ b/pyproject.toml @@ -0,0 +1,2 @@ +[tool.black] +line-length = 100 diff --git a/requirements/base.in b/requirements/base.in index 5f4a8f5d..fc60b9bc 100644 --- a/requirements/base.in +++ b/requirements/base.in @@ -1,11 +1,13 @@ -c constraints.in -layoutparser[layoutmodels,tesseract] +layoutparser python-multipart huggingface-hub opencv-python!=4.7.0.68 onnx -# NOTE(benjamin): Pinned because onnxruntime changed the way quantization is done, and we need to update our code to support it -onnxruntime<1.16 +onnxruntime>=1.17.0 +matplotlib +torch +timm # NOTE(alan): Pinned because this is when the most recent module we import appeared transformers>=4.25.1 rapidfuzz diff --git a/requirements/base.txt b/requirements/base.txt index 28857d7a..536ea640 100644 --- a/requirements/base.txt +++ b/requirements/base.txt @@ -1,12 +1,10 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements/base.in # -antlr4-python3-runtime==4.9.3 - # via omegaconf -certifi==2023.7.22 +certifi==2024.2.2 # via requests cffi==1.16.0 # via cryptography @@ -16,28 +14,26 @@ charset-normalizer==3.3.2 # requests coloredlogs==15.0.1 # via onnxruntime -contourpy==1.1.1 +contourpy==1.2.1 # via matplotlib -cryptography==41.0.5 +cryptography==42.0.7 # via pdfminer-six cycler==0.12.1 # via matplotlib -effdet==0.4.1 - # via layoutparser -filelock==3.13.1 +filelock==3.14.0 # via # huggingface-hub # torch # transformers -flatbuffers==23.5.26 +flatbuffers==24.3.25 # via onnxruntime -fonttools==4.44.3 +fonttools==4.51.0 # via matplotlib -fsspec==2023.10.0 +fsspec==2024.5.0 # via # huggingface-hub # torch -huggingface-hub==0.19.4 +huggingface-hub==0.23.1 # via # -r requirements/base.in # timm @@ -45,27 +41,27 @@ huggingface-hub==0.19.4 # transformers humanfriendly==10.0 # via coloredlogs -idna==3.4 +idna==3.7 # via requests -importlib-resources==6.1.1 +importlib-resources==6.4.0 # via matplotlib iopath==0.1.10 # via layoutparser -jinja2==3.1.2 +jinja2==3.1.4 # via torch kiwisolver==1.4.5 # via matplotlib -layoutparser[layoutmodels,tesseract]==0.3.4 +layoutparser==0.3.4 # via -r requirements/base.in -markupsafe==2.1.3 +markupsafe==2.1.5 # via jinja2 -matplotlib==3.7.3 - # via pycocotools +matplotlib==3.9.0 + # via -r requirements/base.in mpmath==1.3.0 # via sympy -networkx==3.1 +networkx==3.2.1 # via torch -numpy==1.24.4 +numpy==1.26.4 # via # contourpy # layoutparser @@ -74,88 +70,77 @@ numpy==1.24.4 # onnxruntime # opencv-python # pandas - # pycocotools # scipy # torchvision # transformers -omegaconf==2.3.0 - # via effdet -onnx==1.15.0 +onnx==1.16.0 # via -r requirements/base.in -onnxruntime==1.15.1 +onnxruntime==1.18.0 # via -r requirements/base.in -opencv-python==4.8.1.78 +opencv-python==4.9.0.80 # via # -r requirements/base.in # layoutparser -packaging==23.2 +packaging==24.0 # via # huggingface-hub # matplotlib # onnxruntime - # pytesseract # transformers -pandas==2.0.3 +pandas==2.2.2 # via layoutparser -pdf2image==1.16.3 +pdf2image==1.17.0 # via layoutparser -pdfminer-six==20221105 +pdfminer-six==20231228 # via pdfplumber -pdfplumber==0.10.3 +pdfplumber==0.11.0 # via layoutparser -pillow==10.1.0 +pillow==10.3.0 # via # layoutparser # matplotlib # pdf2image # pdfplumber - # pytesseract # torchvision portalocker==2.8.2 # via iopath -protobuf==4.25.1 +protobuf==5.26.1 # via # onnx # onnxruntime -pycocotools==2.0.7 - # via effdet -pycparser==2.21 +pycparser==2.22 # via cffi -pyparsing==3.1.1 +pyparsing==3.1.2 # via matplotlib -pypdfium2==4.24.0 +pypdfium2==4.30.0 # via pdfplumber -pytesseract==0.3.10 - # via layoutparser -python-dateutil==2.8.2 +python-dateutil==2.9.0.post0 # via # matplotlib # pandas -python-multipart==0.0.6 +python-multipart==0.0.9 # via -r requirements/base.in -pytz==2023.3.post1 +pytz==2024.1 # via pandas pyyaml==6.0.1 # via # huggingface-hub # layoutparser - # omegaconf # timm # transformers -rapidfuzz==3.5.2 +rapidfuzz==3.9.1 # via -r requirements/base.in -regex==2023.10.3 +regex==2024.5.15 # via transformers -requests==2.31.0 +requests==2.32.2 # via # huggingface-hub - # torchvision # transformers -safetensors==0.4.0 +safetensors==0.4.3 # via # timm # transformers -scipy==1.10.1 +scipy==1.13.0 # via layoutparser six==1.16.0 # via python-dateutil @@ -163,36 +148,32 @@ sympy==1.12 # via # onnxruntime # torch -timm==0.9.10 - # via effdet -tokenizers==0.15.0 +timm==1.0.3 + # via -r requirements/base.in +tokenizers==0.19.1 # via transformers -torch==2.1.1 +torch==2.3.0 # via - # effdet - # layoutparser + # -r requirements/base.in # timm # torchvision -torchvision==0.16.1 - # via - # effdet - # layoutparser - # timm -tqdm==4.66.1 +torchvision==0.18.0 + # via timm +tqdm==4.66.4 # via # huggingface-hub # iopath # transformers -transformers==4.35.2 +transformers==4.41.0 # via -r requirements/base.in -typing-extensions==4.8.0 +typing-extensions==4.11.0 # via # huggingface-hub # iopath # torch -tzdata==2023.3 +tzdata==2024.1 # via pandas -urllib3==2.1.0 +urllib3==2.2.1 # via requests -zipp==3.17.0 +zipp==3.18.2 # via importlib-resources diff --git a/requirements/dev.txt b/requirements/dev.txt index 8f45a80c..1021b087 100644 --- a/requirements/dev.txt +++ b/requirements/dev.txt @@ -1,17 +1,16 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements/dev.in # -anyio==4.0.0 +anyio==4.3.0 # via # -c requirements/test.txt + # httpx # jupyter-server -appnope==0.1.3 - # via - # ipykernel - # ipython +appnope==0.1.4 + # via ipykernel argon2-cffi==23.1.0 # via jupyter-server argon2-cffi-bindings==21.2.0 @@ -22,24 +21,24 @@ asttokens==2.4.1 # via stack-data async-lru==2.0.4 # via jupyterlab -attrs==23.1.0 +attrs==23.2.0 # via # jsonschema # referencing -babel==2.13.1 +babel==2.15.0 # via jupyterlab-server -backcall==0.2.0 - # via ipython -beautifulsoup4==4.12.2 +beautifulsoup4==4.12.3 # via nbconvert bleach==6.1.0 # via nbconvert -build==1.0.3 +build==1.2.1 # via pip-tools -certifi==2023.7.22 +certifi==2024.2.2 # via # -c requirements/base.txt # -c requirements/test.txt + # httpcore + # httpx # requests cffi==1.16.0 # via @@ -54,11 +53,11 @@ click==8.1.7 # via # -c requirements/test.txt # pip-tools -comm==0.2.0 +comm==0.2.2 # via # ipykernel # ipywidgets -contourpy==1.1.1 +contourpy==1.2.1 # via # -c requirements/base.txt # matplotlib @@ -66,34 +65,48 @@ cycler==0.12.1 # via # -c requirements/base.txt # matplotlib -debugpy==1.8.0 +debugpy==1.8.1 # via ipykernel decorator==5.1.1 # via ipython defusedxml==0.7.1 # via nbconvert -exceptiongroup==1.1.3 +exceptiongroup==1.2.1 # via # -c requirements/test.txt # anyio + # ipython executing==2.0.1 # via stack-data -fastjsonschema==2.19.0 +fastjsonschema==2.19.1 # via nbformat -fonttools==4.44.3 +fonttools==4.51.0 # via # -c requirements/base.txt # matplotlib fqdn==1.5.1 # via jsonschema -idna==3.4 +h11==0.14.0 + # via + # -c requirements/test.txt + # httpcore +httpcore==1.0.5 + # via + # -c requirements/test.txt + # httpx +httpx==0.27.0 + # via + # -c requirements/test.txt + # jupyterlab +idna==3.7 # via # -c requirements/base.txt # -c requirements/test.txt # anyio + # httpx # jsonschema # requests -importlib-metadata==6.8.0 +importlib-metadata==7.1.0 # via # build # jupyter-client @@ -101,52 +114,49 @@ importlib-metadata==6.8.0 # jupyterlab # jupyterlab-server # nbconvert -importlib-resources==6.1.1 +importlib-resources==6.4.0 # via # -c requirements/base.txt - # jsonschema - # jsonschema-specifications - # jupyterlab # matplotlib -ipykernel==6.26.0 +ipykernel==6.29.4 # via # jupyter # jupyter-console # jupyterlab # qtconsole -ipython==8.12.3 +ipython==8.18.1 # via # -r requirements/dev.in # ipykernel # ipywidgets # jupyter-console -ipywidgets==8.1.1 +ipywidgets==8.1.2 # via jupyter isoduration==20.11.0 # via jsonschema jedi==0.19.1 # via ipython -jinja2==3.1.2 +jinja2==3.1.4 # via # -c requirements/base.txt # jupyter-server # jupyterlab # jupyterlab-server # nbconvert -json5==0.9.14 +json5==0.9.25 # via jupyterlab-server jsonpointer==2.4 # via jsonschema -jsonschema[format-nongpl]==4.20.0 +jsonschema[format-nongpl]==4.22.0 # via # jupyter-events # jupyterlab-server # nbformat -jsonschema-specifications==2023.11.1 +jsonschema-specifications==2023.12.1 # via jsonschema jupyter==1.0.0 # via -r requirements/dev.in -jupyter-client==8.6.0 +jupyter-client==8.6.1 # via # ipykernel # jupyter-console @@ -155,7 +165,7 @@ jupyter-client==8.6.0 # qtconsole jupyter-console==6.6.3 # via jupyter -jupyter-core==5.5.0 +jupyter-core==5.7.2 # via # ipykernel # jupyter-client @@ -166,75 +176,75 @@ jupyter-core==5.5.0 # nbconvert # nbformat # qtconsole -jupyter-events==0.9.0 +jupyter-events==0.10.0 # via jupyter-server -jupyter-lsp==2.2.0 +jupyter-lsp==2.2.5 # via jupyterlab -jupyter-server==2.10.1 +jupyter-server==2.14.0 # via # jupyter-lsp # jupyterlab # jupyterlab-server # notebook # notebook-shim -jupyter-server-terminals==0.4.4 +jupyter-server-terminals==0.5.3 # via jupyter-server -jupyterlab==4.0.8 +jupyterlab==4.2.0 # via notebook -jupyterlab-pygments==0.2.2 +jupyterlab-pygments==0.3.0 # via nbconvert -jupyterlab-server==2.25.1 +jupyterlab-server==2.27.1 # via # jupyterlab # notebook -jupyterlab-widgets==3.0.9 +jupyterlab-widgets==3.0.10 # via ipywidgets kiwisolver==1.4.5 # via # -c requirements/base.txt # matplotlib -markupsafe==2.1.3 +markupsafe==2.1.5 # via # -c requirements/base.txt # jinja2 # nbconvert -matplotlib==3.7.3 +matplotlib==3.9.0 # via # -c requirements/base.txt # -r requirements/dev.in -matplotlib-inline==0.1.6 +matplotlib-inline==0.1.7 # via # ipykernel # ipython mistune==3.0.2 # via nbconvert -nbclient==0.9.0 +nbclient==0.10.0 # via nbconvert -nbconvert==7.11.0 +nbconvert==7.16.4 # via # jupyter # jupyter-server -nbformat==5.9.2 +nbformat==5.10.4 # via # jupyter-server # nbclient # nbconvert -nest-asyncio==1.5.8 +nest-asyncio==1.6.0 # via ipykernel -notebook==7.0.6 +notebook==7.2.0 # via jupyter -notebook-shim==0.2.3 +notebook-shim==0.2.4 # via # jupyterlab # notebook -numpy==1.24.4 +numpy==1.26.4 # via # -c requirements/base.txt # contourpy # matplotlib -overrides==7.4.0 +overrides==7.7.0 # via jupyter-server -packaging==23.2 +packaging==24.0 # via # -c requirements/base.txt # -c requirements/test.txt @@ -247,34 +257,30 @@ packaging==23.2 # nbconvert # qtconsole # qtpy -pandocfilters==1.5.0 +pandocfilters==1.5.1 # via nbconvert -parso==0.8.3 +parso==0.8.4 # via jedi -pexpect==4.8.0 - # via ipython -pickleshare==0.7.5 +pexpect==4.9.0 # via ipython -pillow==10.1.0 +pillow==10.3.0 # via # -c requirements/base.txt # -c requirements/test.txt # matplotlib -pip-tools==7.3.0 +pip-tools==7.4.1 # via -r requirements/dev.in -pkgutil-resolve-name==1.3.10 - # via jsonschema -platformdirs==4.0.0 +platformdirs==4.2.2 # via # -c requirements/test.txt # jupyter-core -prometheus-client==0.18.0 +prometheus-client==0.20.0 # via jupyter-server -prompt-toolkit==3.0.41 +prompt-toolkit==3.0.43 # via # ipython # jupyter-console -psutil==5.9.6 +psutil==5.9.8 # via ipykernel ptyprocess==0.7.0 # via @@ -282,23 +288,25 @@ ptyprocess==0.7.0 # terminado pure-eval==0.2.2 # via stack-data -pycparser==2.21 +pycparser==2.22 # via # -c requirements/base.txt # cffi -pygments==2.16.1 +pygments==2.18.0 # via # ipython # jupyter-console # nbconvert # qtconsole -pyparsing==3.1.1 +pyparsing==3.1.2 # via # -c requirements/base.txt # matplotlib -pyproject-hooks==1.0.0 - # via build -python-dateutil==2.8.2 +pyproject-hooks==1.1.0 + # via + # build + # pip-tools +python-dateutil==2.9.0.post0 # via # -c requirements/base.txt # arrow @@ -306,32 +314,28 @@ python-dateutil==2.8.2 # matplotlib python-json-logger==2.0.7 # via jupyter-events -pytz==2023.3.post1 - # via - # -c requirements/base.txt - # babel pyyaml==6.0.1 # via # -c requirements/base.txt # -c requirements/test.txt # jupyter-events -pyzmq==25.1.1 +pyzmq==26.0.3 # via # ipykernel # jupyter-client # jupyter-console # jupyter-server # qtconsole -qtconsole==5.5.1 +qtconsole==5.5.2 # via jupyter qtpy==2.4.1 # via qtconsole -referencing==0.31.0 +referencing==0.35.1 # via # jsonschema # jsonschema-specifications # jupyter-events -requests==2.31.0 +requests==2.32.2 # via # -c requirements/base.txt # -c requirements/test.txt @@ -344,11 +348,11 @@ rfc3986-validator==0.1.1 # via # jsonschema # jupyter-events -rpds-py==0.13.0 +rpds-py==0.18.1 # via # jsonschema # referencing -send2trash==1.8.2 +send2trash==1.8.3 # via jupyter-server six==1.16.0 # via @@ -357,19 +361,20 @@ six==1.16.0 # bleach # python-dateutil # rfc3339-validator -sniffio==1.3.0 +sniffio==1.3.1 # via # -c requirements/test.txt # anyio + # httpx soupsieve==2.5 # via beautifulsoup4 stack-data==0.6.3 # via ipython -terminado==0.18.0 +terminado==0.18.1 # via # jupyter-server # jupyter-server-terminals -tinycss2==1.2.1 +tinycss2==1.3.0 # via nbconvert tomli==2.0.1 # via @@ -377,8 +382,7 @@ tomli==2.0.1 # build # jupyterlab # pip-tools - # pyproject-hooks -tornado==6.3.3 +tornado==6.4 # via # ipykernel # jupyter-client @@ -386,7 +390,7 @@ tornado==6.3.3 # jupyterlab # notebook # terminado -traitlets==5.13.0 +traitlets==5.14.3 # via # comm # ipykernel @@ -403,22 +407,23 @@ traitlets==5.13.0 # nbconvert # nbformat # qtconsole -types-python-dateutil==2.8.19.14 +types-python-dateutil==2.9.0.20240316 # via arrow -typing-extensions==4.8.0 +typing-extensions==4.11.0 # via # -c requirements/base.txt # -c requirements/test.txt + # anyio # async-lru # ipython uri-template==1.3.0 # via jsonschema -urllib3==2.1.0 +urllib3==2.2.1 # via # -c requirements/base.txt # -c requirements/test.txt # requests -wcwidth==0.2.10 +wcwidth==0.2.13 # via prompt-toolkit webcolors==1.13 # via jsonschema @@ -426,13 +431,13 @@ webencodings==0.5.1 # via # bleach # tinycss2 -websocket-client==1.6.4 +websocket-client==1.8.0 # via jupyter-server -wheel==0.41.3 +wheel==0.43.0 # via pip-tools -widgetsnbextension==4.0.9 +widgetsnbextension==4.0.10 # via ipywidgets -zipp==3.17.0 +zipp==3.18.2 # via # -c requirements/base.txt # importlib-metadata diff --git a/requirements/test.txt b/requirements/test.txt index 4e5a2adb..5f738730 100644 --- a/requirements/test.txt +++ b/requirements/test.txt @@ -1,14 +1,14 @@ # -# This file is autogenerated by pip-compile with Python 3.8 +# This file is autogenerated by pip-compile with Python 3.9 # by the following command: # # pip-compile requirements/test.in # -anyio==4.0.0 +anyio==4.3.0 # via httpx -black==23.11.0 +black==24.4.2 # via -r requirements/test.in -certifi==2023.7.22 +certifi==2024.2.2 # via # -c requirements/base.txt # httpcore @@ -22,39 +22,39 @@ click==8.1.7 # via # -r requirements/test.in # black -coverage[toml]==7.3.2 +coverage[toml]==7.5.1 # via # -r requirements/test.in # pytest-cov -exceptiongroup==1.1.3 +exceptiongroup==1.2.1 # via # anyio # pytest -filelock==3.13.1 +filelock==3.14.0 # via # -c requirements/base.txt # huggingface-hub -flake8==6.1.0 +flake8==7.0.0 # via # -r requirements/test.in # flake8-docstrings flake8-docstrings==1.7.0 # via -r requirements/test.in -fsspec==2023.10.0 +fsspec==2024.5.0 # via # -c requirements/base.txt # huggingface-hub h11==0.14.0 # via httpcore -httpcore==1.0.2 +httpcore==1.0.5 # via httpx -httpx==0.25.1 +httpx==0.27.0 # via -r requirements/test.in -huggingface-hub==0.19.4 +huggingface-hub==0.23.1 # via # -c requirements/base.txt # -r requirements/test.in -idna==3.4 +idna==3.7 # via # -c requirements/base.txt # anyio @@ -64,57 +64,57 @@ iniconfig==2.0.0 # via pytest mccabe==0.7.0 # via flake8 -mypy==1.7.0 +mypy==1.10.0 # via -r requirements/test.in mypy-extensions==1.0.0 # via # black # mypy -packaging==23.2 +packaging==24.0 # via # -c requirements/base.txt # black # huggingface-hub # pytest -pathspec==0.11.2 +pathspec==0.12.1 # via black -pdf2image==1.16.3 +pdf2image==1.17.0 # via # -c requirements/base.txt # -r requirements/test.in -pillow==10.1.0 +pillow==10.3.0 # via # -c requirements/base.txt # pdf2image -platformdirs==4.0.0 +platformdirs==4.2.2 # via black -pluggy==1.3.0 +pluggy==1.5.0 # via pytest pycodestyle==2.11.1 # via flake8 pydocstyle==6.3.0 # via flake8-docstrings -pyflakes==3.1.0 +pyflakes==3.2.0 # via flake8 -pytest==7.4.3 +pytest==8.2.1 # via # pytest-cov # pytest-mock -pytest-cov==4.1.0 +pytest-cov==5.0.0 # via -r requirements/test.in -pytest-mock==3.12.0 +pytest-mock==3.14.0 # via -r requirements/test.in pyyaml==6.0.1 # via # -c requirements/base.txt # huggingface-hub -requests==2.31.0 +requests==2.32.2 # via # -c requirements/base.txt # huggingface-hub -ruff==0.1.5 +ruff==0.4.4 # via -r requirements/test.in -sniffio==1.3.0 +sniffio==1.3.1 # via # anyio # httpx @@ -126,19 +126,20 @@ tomli==2.0.1 # coverage # mypy # pytest -tqdm==4.66.1 +tqdm==4.66.4 # via # -c requirements/base.txt # huggingface-hub -types-pyyaml==6.0.12.12 +types-pyyaml==6.0.12.20240311 # via -r requirements/test.in -typing-extensions==4.8.0 +typing-extensions==4.11.0 # via # -c requirements/base.txt + # anyio # black # huggingface-hub # mypy -urllib3==2.1.0 +urllib3==2.2.1 # via # -c requirements/base.txt # requests diff --git a/scripts/version-sync.sh b/scripts/version-sync.sh index e8888efa..4a62d26e 100755 --- a/scripts/version-sync.sh +++ b/scripts/version-sync.sh @@ -13,12 +13,12 @@ function usage { } function getopts-extra () { - declare i=1 + declare -i i=1 # if the next argument is not an option, then append it to array OPTARG while [[ ${OPTIND} -le $# && ${!OPTIND:0:1} != '-' ]]; do OPTARG[i]=${!OPTIND} - i+=1 - OPTIND+=1 + ((i += 1)) + ((OPTIND += 1)) done } diff --git a/test_unstructured_inference/inference/test_layout.py b/test_unstructured_inference/inference/test_layout.py index e7f77164..6cb618ce 100644 --- a/test_unstructured_inference/inference/test_layout.py +++ b/test_unstructured_inference/inference/test_layout.py @@ -9,7 +9,10 @@ import unstructured_inference.models.base as models from unstructured_inference.inference import elements, layout, layoutelement -from unstructured_inference.inference.elements import EmbeddedTextRegion, ImageTextRegion +from unstructured_inference.inference.elements import ( + EmbeddedTextRegion, + ImageTextRegion, +) from unstructured_inference.models.unstructuredmodel import ( UnstructuredElementExtractionModel, UnstructuredObjectDetectionModel, @@ -156,7 +159,7 @@ def test_process_data_with_model_raises_on_invalid_model_name(): layout.process_data_with_model(fp, model_name="fake") -@pytest.mark.parametrize("model_name", [None, "checkbox"]) +@pytest.mark.parametrize("model_name", [None, "yolox"]) def test_process_file_with_model(monkeypatch, mock_final_layout, model_name): def mock_initialize(self, *args, **kwargs): self.model = MockLayoutModel(mock_final_layout) @@ -166,7 +169,7 @@ def mock_initialize(self, *args, **kwargs): "from_file", lambda *args, **kwargs: layout.DocumentLayout.from_pages([]), ) - monkeypatch.setattr(models.UnstructuredDetectronModel, "initialize", mock_initialize) + monkeypatch.setattr(models.UnstructuredDetectronONNXModel, "initialize", mock_initialize) filename = "" assert layout.process_file_with_model(filename, model_name=model_name) @@ -180,7 +183,7 @@ def mock_initialize(self, *args, **kwargs): "from_file", lambda *args, **kwargs: layout.DocumentLayout.from_pages([]), ) - monkeypatch.setattr(models.UnstructuredDetectronModel, "initialize", mock_initialize) + monkeypatch.setattr(models.UnstructuredDetectronONNXModel, "initialize", mock_initialize) filename = "" layout.process_file_with_model(filename, model_name=None) # There should be no UserWarning, but if there is one it should not have the following message @@ -224,33 +227,6 @@ def __init__( self.detection_model = detection_model -@pytest.mark.parametrize( - ("text", "expected"), - [ - ("base", 0.0), - ("", 0.0), - ("(cid:2)", 1.0), - ("(cid:1)a", 0.5), - ("c(cid:1)ab", 0.25), - ], -) -def test_cid_ratio(text, expected): - assert elements.cid_ratio(text) == expected - - -@pytest.mark.parametrize( - ("text", "expected"), - [ - ("base", False), - ("(cid:2)", True), - ("(cid:1234567890)", True), - ("jkl;(cid:12)asdf", True), - ], -) -def test_is_cid_present(text, expected): - assert elements.is_cid_present(text) == expected - - class MockLayout: def __init__(self, *elements): self.elements = elements @@ -271,12 +247,14 @@ def filter_by(self, *args, **kwargs): return MockLayout() +@pytest.mark.parametrize("element_extraction_model", [None, "foo"]) @pytest.mark.parametrize("filetype", ["png", "jpg", "tiff"]) -def test_from_image_file(monkeypatch, mock_final_layout, filetype): +def test_from_image_file(monkeypatch, mock_final_layout, filetype, element_extraction_model): def mock_get_elements(self, *args, **kwargs): self.elements = [mock_final_layout] monkeypatch.setattr(layout.PageLayout, "get_elements_with_detection_model", mock_get_elements) + monkeypatch.setattr(layout.PageLayout, "get_elements_using_image_extraction", mock_get_elements) filename = f"sample-docs/loremipsum.{filetype}" image = Image.open(filename) image_metadata = { @@ -285,7 +263,10 @@ def mock_get_elements(self, *args, **kwargs): "height": image.height, } - doc = layout.DocumentLayout.from_image_file(filename) + doc = layout.DocumentLayout.from_image_file( + filename, + element_extraction_model=element_extraction_model, + ) page = doc.pages[0] assert page.elements[0] == mock_final_layout assert page.image is None @@ -331,16 +312,6 @@ def test_from_image_file_raises_isadirectoryerror_with_dir(): layout.DocumentLayout.from_image_file(tempdir) -@pytest.mark.parametrize("idx", range(2)) -def test_get_elements_from_layout(mock_initial_layout, idx): - page = MockPageLayout() - block = mock_initial_layout[idx] - block.bbox.pad(3) - fixed_layout = [block] - elements = page.get_elements_from_layout(fixed_layout) - assert elements[0].text == block.text - - def test_page_numbers_in_page_objects(): with patch( "unstructured_inference.inference.layout.PageLayout.get_elements_with_detection_model", @@ -350,49 +321,8 @@ def test_page_numbers_in_page_objects(): assert [page.number for page in doc.pages] == list(range(1, len(doc.pages) + 1)) -@pytest.mark.parametrize( - ("fixed_layouts", "called_method", "not_called_method"), - [ - ( - [MockLayout()], - "get_elements_from_layout", - "get_elements_with_detection_model", - ), - (None, "get_elements_with_detection_model", "get_elements_from_layout"), - ], -) -def test_from_file_fixed_layout(fixed_layouts, called_method, not_called_method): - with patch.object( - layout.PageLayout, - "get_elements_with_detection_model", - return_value=[], - ), patch.object( - layout.PageLayout, - "get_elements_from_layout", - return_value=[], - ): - layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf", fixed_layouts=fixed_layouts) - getattr(layout.PageLayout, called_method).assert_called() - getattr(layout.PageLayout, not_called_method).assert_not_called() - - -@pytest.mark.parametrize( - ("text", "expected"), - [("c\to\x0cn\ftrol\ncharacter\rs\b", "control characters"), ("\"'\\", "\"'\\")], -) -def test_remove_control_characters(text, expected): - assert elements.remove_control_characters(text) == expected - - no_text_region = EmbeddedTextRegion.from_coords(0, 0, 100, 100) text_region = EmbeddedTextRegion.from_coords(0, 0, 100, 100, text="test") -cid_text_region = EmbeddedTextRegion.from_coords( - 0, - 0, - 100, - 100, - text="(cid:1)(cid:2)(cid:3)(cid:4)(cid:5)", -) overlapping_rect = ImageTextRegion.from_coords(50, 50, 150, 150) nonoverlapping_rect = ImageTextRegion.from_coords(150, 150, 200, 200) populated_text_region = EmbeddedTextRegion.from_coords(50, 50, 60, 60, text="test") @@ -443,12 +373,6 @@ def check_annotated_image(): check_annotated_image() -@pytest.mark.parametrize(("text", "expected"), [("asdf", "asdf"), (None, "")]) -def test_embedded_text_region(text, expected): - etr = elements.EmbeddedTextRegion.from_coords(0, 0, 24, 24, text=text) - assert etr.extract_text(objects=None) == expected - - class MockDetectionModel(layout.UnstructuredObjectDetectionModel): def initialize(self, *args, **kwargs): pass diff --git a/test_unstructured_inference/inference/test_layout_element.py b/test_unstructured_inference/inference/test_layout_element.py index da4b0f10..f814c180 100644 --- a/test_unstructured_inference/inference/test_layout_element.py +++ b/test_unstructured_inference/inference/test_layout_element.py @@ -5,18 +5,6 @@ from unstructured_inference.inference.layoutelement import LayoutElement, TextRegion -def test_layout_element_extract_text( - mock_layout_element, - mock_text_region, -): - extracted_text = mock_layout_element.extract_text( - objects=[mock_text_region], - ) - - assert isinstance(extracted_text, str) - assert "Sample text" in extracted_text - - def test_layout_element_do_dict(mock_layout_element): expected = { "coordinates": ((100, 100), (100, 300), (300, 300), (300, 100)), diff --git a/test_unstructured_inference/models/test_chippermodel.py b/test_unstructured_inference/models/test_chippermodel.py index 5a63894b..c68aa6bc 100644 --- a/test_unstructured_inference/models/test_chippermodel.py +++ b/test_unstructured_inference/models/test_chippermodel.py @@ -3,6 +3,7 @@ import pytest import torch from PIL import Image +from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.models import chipper from unstructured_inference.models.base import get_model @@ -139,13 +140,8 @@ def test_no_repeat_ngram_logits(): no_repeat_ngram_size = 2 - output = chipper._no_repeat_ngram_logits( - input_ids=input_ids, - cur_len=cur_len, - logits=logits, - batch_size=batch_size, - no_repeat_ngram_size=no_repeat_ngram_size, - ) + logitsProcessor = chipper.NoRepeatNGramLogitsProcessor(ngram_size=2) + output = logitsProcessor(input_ids=input_ids, scores=logits) assert ( int( @@ -194,6 +190,25 @@ def test_no_repeat_ngram_logits(): ) +def test_ngram_repetiton_stopping_criteria(): + input_ids = torch.tensor([[1, 2, 3, 4, 0, 1, 2, 3, 4]]) + logits = torch.tensor([[0.1, -0.3, -0.5, 0, 1.0, -0.9]]) + + stoppingCriteria = chipper.NGramRepetitonStoppingCriteria( + repetition_window=2, skip_tokens={0, 1, 2, 3, 4} + ) + + output = stoppingCriteria(input_ids=input_ids, scores=logits) + + assert output is False + + stoppingCriteria = chipper.NGramRepetitonStoppingCriteria( + repetition_window=2, skip_tokens={1, 2, 3, 4} + ) + output = stoppingCriteria(input_ids=input_ids, scores=logits) + assert output is True + + @pytest.mark.parametrize( ("decoded_str", "expected_classes"), [ @@ -241,7 +256,51 @@ def test_postprocess_bbox(decoded_str, expected_classes): assert out[i].type == expected_classes[i] -def test_run_chipper_v2(): +def test_predict_tokens_beam_indices(): + model = get_model("chipper") + model.stopping_criteria = [ + chipper.NGramRepetitonStoppingCriteria( + repetition_window=1, + skip_tokens={}, + ), + ] + img = Image.open("sample-docs/easy_table.jpg") + output = model.predict_tokens(image=img) + assert len(output) > 0 + + +def test_largest_margin_edge(): + model = get_model("chipper") + img = Image.open("sample-docs/easy_table.jpg") + output = model.largest_margin(image=img, input_bbox=[0, 1, 0, 0], transpose=False) + + assert output is None + + output = model.largest_margin(img, [1, 1, 1, 1], False) + + assert output is None + + output = model.largest_margin(img, [2, 1, 3, 10], True) + + assert output == (0, 0, 0) + + +def test_deduplicate_detected_elements(): + model = get_model("chipper") + img = Image.open("sample-docs/easy_table.jpg") + elements = model(img) + + output = model.deduplicate_detected_elements(elements) + + assert len(output) == 2 + + +def test_norepeatnGramlogitsprocessor_exception(): + with pytest.raises(ValueError): + chipper.NoRepeatNGramLogitsProcessor(ngram_size="") + + +def test_run_chipper_v3(): model = get_model("chipper") img = Image.open("sample-docs/easy_table.jpg") elements = model(img) @@ -364,3 +423,26 @@ def test_check_overlap(bbox1, bbox2, output): model = get_model("chipper") assert model.check_overlap(bbox1, bbox2) == output + + +def test_format_table_elements(): + table_html = "
Cell 1Cell 2
Cell 3
" + texts = [ + "Text", + " - List element", + table_html, + None, + ] + elements = [LayoutElement(bbox=mock.MagicMock(), text=text) for text in texts] + formatted_elements = chipper.UnstructuredChipperModel.format_table_elements(elements) + text_attributes = [fe.text for fe in formatted_elements] + text_as_html_attributes = [ + fe.text_as_html if hasattr(fe, "text_as_html") else None for fe in formatted_elements + ] + assert text_attributes == [ + "Text", + " - List element", + "Cell 1Cell 2Cell 3", + None, + ] + assert text_as_html_attributes == [None, None, table_html, None] diff --git a/test_unstructured_inference/models/test_detectron2.py b/test_unstructured_inference/models/test_detectron2.py deleted file mode 100644 index 987120e1..00000000 --- a/test_unstructured_inference/models/test_detectron2.py +++ /dev/null @@ -1,50 +0,0 @@ -from unittest.mock import patch - -import pytest - -import unstructured_inference.models.base as models -from unstructured_inference.models import detectron2 - - -class MockDetectron2LayoutModel: - def __init__(self, *args, **kwargs): - self.args = args - self.kwargs = kwargs - - def detect(self, x): - return [] - - -def test_load_default_model(monkeypatch): - monkeypatch.setattr(detectron2, "Detectron2LayoutModel", MockDetectron2LayoutModel) - monkeypatch.setattr(models, "models", {}) - - with patch.object(detectron2, "is_detectron2_available", return_value=True): - model = models.get_model("detectron2_lp") - - assert isinstance(model.model, MockDetectron2LayoutModel) - - -def test_load_default_model_raises_when_not_available(monkeypatch): - monkeypatch.setattr(models, "models", {}) - with patch.object(detectron2, "is_detectron2_available", return_value=False): - with pytest.raises(ImportError): - models.get_model("detectron2_lp") - - -@pytest.mark.parametrize(("config_path", "model_path"), [("asdf", "diufs"), ("dfaw", "hfhfhfh")]) -def test_load_model(monkeypatch, config_path, model_path): - monkeypatch.setattr(detectron2, "Detectron2LayoutModel", MockDetectron2LayoutModel) - with patch.object(detectron2, "is_detectron2_available", return_value=True): - model = detectron2.UnstructuredDetectronModel() - model.initialize(config_path=config_path, model_path=model_path) - assert config_path == model.model.args[0] - assert model_path == model.model.kwargs["model_path"] - - -def test_unstructured_detectron_model(): - model = detectron2.UnstructuredDetectronModel() - model.model = MockDetectron2LayoutModel() - result = model(None) - assert isinstance(result, list) - assert len(result) == 0 diff --git a/test_unstructured_inference/models/test_model.py b/test_unstructured_inference/models/test_model.py index fc7c18aa..411ef3d4 100644 --- a/test_unstructured_inference/models/test_model.py +++ b/test_unstructured_inference/models/test_model.py @@ -1,3 +1,4 @@ +import json from typing import Any from unittest import mock @@ -25,10 +26,33 @@ def predict(self, x: Any) -> Any: return [] +MOCK_MODEL_TYPES = { + "foo": { + "input_shape": (640, 640), + }, +} + + def test_get_model(monkeypatch): monkeypatch.setattr(models, "models", {}) - with mock.patch.dict(models.model_class_map, {"checkbox": MockModel}): - assert isinstance(models.get_model("checkbox"), MockModel) + with mock.patch.dict(models.model_class_map, {"yolox": MockModel}): + assert isinstance(models.get_model("yolox"), MockModel) + + +def test_register_new_model(): + assert "foo" not in models.model_class_map + assert "foo" not in models.model_config_map + models.register_new_model(MOCK_MODEL_TYPES, MockModel) + assert "foo" in models.model_class_map + assert "foo" in models.model_config_map + model = models.get_model("foo") + assert len(model.initializer.mock_calls) == 1 + assert model.initializer.mock_calls[0][-1] == MOCK_MODEL_TYPES["foo"] + assert isinstance(model, MockModel) + # unregister the new model by reset to default + models.model_class_map, models.model_config_map = models.get_default_model_mappings() + assert "foo" not in models.model_class_map + assert "foo" not in models.model_config_map def test_raises_invalid_model(): @@ -38,14 +62,16 @@ def test_raises_invalid_model(): def test_raises_uninitialized(): with pytest.raises(ModelNotInitializedError): - models.UnstructuredDetectronModel().predict(None) + models.UnstructuredDetectronONNXModel().predict(None) def test_model_initializes_once(): from unstructured_inference.inference import layout with mock.patch.dict(models.model_class_map, {"yolox": MockModel}), mock.patch.object( - models, "models", {} + models, + "models", + {}, ): doc = layout.DocumentLayout.from_file("sample-docs/loremipsum.pdf") doc.pages[0].detection_model.initializer.assert_called_once() @@ -143,23 +169,32 @@ def test_env_variables_override_default_model(monkeypatch): # args, we should get back the model the env var calls for monkeypatch.setattr(models, "models", {}) with mock.patch.dict( - models.os.environ, {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "checkbox"} - ), mock.patch.dict(models.model_class_map, {"checkbox": MockModel}): + models.os.environ, + {"UNSTRUCTURED_DEFAULT_MODEL_NAME": "yolox"}, + ), mock.patch.dict(models.model_class_map, {"yolox": MockModel}): model = models.get_model() assert isinstance(model, MockModel) -def test_env_variables_override_intialization_params(monkeypatch): +def test_env_variables_override_initialization_params(monkeypatch): # When initialization params are specified in an environment variable, and we call get_model, we # should see that the model was initialized with those params monkeypatch.setattr(models, "models", {}) + fake_label_map = {"1": "label1", "2": "label2"} with mock.patch.dict( models.os.environ, {"UNSTRUCTURED_DEFAULT_MODEL_INITIALIZE_PARAMS_JSON_PATH": "fake_json.json"}, ), mock.patch.object(models, "DEFAULT_MODEL", "fake"), mock.patch.dict( - models.model_class_map, {"fake": mock.MagicMock()} + models.model_class_map, + {"fake": mock.MagicMock()}, ), mock.patch( - "builtins.open", mock.mock_open(read_data='{"date": "3/26/81"}') + "builtins.open", + mock.mock_open( + read_data='{"model_path": "fakepath", "label_map": ' + json.dumps(fake_label_map) + "}", + ), ): model = models.get_model() - model.initialize.assert_called_once_with(date="3/26/81") + model.initialize.assert_called_once_with( + model_path="fakepath", + label_map={1: "label1", 2: "label2"}, + ) diff --git a/test_unstructured_inference/models/test_supergradients.py b/test_unstructured_inference/models/test_supergradients.py deleted file mode 100644 index d56df7d0..00000000 --- a/test_unstructured_inference/models/test_supergradients.py +++ /dev/null @@ -1,42 +0,0 @@ -from os import path -from PIL import Image - -from unstructured_inference.constants import Source -from unstructured_inference.models import super_gradients - - -def test_supergradients_model(): - model_path = path.join(path.dirname(__file__), "test_ci_model.onnx") - model = super_gradients.UnstructuredSuperGradients() - model.initialize( - model_path=model_path, - label_map={ - "0": "Picture", - "1": "Caption", - "2": "Text", - "3": "Formula", - "4": "Page number", - "5": "Address", - "6": "Footer", - "7": "Subheadline", - "8": "Chart", - "9": "Metadata", - "10": "Title", - "11": "Misc", - "12": "Header", - "13": "Table", - "14": "Headline", - "15": "List-item", - "16": "List", - "17": "Author", - "18": "Value", - "19": "Link", - "20": "Field-Name", - }, - input_shape=(1024, 1024), - ) - img = Image.open("sample-docs/loremipsum.jpg") - el, *_ = model(img) - assert el.source == Source.SUPER_GRADIENTS - assert el.prob > 0.70 - assert el.type == "Text" diff --git a/test_unstructured_inference/models/test_tables.py b/test_unstructured_inference/models/test_tables.py index 9131cbce..15c467cd 100644 --- a/test_unstructured_inference/models/test_tables.py +++ b/test_unstructured_inference/models/test_tables.py @@ -7,9 +7,11 @@ from transformers.models.table_transformer.modeling_table_transformer import ( TableTransformerDecoder, ) +from copy import deepcopy import unstructured_inference.models.table_postprocess as postprocess from unstructured_inference.models import tables +from unstructured_inference.models.tables import apply_thresholds_on_objects, structure_to_cells skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"} @@ -932,9 +934,43 @@ def test_table_prediction_output_format( assert expectation in result +def test_table_prediction_output_format_when_wrong_type_then_value_error( + table_transformer, + example_image, + mocker, + example_table_cells, + mocked_ocr_tokens, +): + mocker.patch.object(tables, "recognize", return_value=example_table_cells) + mocker.patch.object( + tables.UnstructuredTableTransformerModel, + "get_structure", + return_value=None, + ) + with pytest.raises(ValueError): + table_transformer.run_prediction( + example_image, result_format="Wrong format", ocr_tokens=mocked_ocr_tokens + ) + + +def test_table_prediction_runs_with_empty_recognize( + table_transformer, + example_image, + mocker, + mocked_ocr_tokens, +): + mocker.patch.object(tables, "recognize", return_value=[]) + mocker.patch.object( + tables.UnstructuredTableTransformerModel, + "get_structure", + return_value=None, + ) + assert table_transformer.run_prediction(example_image, ocr_tokens=mocked_ocr_tokens) == "" + + def test_table_prediction_with_ocr_tokens(table_transformer, example_image, mocked_ocr_tokens): prediction = table_transformer.predict(example_image, ocr_tokens=mocked_ocr_tokens) - assert '
' in prediction + assert '" in prediction @@ -943,6 +979,55 @@ def test_table_prediction_with_no_ocr_tokens(table_transformer, example_image): table_transformer.predict(example_image) +@pytest.mark.parametrize( + ("thresholds", "expected_object_number"), + [ + ({"0": 0.5}, 1), + ({"0": 0.1}, 3), + ({"0": 0.9}, 0), + ], +) +def test_objects_are_filtered_based_on_class_thresholds_when_correct_prediction_and_threshold( + thresholds, expected_object_number +): + objects = [ + {"label": "0", "score": 0.2}, + {"label": "0", "score": 0.4}, + {"label": "0", "score": 0.55}, + ] + assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number + + +@pytest.mark.parametrize( + ("thresholds", "expected_object_number"), + [ + ({"0": 0.5, "1": 0.1}, 4), + ({"0": 0.1, "1": 0.9}, 3), + ({"0": 0.9, "1": 0.5}, 1), + ], +) +def test_objects_are_filtered_based_on_class_thresholds_when_two_classes( + thresholds, expected_object_number +): + objects = [ + {"label": "0", "score": 0.2}, + {"label": "0", "score": 0.4}, + {"label": "0", "score": 0.55}, + {"label": "1", "score": 0.2}, + {"label": "1", "score": 0.4}, + {"label": "1", "score": 0.55}, + ] + assert len(apply_thresholds_on_objects(objects, thresholds)) == expected_object_number + + +def test_objects_filtering_when_missing_threshold(): + class_name = "class_name" + objects = [{"label": class_name, "score": 0.2}] + thresholds = {"1": 0.5} + with pytest.raises(KeyError, match=class_name): + apply_thresholds_on_objects(objects, thresholds) + + def test_intersect(): a = postprocess.Rect() b = postprocess.Rect([1, 2, 3, 4]) @@ -1131,26 +1216,6 @@ def test_header_supercell_tree(supercells, expected_len): assert len(supercells) == expected_len -def test_cells_to_html(): - # example table - # +----------+---------------------+ - # | two | two columns | - # | |----------+----------| - # | rows |sub cell 1|sub cell 2| - # +----------+----------+----------+ - cells = [ - {"row_nums": [0, 1], "column_nums": [0], "cell text": "two row", "column header": False}, - {"row_nums": [0], "column_nums": [1, 2], "cell text": "two cols", "column header": False}, - {"row_nums": [1], "column_nums": [1], "cell text": "sub cell 1", "column header": False}, - {"row_nums": [1], "column_nums": [2], "cell text": "sub cell 2", "column header": False}, - ] - expected = ( - '
' in prediction assert "
Blind51434.5%, n=1
two rowtwo ' - "cols
sub cell 1sub cell 2
" - ) - assert tables.cells_to_html(cells) == expected - - @pytest.mark.parametrize("zoom", [1, 0.1, 5, -1, 0]) def test_zoom_image(example_image, zoom): width, height = example_image.size @@ -1162,6 +1227,534 @@ def test_zoom_image(example_image, zoom): assert new_h == np.round(height * zoom, 0) +@pytest.mark.parametrize( + ("input_cells", "expected_html"), + [ + # +----------+---------------------+ + # | row1col1 | row1col2 | row1col3 | + # |----------|----------+----------| + # | row2col1 | row2col2 | row2col3 | + # +----------+----------+----------+ + pytest.param( + [ + { + "row_nums": [0], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + ], + ( + "" + "
row1col1row1col2row1col3
row2col1row2col2row2col3
" + ), + id="simple table without header", + ), + # +----------+---------------------+ + # | h1col1 | h1col2 | h1col3 | + # |----------|----------+----------| + # | row1col1 | row1col2 | row1col3 | + # |----------|----------+----------| + # | row2col1 | row2col2 | row2col3 | + # +----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + ], + ( + "" + "" + "
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
" + ), + id="simple table with header", + ), + # +----------+---------------------+ + # | h1col1 | h1col2 | h1col3 | + # |----------|----------+----------| + # | row1col1 | row1col2 | row1col3 | + # |----------|----------+----------| + # | row2col1 | row2col2 | row2col3 | + # +----------+----------+----------+ + pytest.param( + [ + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + ], + ( + "" + "" + "
h1col1h1col2h1col2
row1col1row1col2row1col3
row2col1row2col2row2col3
" + ), + id="simple table with header, mixed elements", + ), + # +----------+---------------------+ + # | two | two columns | + # | |----------+----------| + # | rows |sub cell 1|sub cell 2| + # +----------+----------+----------+ + pytest.param( + [ + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "two row", + "column header": False, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "two cols", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "sub cell 1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "sub cell 2", + "column header": False, + }, + ], + ( + '" + "
two rowtwo ' + "cols
sub cell 1sub cell 2
" + ), + id="various spans, no headers", + ), + # +----------+---------------------+----------+ + # | | h1col23 | h1col4 | + # | h12col1 |----------+----------+----------| + # | | h2col2 | h2col34 | + # |----------|----------+----------+----------+ + # | r3col1 | r3col2 | | + # |----------+----------| r34col34 | + # | r4col12 | | + # +----------+----------+----------+----------+ + pytest.param( + [ + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "h12col1", + "column header": True, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "h1col23", + "column header": True, + }, + {"row_nums": [0], "column_nums": [3], "cell text": "h1col4", "column header": True}, + {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [2, 3], + "cell text": "h2col34", + "column header": True, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "r3col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "r3col2", + "column header": False, + }, + { + "row_nums": [2, 3], + "column_nums": [2, 3], + "cell text": "r34col34", + "column header": False, + }, + { + "row_nums": [3], + "column_nums": [0, 1], + "cell text": "r4col12", + "column header": False, + }, + ], + ( + '' + '' + '' + '' + '
h12col1h1col23h1col4
h2col2h2col34
r3col1r3col2r34col34
r4col12
' + ), + id="various spans, with 2 row header", + ), + ], +) +def test_cells_to_html(input_cells, expected_html): + assert tables.cells_to_html(input_cells) == expected_html + + +@pytest.mark.parametrize( + ("input_cells", "expected_cells"), + [ + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + ], + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "h1col2", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [2], + "cell text": "row1col3", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + ], + id="identical tables, no changes expected", + ), + pytest.param( + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + ], + [ + {"row_nums": [0], "column_nums": [0], "cell text": "h1col1", "column header": True}, + {"row_nums": [0], "column_nums": [1], "cell text": "", "column header": True}, + {"row_nums": [0], "column_nums": [2], "cell text": "h1col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [0], + "cell text": "row1col1", + "column header": False, + }, + { + "row_nums": [1], + "column_nums": [1], + "cell text": "row1col2", + "column header": False, + }, + {"row_nums": [1], "column_nums": [2], "cell text": "", "column header": False}, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "row2col1", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [1], + "cell text": "row2col2", + "column header": False, + }, + { + "row_nums": [2], + "column_nums": [2], + "cell text": "row2col3", + "column header": False, + }, + ], + id="missing column in header and in the middle", + ), + pytest.param( + [ + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "h12col1", + "column header": True, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "h1col23", + "column header": True, + }, + {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [2, 3], + "cell text": "h2col34", + "column header": True, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "r3col1", + "column header": False, + }, + { + "row_nums": [2, 3], + "column_nums": [2, 3], + "cell text": "r34col34", + "column header": False, + }, + { + "row_nums": [3], + "column_nums": [0, 1], + "cell text": "r4col12", + "column header": False, + }, + ], + [ + { + "row_nums": [0, 1], + "column_nums": [0], + "cell text": "h12col1", + "column header": True, + }, + { + "row_nums": [0], + "column_nums": [1, 2], + "cell text": "h1col23", + "column header": True, + }, + {"row_nums": [0], "column_nums": [3], "cell text": "", "column header": True}, + {"row_nums": [1], "column_nums": [1], "cell text": "h2col2", "column header": True}, + { + "row_nums": [1], + "column_nums": [2, 3], + "cell text": "h2col34", + "column header": True, + }, + { + "row_nums": [2], + "column_nums": [0], + "cell text": "r3col1", + "column header": False, + }, + {"row_nums": [2], "column_nums": [1], "cell text": "", "column header": False}, + { + "row_nums": [2, 3], + "column_nums": [2, 3], + "cell text": "r34col34", + "column header": False, + }, + { + "row_nums": [3], + "column_nums": [0, 1], + "cell text": "r4col12", + "column header": False, + }, + ], + id="missing column in header and in the middle in table with spans", + ), + ], +) +def test_fill_cells(input_cells, expected_cells): + def sort_cells(cells): + return sorted(cells, key=lambda x: (x["row_nums"], x["column_nums"])) + + assert sort_cells(tables.fill_cells(input_cells)) == sort_cells(expected_cells) + + def test_padded_results_has_right_dimensions(table_transformer, example_image): str_class_name2idx = tables.get_class_map("structure") # a simpler mapping so we keep all structure in the returned objs below for test @@ -1201,3 +1794,100 @@ def test_padded_results_has_right_dimensions(table_transformer, example_image): def test_compute_confidence_score_zero_division_error_handling(): assert tables.compute_confidence_score([]) == 0 + + +@pytest.mark.parametrize( + "column_span_score, row_span_score, expected_text_to_indexes", + [ + ( + 0.9, + 0.8, + ( + { + "one three": {"row_nums": [0, 1], "column_nums": [0]}, + "two": {"row_nums": [0], "column_nums": [1]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), + ), + ( + 0.8, + 0.9, + ( + { + "one two": {"row_nums": [0], "column_nums": [0, 1]}, + "three": {"row_nums": [1], "column_nums": [0]}, + "four": {"row_nums": [1], "column_nums": [1]}, + } + ), + ), + ], +) +def test_subcells_filtering_when_overlapping_spanning_cells( + column_span_score, row_span_score, expected_text_to_indexes +): + """ + # table + # +-----------+----------+ + # | one | two | + # |-----------+----------| + # | three | four | + # +-----------+----------+ + + spanning cells over first row and over first column + """ + table_structure = { + "rows": [ + {"bbox": [0, 0, 10, 20]}, + {"bbox": [10, 0, 20, 20]}, + ], + "columns": [ + {"bbox": [0, 0, 20, 10]}, + {"bbox": [0, 10, 20, 20]}, + ], + "spanning cells": [ + {"bbox": [0, 0, 20, 10], "score": column_span_score}, + {"bbox": [0, 0, 10, 20], "score": row_span_score}, + ], + } + tokens = [ + { + "text": "one", + "bbox": [0, 0, 10, 10], + }, + { + "text": "two", + "bbox": [0, 10, 10, 20], + }, + { + "text": "three", + "bbox": [10, 0, 20, 10], + }, + {"text": "four", "bbox": [10, 10, 20, 20]}, + ] + token_args = {"span_num": 1, "line_num": 1, "block_num": 1} + for token in tokens: + token.update(token_args) + for spanning_cell in table_structure["spanning cells"]: + spanning_cell["projected row header"] = False + + # table structure is edited inside structure_to_cells, save copy for future runs + saved_table_structure = deepcopy(table_structure) + + predicted_cells, _ = structure_to_cells(table_structure, tokens=tokens) + predicted_text_to_indexes = { + cell["cell text"]: { + "row_nums": cell["row_nums"], + "column_nums": cell["column_nums"], + } + for cell in predicted_cells + } + assert predicted_text_to_indexes == expected_text_to_indexes + + # swap spanning cells to ensure the highest prob spanning cell is used + spans = saved_table_structure["spanning cells"] + spans[0], spans[1] = spans[1], spans[0] + saved_table_structure["spanning cells"] = spans + + predicted_cells_after_reorder, _ = structure_to_cells(saved_table_structure, tokens=tokens) + assert predicted_cells_after_reorder == predicted_cells diff --git a/test_unstructured_inference/test_elements.py b/test_unstructured_inference/test_elements.py index abd4ce0c..f6f5b568 100644 --- a/test_unstructured_inference/test_elements.py +++ b/test_unstructured_inference/test_elements.py @@ -6,12 +6,12 @@ from unstructured_inference.constants import ElementType from unstructured_inference.inference import elements -from unstructured_inference.inference.elements import TextRegion, ImageTextRegion +from unstructured_inference.inference.elements import Rectangle, TextRegion, ImageTextRegion from unstructured_inference.inference.layoutelement import ( + LayoutElement, + merge_inferred_layout_with_extracted_layout, partition_groups_from_regions, separate, - merge_inferred_layout_with_extracted_layout, - LayoutElement, ) skip_outside_ci = os.getenv("CI", "").lower() in {"", "false", "f", "0"} @@ -31,6 +31,18 @@ def rand_rect(size=10): return elements.Rectangle(x1, y1, x1 + size, y1 + size) +@pytest.mark.parametrize( + ("rect1", "rect2", "expected"), + [ + (Rectangle(0, 0, 1, 1), Rectangle(0, 0, None, None), None), + (Rectangle(0, 0, None, None), Rectangle(0, 0, 1, 1), None), + ], +) +def test_unhappy_intersection(rect1, rect2, expected): + assert rect1.intersection(rect2) == expected + assert not rect1.intersects(rect2) + + @pytest.mark.parametrize("second_size", [10, 20]) def test_intersects(second_size): for _ in range(1000): @@ -106,10 +118,6 @@ def test_partition_groups_from_regions(mock_embedded_text_regions): text = "".join([el.text for el in sorted_groups[-1]]) assert text.startswith("Layout") - words = [] - groups = partition_groups_from_regions(words) - assert len(groups) == 0 - def test_rectangle_area(monkeypatch): for _ in range(1000): @@ -202,7 +210,10 @@ def test_intersection_over_min( def test_grow_region_to_match_region(): - from unstructured_inference.inference.elements import Rectangle, grow_region_to_match_region + from unstructured_inference.inference.elements import ( + Rectangle, + grow_region_to_match_region, + ) a = Rectangle(1, 1, 2, 2) b = Rectangle(1, 1, 5, 5) diff --git a/unstructured_inference/__version__.py b/unstructured_inference/__version__.py index 0f031048..33c5779d 100644 --- a/unstructured_inference/__version__.py +++ b/unstructured_inference/__version__.py @@ -1 +1 @@ -__version__ = "0.7.20" # pragma: no cover +__version__ = "0.7.37-dev0" # pragma: no cover diff --git a/unstructured_inference/config.py b/unstructured_inference/config.py index 0d85b6a3..d5765bbf 100644 --- a/unstructured_inference/config.py +++ b/unstructured_inference/config.py @@ -5,6 +5,7 @@ settings that should not be altered without making a code change (e.g., definition of 1Gb of memory in bytes). Constants should go into `./constants.py` """ + import os from dataclasses import dataclass diff --git a/unstructured_inference/constants.py b/unstructured_inference/constants.py index 5de6eea8..173e37b4 100644 --- a/unstructured_inference/constants.py +++ b/unstructured_inference/constants.py @@ -8,16 +8,33 @@ class Source(Enum): CHIPPER = "chipper" CHIPPERV1 = "chipperv1" CHIPPERV2 = "chipperv2" + CHIPPERV3 = "chipperv3" MERGED = "merged" - SUPER_GRADIENTS = "super-gradients" + + +CHIPPER_VERSIONS = ( + Source.CHIPPER, + Source.CHIPPERV1, + Source.CHIPPERV2, + Source.CHIPPERV3, +) class ElementType: + PARAGRAPH = "Paragraph" IMAGE = "Image" + PARAGRAPH_IN_IMAGE = "ParagraphInImage" FIGURE = "Figure" PICTURE = "Picture" TABLE = "Table" + PARAGRAPH_IN_TABLE = "ParagraphInTable" LIST = "List" + FORM = "Form" + PARAGRAPH_IN_FORM = "ParagraphInForm" + CHECK_BOX_CHECKED = "CheckBoxChecked" + CHECK_BOX_UNCHECKED = "CheckBoxUnchecked" + RADIO_BUTTON_CHECKED = "RadioButtonChecked" + RADIO_BUTTON_UNCHECKED = "RadioButtonUnchecked" LIST_ITEM = "List-item" FORMULA = "Formula" CAPTION = "Caption" @@ -29,6 +46,12 @@ class ElementType: TEXT = "Text" UNCATEGORIZED_TEXT = "UncategorizedText" PAGE_BREAK = "PageBreak" + CODE_SNIPPET = "CodeSnippet" + PAGE_NUMBER = "PageNumber" + OTHER = "Other" FULL_PAGE_REGION_THRESHOLD = 0.99 + +# this field is defined by pytesseract/unstructured.pytesseract +TESSERACT_TEXT_HEIGHT = "height" diff --git a/unstructured_inference/inference/elements.py b/unstructured_inference/inference/elements.py index e1194fd6..f6b4d4d8 100644 --- a/unstructured_inference/inference/elements.py +++ b/unstructured_inference/inference/elements.py @@ -1,10 +1,8 @@ from __future__ import annotations -import re -import unicodedata from copy import deepcopy from dataclasses import dataclass -from typing import Collection, Optional, Union +from typing import Optional, Union import numpy as np @@ -67,6 +65,8 @@ def is_disjoint(self, other: Rectangle) -> bool: def intersects(self, other: Rectangle) -> bool: """Checks whether this rectangle intersects another rectangle.""" + if self._has_none() or other._has_none(): + return False return intersections(self, other)[0, 1] def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = None) -> bool: @@ -81,6 +81,10 @@ def is_in(self, other: Rectangle, error_margin: Optional[Union[int, float]] = No ], ) + def _has_none(self) -> bool: + """return false when one of the coord is nan""" + return any((self.x1 is None, self.x2 is None, self.y1 is None, self.y2 is None)) + @property def coordinates(self): """Gets coordinates of the rectangle""" @@ -89,6 +93,8 @@ def coordinates(self): def intersection(self, other: Rectangle) -> Optional[Rectangle]: """Gives the rectangle that is the intersection of two rectangles, or None if the rectangles are disjoint.""" + if self._has_none() or other._has_none(): + return None x1 = max(self.x1, other.x1) x2 = min(self.x2, other.x2) y1 = max(self.y1, other.y1) @@ -176,21 +182,6 @@ class TextRegion: def __str__(self) -> str: return str(self.text) - def extract_text( - self, - objects: Optional[Collection[TextRegion]], - ) -> str: - """Extracts text contained in region.""" - if self.text is not None: - # If block text is already populated, we'll assume it's correct - text = self.text - elif objects is not None: - text = aggregate_by_block(self, objects) - else: - text = "" - cleaned_text = remove_control_characters(text) - return cleaned_text - @classmethod def from_coords( cls, @@ -209,67 +200,11 @@ def from_coords( class EmbeddedTextRegion(TextRegion): - def extract_text( - self, - objects: Optional[Collection[TextRegion]], - ) -> str: - """Extracts text contained in region.""" - if self.text is None: - return "" - else: - return self.text + pass class ImageTextRegion(TextRegion): - def extract_text( - self, - objects: Optional[Collection[TextRegion]], - ) -> str: - """Extracts text contained in region.""" - if self.text is None: - return "" - else: - return super().extract_text(objects) - - -def aggregate_by_block( - text_region: TextRegion, - pdf_objects: Collection[TextRegion], -) -> str: - """Extracts the text aggregated from the elements of the given layout that lie within the given - block.""" - filtered_blocks = [ - obj for obj in pdf_objects if obj.bbox.is_in(text_region.bbox, error_margin=5) - ] - text = " ".join([x.text for x in filtered_blocks if x.text]) - return text - - -def cid_ratio(text: str) -> float: - """Gets ratio of unknown 'cid' characters extracted from text to all characters.""" - if not is_cid_present(text): - return 0.0 - cid_pattern = r"\(cid\:(\d+)\)" - unmatched, n_cid = re.subn(cid_pattern, "", text) - total = n_cid + len(unmatched) - return n_cid / total - - -def is_cid_present(text: str) -> bool: - """Checks if a cid code is present in a text selection.""" - if len(text) < len("(cid:x)"): - return False - return text.find("(cid:") != -1 - - -def remove_control_characters(text: str) -> str: - """Removes control characters from text.""" - - # Replace newline character with a space - text = text.replace("\n", " ") - # Remove other control characters - out_text = "".join(c for c in text if unicodedata.category(c)[0] != "C") - return out_text + pass def region_bounding_boxes_are_almost_the_same( diff --git a/unstructured_inference/inference/layout.py b/unstructured_inference/inference/layout.py index e1491c57..517a2dba 100644 --- a/unstructured_inference/inference/layout.py +++ b/unstructured_inference/inference/layout.py @@ -15,10 +15,8 @@ from unstructured_inference.inference.layoutelement import ( LayoutElement, ) -from unstructured_inference.inference.ordering import order_layout from unstructured_inference.logger import logger from unstructured_inference.models.base import get_model -from unstructured_inference.models.chipper import UnstructuredChipperModel from unstructured_inference.models.unstructuredmodel import ( UnstructuredElementExtractionModel, UnstructuredObjectDetectionModel, @@ -140,7 +138,7 @@ def __init__( ): if detection_model is not None and element_extraction_model is not None: raise ValueError("Only one of detection_model and extraction_model should be passed.") - self.image = image + self.image: Optional[Image.Image] = image if image_metadata is None: image_metadata = {} self.image_metadata = image_metadata @@ -167,6 +165,7 @@ def get_elements_using_image_extraction( raise ValueError( "Cannot get elements using image extraction, no image extraction model defined", ) + assert self.image is not None elements = self.element_extraction_model(self.image) if inplace: self.elements = elements @@ -178,7 +177,6 @@ def get_elements_with_detection_model( inplace: bool = True, ) -> Optional[List[LayoutElement]]: """Uses specified model to detect the elements on the page.""" - logger.info("Detecting page elements ...") if self.detection_model is None: model = get_model() if isinstance(model, UnstructuredObjectDetectionModel): @@ -188,6 +186,7 @@ def get_elements_with_detection_model( # NOTE(mrobinson) - We'll want make this model inference step some kind of # remote call in the future. + assert self.image is not None inferred_layout: List[LayoutElement] = self.detection_model(self.image) inferred_layout = self.detection_model.deduplicate_detected_elements( inferred_layout, @@ -199,29 +198,6 @@ def get_elements_with_detection_model( return inferred_layout - def get_elements_from_layout( - self, - layout: List[TextRegion], - pdf_objects: Optional[List[TextRegion]] = None, - ) -> List[LayoutElement]: - """Uses the given Layout to separate the page text into elements, either extracting the - text from the discovered layout blocks.""" - - # If the model is a chipper model, we don't want to order the - # elements, as they are already ordered - order_elements = not isinstance(self.detection_model, UnstructuredChipperModel) - if order_elements: - layout = order_layout(layout) - - elements = [ - get_element_from_block( - block=e, - pdf_objects=pdf_objects, - ) - for e in layout - ] - return elements - def _get_image_array(self) -> Union[np.ndarray, None]: """Converts the raw image into a numpy array.""" if self.image_array is None: @@ -322,13 +298,13 @@ def from_image( detection_model=detection_model, element_extraction_model=element_extraction_model, ) + # FIXME (yao): refactor the other methods so they all return elements like the third route if page.element_extraction_model is not None: page.get_elements_using_image_extraction() - return page - if fixed_layout is None: + elif fixed_layout is None: page.get_elements_with_detection_model() else: - page.elements = page.get_elements_from_layout(fixed_layout) + page.elements = [] page.image_metadata = { "format": page.image.format if page.image else None, @@ -403,19 +379,6 @@ def process_file_with_model( return layout -def get_element_from_block( - block: TextRegion, - pdf_objects: Optional[List[TextRegion]] = None, -) -> LayoutElement: - """Creates a LayoutElement from a given layout or image by finding all the text that lies within - a given block.""" - element = block if isinstance(block, LayoutElement) else LayoutElement.from_region(block) - element.text = element.extract_text( - objects=pdf_objects, - ) - return element - - def convert_pdf_to_image( filename: str, dpi: int = 200, diff --git a/unstructured_inference/inference/layoutelement.py b/unstructured_inference/inference/layoutelement.py index 376c419e..37a9ef24 100644 --- a/unstructured_inference/inference/layoutelement.py +++ b/unstructured_inference/inference/layoutelement.py @@ -10,6 +10,7 @@ from unstructured_inference.config import inference_config from unstructured_inference.constants import ( + CHIPPER_VERSIONS, FULL_PAGE_REGION_THRESHOLD, ElementType, Source, @@ -31,16 +32,6 @@ class LayoutElement(TextRegion): image_path: Optional[str] = None parent: Optional[LayoutElement] = None - def extract_text( - self, - objects: Optional[Collection[TextRegion]], - ): - """Extracts text contained in region""" - text = super().extract_text( - objects=objects, - ) - return text - def to_dict(self) -> dict: """Converts the class instance to dictionary form.""" out_dict = { @@ -108,7 +99,7 @@ def merge_inferred_layout_with_extracted_layout( continue region_matched = False for inferred_region in inferred_layout: - if inferred_region.source in (Source.CHIPPER, Source.CHIPPERV1): + if inferred_region.source in CHIPPER_VERSIONS: continue if inferred_region.bbox.intersects(extracted_region.bbox): @@ -164,9 +155,11 @@ def merge_inferred_layout_with_extracted_layout( categorized_extracted_elements_to_add = [ LayoutElement( text=el.text, - type=ElementType.IMAGE - if isinstance(el, ImageTextRegion) - else ElementType.UNCATEGORIZED_TEXT, + type=( + ElementType.IMAGE + if isinstance(el, ImageTextRegion) + else ElementType.UNCATEGORIZED_TEXT + ), source=el.source, bbox=el.bbox, ) @@ -221,7 +214,9 @@ def reduce(keep: Rectangle, reduce: Rectangle): reduce(keep=region_b, reduce=region_a) -def table_cells_to_dataframe(cells: dict, nrows: int = 1, ncols: int = 1, header=None) -> DataFrame: +def table_cells_to_dataframe( + cells: List[dict], nrows: int = 1, ncols: int = 1, header=None +) -> DataFrame: """convert table-transformer's cells data into a pandas dataframe""" arr = np.empty((nrows, ncols), dtype=object) for cell in cells: diff --git a/unstructured_inference/inference/ordering.py b/unstructured_inference/inference/ordering.py deleted file mode 100644 index 33b823c7..00000000 --- a/unstructured_inference/inference/ordering.py +++ /dev/null @@ -1,35 +0,0 @@ -from typing import List - -from unstructured_inference.inference.elements import TextRegion - - -def order_layout( - layout: List[TextRegion], - column_tol_factor: float = 0.2, - full_page_threshold_factor: float = 0.9, -) -> List[TextRegion]: - """Orders the layout elements detected on a page. For groups of elements that are not - the width of the page, the algorithm attempts to group elements into column based on - the coordinates of the bounding box. Columns are ordered left to right, and elements - within columns are ordered top to bottom. - - Parameters - ---------- - layout - the layout elements to order. - column_tol_factor - multiplied by the page width to find the tolerance for considering two elements as - part of the same column. - full_page_threshold_factor - multiplied by the page width to find the minimum width an elements need to be - for it to be considered a full page width element. - """ - if len(layout) == 0: - return [] - - layout.sort( - key=lambda element: (element.bbox.y1, element.bbox.x1, element.bbox.y2, element.bbox.x2), - ) - # NOTE(alan): Temporarily revert to orginal logic pending fixing the new logic - # See code prior to this commit for new logic. - return layout diff --git a/unstructured_inference/models/base.py b/unstructured_inference/models/base.py index 26336e23..7ebe6e42 100644 --- a/unstructured_inference/models/base.py +++ b/unstructured_inference/models/base.py @@ -1,43 +1,49 @@ +from __future__ import annotations + import json import os -from typing import Dict, Optional, Type +from typing import Dict, Optional, Tuple, Type from unstructured_inference.models.chipper import MODEL_TYPES as CHIPPER_MODEL_TYPES from unstructured_inference.models.chipper import UnstructuredChipperModel -from unstructured_inference.models.detectron2 import ( - MODEL_TYPES as DETECTRON2_MODEL_TYPES, -) -from unstructured_inference.models.detectron2 import ( - UnstructuredDetectronModel, -) -from unstructured_inference.models.detectron2onnx import ( - MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES, -) -from unstructured_inference.models.detectron2onnx import ( - UnstructuredDetectronONNXModel, -) -from unstructured_inference.models.super_gradients import ( - UnstructuredSuperGradients, -) +from unstructured_inference.models.detectron2onnx import MODEL_TYPES as DETECTRON2_ONNX_MODEL_TYPES +from unstructured_inference.models.detectron2onnx import UnstructuredDetectronONNXModel from unstructured_inference.models.unstructuredmodel import UnstructuredModel -from unstructured_inference.models.yolox import ( - MODEL_TYPES as YOLOX_MODEL_TYPES, -) -from unstructured_inference.models.yolox import ( - UnstructuredYoloXModel, -) +from unstructured_inference.models.yolox import MODEL_TYPES as YOLOX_MODEL_TYPES +from unstructured_inference.models.yolox import UnstructuredYoloXModel +from unstructured_inference.utils import LazyDict DEFAULT_MODEL = "yolox" models: Dict[str, UnstructuredModel] = {} -model_class_map: Dict[str, Type[UnstructuredModel]] = { - **{name: UnstructuredDetectronModel for name in DETECTRON2_MODEL_TYPES}, - **{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES}, - **{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES}, - **{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES}, - "super_gradients": UnstructuredSuperGradients, -} + +def get_default_model_mappings() -> Tuple[ + Dict[str, Type[UnstructuredModel]], + Dict[str, dict | LazyDict], +]: + """default model mappings for models that are in `unstructured_inference` repo""" + return { + **{name: UnstructuredDetectronONNXModel for name in DETECTRON2_ONNX_MODEL_TYPES}, + **{name: UnstructuredYoloXModel for name in YOLOX_MODEL_TYPES}, + **{name: UnstructuredChipperModel for name in CHIPPER_MODEL_TYPES}, + }, { + **DETECTRON2_ONNX_MODEL_TYPES, + **YOLOX_MODEL_TYPES, + **CHIPPER_MODEL_TYPES, + } + + +model_class_map, model_config_map = get_default_model_mappings() + + +def register_new_model(model_config: dict, model_class: UnstructuredModel): + """Register this model in model_config_map and model_class_map. + + Those maps are updated with the with the new model class information. + """ + model_config_map.update(model_config) + model_class_map.update({name: model_class for name in model_config}) def get_model(model_name: Optional[str] = None) -> UnstructuredModel: @@ -58,15 +64,13 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel: if initialize_param_json is not None: with open(initialize_param_json) as fp: initialize_params = json.load(fp) + label_map_int_keys = { + int(key): value for key, value in initialize_params["label_map"].items() + } + initialize_params["label_map"] = label_map_int_keys else: - if model_name in DETECTRON2_MODEL_TYPES: - initialize_params = DETECTRON2_MODEL_TYPES[model_name] - elif model_name in DETECTRON2_ONNX_MODEL_TYPES: - initialize_params = DETECTRON2_ONNX_MODEL_TYPES[model_name] - elif model_name in YOLOX_MODEL_TYPES: - initialize_params = YOLOX_MODEL_TYPES[model_name] - elif model_name in CHIPPER_MODEL_TYPES: - initialize_params = CHIPPER_MODEL_TYPES[model_name] + if model_name in model_config_map: + initialize_params = model_config_map[model_name] else: raise UnknownModelException(f"Unknown model type: {model_name}") @@ -78,6 +82,6 @@ def get_model(model_name: Optional[str] = None) -> UnstructuredModel: class UnknownModelException(Exception): - """Exception for the case where a model is called for with an unrecognized identifier.""" + """A model was requested with an unrecognized identifier.""" pass diff --git a/unstructured_inference/models/chipper.py b/unstructured_inference/models/chipper.py index 0ffa9c7d..857c83e9 100644 --- a/unstructured_inference/models/chipper.py +++ b/unstructured_inference/models/chipper.py @@ -9,20 +9,19 @@ import torch import transformers from cv2.typing import MatLike -from huggingface_hub import hf_hub_download from PIL.Image import Image from transformers import DonutProcessor, VisionEncoderDecoderModel from transformers.generation.logits_process import LogitsProcessor from transformers.generation.stopping_criteria import StoppingCriteria -from unstructured_inference.constants import Source +from unstructured_inference.constants import CHIPPER_VERSIONS, Source from unstructured_inference.inference.elements import Rectangle from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.logger import logger from unstructured_inference.models.unstructuredmodel import ( UnstructuredElementExtractionModel, ) -from unstructured_inference.utils import LazyDict, strip_tags +from unstructured_inference.utils import LazyDict, download_if_needed_and_get_local_path, strip_tags MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = { "chipperv1": { @@ -44,11 +43,22 @@ "max_length": 1536, "heatmap_h": 40, "heatmap_w": 30, + "source": Source.CHIPPERV2, + }, + "chipperv3": { + "pre_trained_model_repo": "unstructuredio/chipper-v3", + "swap_head": True, + "swap_head_hidden_layer_size": 128, + "start_token_prefix": "", + "max_length": 1536, + "heatmap_h": 40, + "heatmap_w": 30, "source": Source.CHIPPER, }, } -MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv2"] +MODEL_TYPES["chipper"] = MODEL_TYPES["chipperv3"] class UnstructuredChipperModel(UnstructuredElementExtractionModel): @@ -104,8 +114,8 @@ def initialize( token=auth_token, ) if swap_head: - lm_head_file = hf_hub_download( - repo_id=pre_trained_model_repo, + lm_head_file = download_if_needed_and_get_local_path( + path_or_repo=pre_trained_model_repo, filename="lm_head.pth", token=auth_token, ) @@ -160,16 +170,18 @@ def predict(self, image) -> List[LayoutElement]: return elements @staticmethod - def format_table_elements(elements): - """makes chipper table element return the same as other layout models + def format_table_elements(elements: List[LayoutElement]) -> List[LayoutElement]: + """Makes chipper table element return the same as other layout models. - - copies the html representation to attribute text_as_html - - strip html tags from the attribute text + 1. If `text` attribute is an html (has html tags in it), copies the `text` + attribute to `text_as_html` attribute. + 2. Strips html tags from the `text` attribute. """ for element in elements: - element.text_as_html = element.text - element.text = strip_tags(element.text) - + text = strip_tags(element.text) if element.text is not None else element.text + if text != element.text: + element.text_as_html = element.text # type: ignore[attr-defined] + element.text = text return elements def predict_tokens( @@ -196,6 +208,7 @@ def predict_tokens( outputs = self.model.generate( encoder_outputs=encoder_outputs, input_ids=self.input_ids, + max_length=self.max_length, no_repeat_ngram_size=0, num_beams=1, return_dict_in_generate=True, @@ -212,6 +225,7 @@ def predict_tokens( outputs = self.model.generate( encoder_outputs=encoder_outputs, input_ids=self.input_ids, + max_length=self.max_length, logits_processor=self.logits_processor, do_sample=True, no_repeat_ngram_size=0, @@ -390,7 +404,7 @@ def deduplicate_detected_elements( min_text_size: int = 15, ) -> List[LayoutElement]: """For chipper, remove elements from other sources.""" - return [el for el in elements if el.source in (Source.CHIPPER, Source.CHIPPERV1)] + return [el for el in elements if el.source in CHIPPER_VERSIONS] def adjust_bbox(self, bbox, x_offset, y_offset, ratio, target_size): """Translate bbox by (x_offset, y_offset) and shrink by ratio.""" @@ -450,7 +464,7 @@ def get_bounding_box( np.asarray( [ agg_heatmap, - cv2.resize( + cv2.resize( # type: ignore hmap, (final_w, final_h), interpolation=cv2.INTER_LINEAR_EXACT, # cv2.INTER_CUBIC @@ -516,12 +530,13 @@ def reduce_element_bbox( Given a LayoutElement element, reduce the size of the bounding box, depending on existing elements """ - bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2] + if element.bbox: + bbox = [element.bbox.x1, element.bbox.y1, element.bbox.x2, element.bbox.y2] - if not self.element_overlap(elements, element): - element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox)) - else: - element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox)) + if not self.element_overlap(elements, element): + element.bbox = Rectangle(*self.reduce_bbox_no_overlap(image, bbox)) + else: + element.bbox = Rectangle(*self.reduce_bbox_overlap(image, bbox)) def bbox_overlap( self, @@ -606,7 +621,7 @@ def reduce_bbox_no_overlap( ): return input_bbox - nimage = np.array(image.crop(input_bbox)) + nimage = np.array(image.crop(input_bbox)) # type: ignore nimage = self.remove_horizontal_lines(nimage) @@ -655,7 +670,7 @@ def reduce_bbox_overlap( ): return input_bbox - nimage = np.array(image.crop(input_bbox)) + nimage = np.array(image.crop(input_bbox)) # type: ignore nimage = self.remove_horizontal_lines(nimage) @@ -759,7 +774,7 @@ def largest_margin( ): return None - nimage = np.array(image.crop(input_bbox)) + nimage = np.array(image.crop(input_bbox)) # type: ignore if nimage.shape[0] * nimage.shape[1] == 0: return None @@ -868,6 +883,8 @@ def resolve_bbox_overlaps( continue ebbox1 = element.bbox + if ebbox1 is None: + continue bbox1 = [ebbox1.x1, ebbox1.y1, ebbox1.x2, max(ebbox1.y1, ebbox1.y2)] for celement in elements: diff --git a/unstructured_inference/models/detectron2.py b/unstructured_inference/models/detectron2.py deleted file mode 100644 index 98939f88..00000000 --- a/unstructured_inference/models/detectron2.py +++ /dev/null @@ -1,99 +0,0 @@ -from pathlib import Path -from typing import Any, Dict, Final, List, Optional, Union - -from huggingface_hub import hf_hub_download -from layoutparser.models.detectron2.layoutmodel import ( - Detectron2LayoutModel, - is_detectron2_available, -) -from layoutparser.models.model_config import LayoutModelConfig -from PIL import Image - -from unstructured_inference.constants import ElementType -from unstructured_inference.inference.layoutelement import LayoutElement -from unstructured_inference.logger import logger -from unstructured_inference.models.unstructuredmodel import ( - UnstructuredObjectDetectionModel, -) -from unstructured_inference.utils import LazyDict, LazyEvaluateInfo - -DETECTRON_CONFIG: Final = "lp://PubLayNet/faster_rcnn_R_50_FPN_3x/config" -DEFAULT_LABEL_MAP: Final[Dict[int, str]] = { - 0: ElementType.TEXT, - 1: ElementType.TITLE, - 2: ElementType.LIST, - 3: ElementType.TABLE, - 4: ElementType.FIGURE, -} -DEFAULT_EXTRA_CONFIG: Final[List[Any]] = ["MODEL.ROI_HEADS.SCORE_THRESH_TEST", 0.8] - - -# NOTE(alan): Entries are implemented as LazyDicts so that models aren't downloaded until they are -# needed. -MODEL_TYPES = { - "detectron2_lp": LazyDict( - model_path=LazyEvaluateInfo( - hf_hub_download, - "layoutparser/detectron2", - "PubLayNet/faster_rcnn_R_50_FPN_3x/model_final.pth", - ), - config_path=LazyEvaluateInfo( - hf_hub_download, - "layoutparser/detectron2", - "PubLayNet/faster_rcnn_R_50_FPN_3x/config.yml", - ), - label_map=DEFAULT_LABEL_MAP, - extra_config=DEFAULT_EXTRA_CONFIG, - ), - "checkbox": LazyDict( - model_path=LazyEvaluateInfo( - hf_hub_download, - "unstructuredio/oer-checkbox", - "detectron2_finetuned_oer_checkbox.pth", - ), - config_path=LazyEvaluateInfo( - hf_hub_download, - "unstructuredio/oer-checkbox", - "detectron2_oer_checkbox.json", - ), - label_map={0: "Unchecked", 1: "Checked"}, - extra_config=None, - ), -} - - -class UnstructuredDetectronModel(UnstructuredObjectDetectionModel): - """Unstructured model wrapper for Detectron2LayoutModel.""" - - def predict(self, x: Image): - """Makes a prediction using detectron2 model.""" - super().predict(x) - prediction = self.model.detect(x) - return [LayoutElement.from_lp_textblock(block) for block in prediction] - - def initialize( - self, - config_path: Union[str, Path, LayoutModelConfig], - model_path: Optional[Union[str, Path]] = None, - label_map: Optional[Dict[int, str]] = None, - extra_config: Optional[list] = None, - device: Optional[str] = None, - ): - """Loads the detectron2 model using the specified parameters""" - - if not is_detectron2_available(): - raise ImportError( - "Failed to load the Detectron2 model. Ensure that the Detectron2 " - "module is correctly installed.", - ) - - config_path_str = str(config_path) - model_path_str: Optional[str] = None if model_path is None else str(model_path) - logger.info("Loading the Detectron2 layout model ...") - self.model = Detectron2LayoutModel( - config_path_str, - model_path=model_path_str, - label_map=label_map, - extra_config=extra_config, - device=device, - ) diff --git a/unstructured_inference/models/detectron2onnx.py b/unstructured_inference/models/detectron2onnx.py index 3def8ced..79cd0a1a 100644 --- a/unstructured_inference/models/detectron2onnx.py +++ b/unstructured_inference/models/detectron2onnx.py @@ -4,7 +4,6 @@ import cv2 import numpy as np import onnxruntime -from huggingface_hub import hf_hub_download from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE from onnxruntime.capi import _pybind_state as C from onnxruntime.quantization import QuantType, quantize_dynamic @@ -16,7 +15,11 @@ from unstructured_inference.models.unstructuredmodel import ( UnstructuredObjectDetectionModel, ) -from unstructured_inference.utils import LazyDict, LazyEvaluateInfo +from unstructured_inference.utils import ( + LazyDict, + LazyEvaluateInfo, + download_if_needed_and_get_local_path, +) onnxruntime.set_default_logger_severity(logger_onnx.getEffectiveLevel()) @@ -34,7 +37,7 @@ MODEL_TYPES: Dict[str, Union[LazyDict, dict]] = { "detectron2_onnx": LazyDict( model_path=LazyEvaluateInfo( - hf_hub_download, + download_if_needed_and_get_local_path, "unstructuredio/detectron2_faster_rcnn_R_50_FPN_3x", "model.onnx", ), @@ -52,7 +55,7 @@ }, "detectron2_mask_rcnn": LazyDict( model_path=LazyEvaluateInfo( - hf_hub_download, + download_if_needed_and_get_local_path, "unstructuredio/detectron2_mask_rcnn_X_101_32x8d_FPN_3x", "model.onnx", ), diff --git a/unstructured_inference/models/donut.py b/unstructured_inference/models/donut.py index 1f753f56..bc60d2c6 100644 --- a/unstructured_inference/models/donut.py +++ b/unstructured_inference/models/donut.py @@ -3,7 +3,7 @@ from typing import Optional, Union import torch -from PIL import Image +from PIL import Image as PILImage from transformers import ( DonutProcessor, VisionEncoderDecoderConfig, @@ -16,7 +16,7 @@ class UnstructuredDonutModel(UnstructuredModel): """Unstructured model wrapper for Donut image transformer.""" - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Make prediction using donut model""" super().predict(x) return self.run_prediction(x) @@ -50,7 +50,7 @@ def initialize( raise ImportError("Review the parameters to initialize a UnstructuredDonutModel obj") self.model.to(device) - def run_prediction(self, x: Image): + def run_prediction(self, x: PILImage.Image): """Internal prediction method.""" pixel_values = self.processor(x, return_tensors="pt").pixel_values decoder_input_ids = self.processor.tokenizer( diff --git a/unstructured_inference/models/super_gradients.py b/unstructured_inference/models/super_gradients.py deleted file mode 100644 index 6d9a25fa..00000000 --- a/unstructured_inference/models/super_gradients.py +++ /dev/null @@ -1,100 +0,0 @@ -import os -from typing import List, cast - -import cv2 -import numpy as np -import onnxruntime -from onnxruntime.capi import _pybind_state as C -from PIL import Image - -from unstructured_inference.constants import Source -from unstructured_inference.inference.layoutelement import LayoutElement -from unstructured_inference.logger import logger -from unstructured_inference.models.unstructuredmodel import ( - UnstructuredObjectDetectionModel, -) - - -class UnstructuredSuperGradients(UnstructuredObjectDetectionModel): - def predict(self, x: Image): - """Predict using Super-Gradients model.""" - super().predict(x) - return self.image_processing(x) - - def initialize(self, model_path: str, label_map: dict, input_shape: tuple): - """Start inference session for SuperGradients model.""" - - if not os.path.exists(model_path): - logger.info("ONNX Model Path Does Not Exist!") - self.model_path = model_path - - available_providers = C.get_available_providers() - ordered_providers = [ - "TensorrtExecutionProvider", - "CUDAExecutionProvider", - "CPUExecutionProvider", - ] - - providers = [provider for provider in ordered_providers if provider in available_providers] - - self.model = onnxruntime.InferenceSession( - model_path, - providers=providers, - ) - - self.layout_classes = label_map - - self.input_shape = input_shape - - def image_processing( - self, - image: Image.Image, - ) -> List[LayoutElement]: - """Method runing SuperGradients Model for layout detection, returns a PageLayout""" - # Not handling various input images right now - # TODO (Pravin): check other shapes for inference - input_shape = self.input_shape - origin_img = np.array(image) - img = preprocess(origin_img, input_shape) - session = self.model - inputs = [o.name for o in session.get_inputs()] - outputs = [o.name for o in session.get_outputs()] - predictions = session.run(outputs, {inputs[0]: img}) - - regions = [] - - num_detections, pred_boxes, pred_scores, pred_classes = predictions - for image_index in range(num_detections.shape[0]): - for i in range(num_detections[image_index, 0]): - class_id = pred_classes[image_index, i] - prob = pred_scores[image_index, i] - x1, y1, x2, y2 = pred_boxes[image_index, i] - detected_class = self.layout_classes[str(class_id)] - region = LayoutElement.from_coords( - float(x1), - float(y1), - float(x2), - float(y2), - text=None, - type=detected_class, - prob=float(prob), - source=Source.SUPER_GRADIENTS, - ) - regions.append(cast(LayoutElement, region)) - - regions.sort(key=lambda element: element.bbox.y1) - - page_layout = regions - - return page_layout - - -def preprocess(origin_img, input_shape, swap=(0, 3, 1, 2)): - """Preprocess image data before Super-Gradients Inputted Model - Giving a generic preprocess function which simply resizes the image before prediction - TODO(Pravin): Look into allowing user to specify their own pre-process function - Which takes a numpy array image and returns a numpy array image - """ - new_img = cv2.resize(origin_img, input_shape).astype(np.uint8) - image_bchw = np.transpose(np.expand_dims(new_img, 0), swap) - return image_bchw diff --git a/unstructured_inference/models/table_postprocess.py b/unstructured_inference/models/table_postprocess.py index 602ae564..741277f8 100644 --- a/unstructured_inference/models/table_postprocess.py +++ b/unstructured_inference/models/table_postprocess.py @@ -80,24 +80,6 @@ def apply_threshold(objects, threshold): return [obj for obj in objects if obj["score"] >= threshold] -# def apply_class_thresholds(bboxes, labels, scores, class_names, class_thresholds): -# """ -# Filter out bounding boxes whose confidence is below the confidence threshold for -# its associated class label. -# """ -# # Apply class-specific thresholds -# indices_above_threshold = [ -# idx -# for idx, (score, label) in enumerate(zip(scores, labels)) -# if score >= class_thresholds[class_names[label]] -# ] -# bboxes = [bboxes[idx] for idx in indices_above_threshold] -# scores = [scores[idx] for idx in indices_above_threshold] -# labels = [labels[idx] for idx in indices_above_threshold] - -# return bboxes, scores, labels - - def refine_rows(rows, tokens, score_threshold): """ Apply operations to the detected rows, such as diff --git a/unstructured_inference/models/tables.py b/unstructured_inference/models/tables.py index 9fc66aee..ade5349e 100644 --- a/unstructured_inference/models/tables.py +++ b/unstructured_inference/models/tables.py @@ -3,13 +3,16 @@ import xml.etree.ElementTree as ET from collections import defaultdict from pathlib import Path -from typing import Dict, List, Optional, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import cv2 import numpy as np import torch -from PIL import Image +from PIL import Image as PILImage from transformers import DetrImageProcessor, TableTransformerForObjectDetection +from transformers.models.table_transformer.modeling_table_transformer import ( + TableTransformerObjectDetectionOutput, +) from unstructured_inference.config import inference_config from unstructured_inference.inference.layoutelement import table_cells_to_dataframe @@ -27,7 +30,12 @@ class UnstructuredTableTransformerModel(UnstructuredModel): def __init__(self): pass - def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None): + def predict( + self, + x: PILImage.Image, + ocr_tokens: Optional[List[Dict]] = None, + result_format: str = "html", + ): """Predict table structure deferring to run_prediction with ocr tokens Note: @@ -44,7 +52,7 @@ def predict(self, x: Image, ocr_tokens: Optional[List[Dict]] = None): FIXME: refactor token data into a dataclass so we have clear expectations of the fields """ super().predict(x) - return self.run_prediction(x, ocr_tokens=ocr_tokens) + return self.run_prediction(x, ocr_tokens=ocr_tokens, result_format=result_format) def initialize( self, @@ -70,13 +78,12 @@ def initialize( def get_structure( self, - x: Image, + x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ) -> dict: """get the table structure as a dictionary contaning different types of elements as key-value pairs; check table-transformer documentation for more information""" with torch.no_grad(): - logger.info(f"padding image by {pad_for_structure_detection} for structure detection") encoding = self.feature_extractor( pad_image_with_background_color(x, pad_for_structure_detection), return_tensors="pt", @@ -87,7 +94,7 @@ def get_structure( def run_prediction( self, - x: Image, + x: PILImage.Image, pad_for_structure_detection: int = inference_config.TABLE_IMAGE_BACKGROUND_PAD, ocr_tokens: Optional[List[Dict]] = None, result_format: Optional[str] = "html", @@ -96,12 +103,27 @@ def run_prediction( outputs_structure = self.get_structure(x, pad_for_structure_detection) if ocr_tokens is None: raise ValueError("Cannot predict table structure with no OCR tokens") - prediction = recognize(outputs_structure, x, tokens=ocr_tokens)[0] + + recognized_table = recognize(outputs_structure, x, tokens=ocr_tokens) + if len(recognized_table) > 0: + prediction = recognized_table[0] + # NOTE(robinson) - This means that the table was not recognized + else: + return "" + if result_format == "html": # Convert cells to HTML prediction = cells_to_html(prediction) or "" elif result_format == "dataframe": prediction = table_cells_to_dataframe(prediction) + elif result_format == "cells": + prediction = prediction + else: + raise ValueError( + f"result_format {result_format} is not a valid format. " + f'Valid formats are: "html", "dataframe", "cells"', + ) + return prediction @@ -148,22 +170,26 @@ def get_class_map(data_type: str): } -def recognize(outputs: dict, img: Image, tokens: list): +def recognize(outputs: dict, img: PILImage.Image, tokens: list): """Recognize table elements.""" str_class_name2idx = get_class_map("structure") str_class_idx2name = {v: k for k, v in str_class_name2idx.items()} - str_class_thresholds = structure_class_thresholds + class_thresholds = structure_class_thresholds # Post-process detected objects, assign class labels objects = outputs_to_objects(outputs, img.size, str_class_idx2name) - + high_confidence_objects = apply_thresholds_on_objects(objects, class_thresholds) # Further process the detected objects so they correspond to a consistent table - tables_structure = objects_to_structures(objects, tokens, str_class_thresholds) + tables_structure = objects_to_structures(high_confidence_objects, tokens, class_thresholds) # Enumerate all table cells: grid cells and spanning cells return [structure_to_cells(structure, tokens)[0] for structure in tables_structure] -def outputs_to_objects(outputs, img_size, class_idx2name): +def outputs_to_objects( + outputs: TableTransformerObjectDetectionOutput, + img_size: Tuple[int, int], + class_idx2name: Mapping[int, str], +): """Output table element types.""" m = outputs["logits"].softmax(-1).max(-1) pred_labels = list(m.indices.detach().cpu().numpy())[0] @@ -192,6 +218,33 @@ def outputs_to_objects(outputs, img_size, class_idx2name): return objects +def apply_thresholds_on_objects( + objects: Sequence[Mapping[str, Any]], + thresholds: Mapping[str, float], +) -> Sequence[Mapping[str, Any]]: + """ + Filters predicted objects which the confidence scores below the thresholds + + Args: + objects: Sequence of mappings for example: + [ + { + "label": "table row", + "score": 0.55, + "bbox": [...], + }, + ..., + ] + thresholds: Mapping from labels to thresholds + + Returns: + Filtered list of objects + + """ + objects = [obj for obj in objects if obj["score"] >= thresholds[obj["label"]]] + return objects + + # for output bounding box post-processing def box_cxcywh_to_xyxy(x): """Convert rectangle format from center-x, center-y, width, height to @@ -431,6 +484,8 @@ def structure_to_cells(table_structure, tokens): columns = table_structure["columns"] rows = table_structure["rows"] spanning_cells = table_structure["spanning cells"] + spanning_cells = sorted(spanning_cells, reverse=True, key=lambda cell: cell["score"]) + cells = [] subcells = [] # Identify complete cells and subcells @@ -454,6 +509,7 @@ def structure_to_cells(table_structure, tokens): spanning_cell_rect.intersect(cell_rect).get_area() / cell_rect.get_area() ) > inference_config.TABLE_IOB_THRESHOLD: cell["subcell"] = True + cell["is_merged"] = False break if cell["subcell"]: @@ -475,7 +531,7 @@ def structure_to_cells(table_structure, tokens): subcell_rect_area = subcell_rect.get_area() if ( subcell_rect.intersect(spanning_cell_rect).get_area() / subcell_rect_area - ) > inference_config.TABLE_IOB_THRESHOLD: + ) > inference_config.TABLE_IOB_THRESHOLD and subcell["is_merged"] is False: if cell_rect is None: cell_rect = Rect(list(subcell["bbox"])) else: @@ -486,6 +542,8 @@ def structure_to_cells(table_structure, tokens): # as header cells for a spanning cell to be classified as a header cell; # otherwise, this could lead to a non-rectangular header region header = header and "column header" in subcell and subcell["column header"] + subcell["is_merged"] = True + if len(cell_rows) > 0 and len(cell_columns) > 0: cell = { "bbox": cell_rect.get_bbox(), @@ -590,11 +648,8 @@ def structure_to_cells(table_structure, tokens): def fill_cells(cells: List[dict]) -> List[dict]: - """add empty cells to pad cells that spans multiple rows for html conversion - - For example if a cell takes row 0 and 1 and column 0, we add a new empty cell at row 1 and - column 0. This padding ensures the structure of the output table is intact. In this example the - cell data is {"row_nums": [0, 1], "column_nums": [0], ...} + """fills the missing cells in the table by adding a cells with empty text + where there are no cells detected by the model. A cell contains the following keys relevent to the html conversion: row_nums: List[int] @@ -605,28 +660,63 @@ def fill_cells(cells: List[dict]) -> List[dict]: than one numbers cell text: str the text in this cell + column header: bool + whether this cell is a column header """ - new_cells = cells.copy() + if not cells: + return [] + + table_rows_no = max({row for cell in cells for row in cell["row_nums"]}) + table_cols_no = max({col for cell in cells for col in cell["column_nums"]}) + filled = np.zeros((table_rows_no + 1, table_cols_no + 1), dtype=bool) for cell in cells: - for extra_row in sorted(cell["row_nums"][1:]): - new_cell = cell.copy() - new_cell["row_nums"] = [extra_row] - new_cell["cell text"] = "" - new_cells.append(new_cell) + for row in cell["row_nums"]: + for col in cell["column_nums"]: + filled[row, col] = True + # add cells for which filled is false + header_rows = {row for cell in cells if cell["column header"] for row in cell["row_nums"]} + new_cells = cells.copy() + not_filled_idx = np.where(filled == False) # noqa: E712 + for row, col in zip(not_filled_idx[0], not_filled_idx[1]): + new_cell = { + "row_nums": [row], + "column_nums": [col], + "cell text": "", + "column header": row in header_rows, + } + new_cells.append(new_cell) return new_cells -def cells_to_html(cells): - """Convert table structure to html format.""" +def cells_to_html(cells: List[dict]) -> str: + """Convert table structure to html format. + + Args: + cells: List of dictionaries representing table cells, where each dictionary has the + following format: + { + "row_nums": List[int], + "column_nums": List[int], + "cell text": str, + "column header": bool, + } + Returns: + str: HTML table string + """ cells = sorted(fill_cells(cells), key=lambda k: (min(k["row_nums"]), min(k["column_nums"]))) table = ET.Element("table") current_row = -1 + table_header = None + table_has_header = any(cell["column header"] for cell in cells) + if table_has_header: + table_header = ET.SubElement(table, "thead") + + table_body = ET.SubElement(table, "tbody") for cell in cells: this_row = min(cell["row_nums"]) - attrib = {} colspan = len(cell["column_nums"]) if colspan > 1: @@ -637,18 +727,19 @@ def cells_to_html(cells): if this_row > current_row: current_row = this_row if cell["column header"]: + table_subelement = table_header cell_tag = "th" - row = ET.SubElement(table, "thead") else: + table_subelement = table_body cell_tag = "td" - row = ET.SubElement(table, "tr") + row = ET.SubElement(table_subelement, "tr") # type: ignore tcell = ET.SubElement(row, cell_tag, attrib=attrib) tcell.text = cell["cell text"] return str(ET.tostring(table, encoding="unicode", short_empty_elements=False)) -def zoom_image(image: Image, zoom: float) -> Image: +def zoom_image(image: PILImage.Image, zoom: float) -> PILImage.Image: """scale an image based on the zoom factor using cv2; the scaled image is post processed by dilation then erosion to improve edge sharpness for OCR tasks""" if zoom <= 0: @@ -666,4 +757,4 @@ def zoom_image(image: Image, zoom: float) -> Image: new_image = cv2.dilate(new_image, kernel, iterations=1) new_image = cv2.erode(new_image, kernel, iterations=1) - return Image.fromarray(new_image) + return PILImage.fromarray(new_image) diff --git a/unstructured_inference/models/yolox.py b/unstructured_inference/models/yolox.py index 47455cf4..0acd93f3 100644 --- a/unstructured_inference/models/yolox.py +++ b/unstructured_inference/models/yolox.py @@ -8,14 +8,17 @@ import cv2 import numpy as np import onnxruntime -from huggingface_hub import hf_hub_download from onnxruntime.capi import _pybind_state as C -from PIL import Image +from PIL import Image as PILImage from unstructured_inference.constants import ElementType, Source from unstructured_inference.inference.layoutelement import LayoutElement from unstructured_inference.models.unstructuredmodel import UnstructuredObjectDetectionModel -from unstructured_inference.utils import LazyDict, LazyEvaluateInfo +from unstructured_inference.utils import ( + LazyDict, + LazyEvaluateInfo, + download_if_needed_and_get_local_path, +) YOLOX_LABEL_MAP = { 0: ElementType.CAPTION, @@ -34,7 +37,7 @@ MODEL_TYPES = { "yolox": LazyDict( model_path=LazyEvaluateInfo( - hf_hub_download, + download_if_needed_and_get_local_path, "unstructuredio/yolo_x_layout", "yolox_l0.05.onnx", ), @@ -42,7 +45,7 @@ ), "yolox_tiny": LazyDict( model_path=LazyEvaluateInfo( - hf_hub_download, + download_if_needed_and_get_local_path, "unstructuredio/yolo_x_layout", "yolox_tiny.onnx", ), @@ -50,7 +53,7 @@ ), "yolox_quantized": LazyDict( model_path=LazyEvaluateInfo( - hf_hub_download, + download_if_needed_and_get_local_path, "unstructuredio/yolo_x_layout", "yolox_l0.05_quantized.onnx", ), @@ -60,7 +63,7 @@ class UnstructuredYoloXModel(UnstructuredObjectDetectionModel): - def predict(self, x: Image): + def predict(self, x: PILImage.Image): """Predict using YoloX model.""" super().predict(x) return self.image_processing(x) @@ -86,7 +89,7 @@ def initialize(self, model_path: str, label_map: dict): def image_processing( self, - image: Image = None, + image: PILImage.Image, ) -> List[LayoutElement]: """Method runing YoloX for layout detection, returns a PageLayout parameters diff --git a/unstructured_inference/utils.py b/unstructured_inference/utils.py index 86285954..696a2e8a 100644 --- a/unstructured_inference/utils.py +++ b/unstructured_inference/utils.py @@ -1,8 +1,10 @@ +import os from collections.abc import Mapping from html.parser import HTMLParser from io import StringIO from typing import Any, Callable, Hashable, Iterable, Iterator, Union +from huggingface_hub import hf_hub_download from PIL import Image from unstructured_inference.inference.layoutelement import LayoutElement @@ -101,3 +103,13 @@ def strip_tags(html: str) -> str: s = MLStripper() s.feed(html) return s.get_data() + + +def download_if_needed_and_get_local_path(path_or_repo: str, filename: str, **kwargs) -> str: + """Returns path to local file if it exists, otherwise treats it as a huggingface repo and + attempts to download.""" + full_path = os.path.join(path_or_repo, filename) + if os.path.exists(full_path): + return full_path + else: + return hf_hub_download(path_or_repo, filename, **kwargs)