Skip to content

Commit

Permalink
refactor: migrate pydantic to v2 (#182)
Browse files Browse the repository at this point in the history
This PR removes `pydantic` v1 compatibility of use `import pydantic.v1`
and applies need changes to `pydantic` v2 such as:

- `dict()`, `parse_obj()`, `parse_file()` methods replace with
`model_dump()`, `model_validate()`, `model_validate_json()`
- `validate_arguments` decorator renamed to `validate_call`,
`root_validator` to `model_validator`
- add dependency `pydantic-settings`, new repo where settings class
moved such as `BaseSettings` and updated imports
- add custom validator for unique items

Signed-off-by: Tiago Santana <[email protected]>
  • Loading branch information
SantanaTiago authored Aug 5, 2024
1 parent e731d24 commit 5b172db
Show file tree
Hide file tree
Showing 25 changed files with 337 additions and 164 deletions.
2 changes: 1 addition & 1 deletion deepsearch/chemistry/queries/molecules.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from enum import Enum
from typing import List, Union

from pydantic.v1 import BaseModel
from pydantic import BaseModel

from deepsearch.cps.client.queries import Query

Expand Down
2 changes: 1 addition & 1 deletion deepsearch/core/client/config.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Union

from pydantic.v1 import BaseModel
from pydantic import BaseModel


class DeepSearchBearerTokenAuth(BaseModel):
Expand Down
21 changes: 8 additions & 13 deletions deepsearch/core/client/settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,17 +4,18 @@
from pathlib import Path
from typing import Dict, Optional, Union

from pydantic.v1 import BaseSettings, SecretStr
from pydantic import SecretStr
from pydantic_settings import BaseSettings, SettingsConfigDict


class DumpableSettings(BaseSettings):
@classmethod
def get_env_var_name(cls, attr_name) -> str:
return cls.Config.env_prefix + attr_name.upper()
def get_env_var_name(cls, attr_name: str) -> str:
return cls.model_config["env_prefix"] + attr_name.upper()

def _get_serializable_dict(self) -> Dict[str, str]:
result = {}
model_dict = self.dict()
model_dict = self.model_dump()
for k in model_dict:
new_key = self.get_env_var_name(attr_name=k)
if isinstance((old_val := model_dict[k]), SecretStr):
Expand All @@ -38,9 +39,7 @@ class ProfileSettings(DumpableSettings):
username: str
api_key: SecretStr
verify_ssl: bool = True

class Config:
env_prefix = "DEEPSEARCH_"
model_config = SettingsConfigDict(env_prefix="DEEPSEARCH_")

@classmethod
def from_cli_prompt(cls) -> ProfileSettings:
Expand All @@ -54,13 +53,9 @@ def from_cli_prompt(cls) -> ProfileSettings:

class MainSettings(DumpableSettings):
profile: Optional[str] = None

class Config:
env_prefix = "DEEPSEARCH_"
model_config = SettingsConfigDict(env_prefix="DEEPSEARCH_")


class CLISettings(DumpableSettings):
show_cli_stack_traces: bool = False

class Config:
env_prefix = "DEEPSEARCH_"
model_config = SettingsConfigDict(env_prefix="DEEPSEARCH_")
4 changes: 2 additions & 2 deletions deepsearch/core/client/settings_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import Dict, List, Optional

import platformdirs
from pydantic.v1 import ValidationError
from pydantic import ValidationError

from deepsearch.core.cli.profile_utils import (
MSG_AMBIGUOUS_SUCCESSOR,
Expand Down Expand Up @@ -111,7 +111,7 @@ def _migrate_legacy_config(self) -> None:
if self._main_settings.profile is None:
legacy_cfg_path = self.config_root_path / LEGACY_CFG_FILENAME
if legacy_cfg_path.exists():
legacy_cfg = DeepSearchConfig.parse_file(legacy_cfg_path)
legacy_cfg = DeepSearchConfig.model_validate(legacy_cfg_path)
if isinstance(legacy_cfg.auth, DeepSearchKeyAuth):
new_cfg = ProfileSettings(
host=legacy_cfg.host,
Expand Down
4 changes: 2 additions & 2 deletions deepsearch/cps/cli/data_indices_typer.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def upload_files(

if conv_settings is not None:
try:
final_conv_settings = ConversionSettings.parse_obj(
final_conv_settings = ConversionSettings.model_validate(
json.loads(conv_settings)
)
except json.JSONDecodeError:
Expand All @@ -177,7 +177,7 @@ def upload_files(

if target_settings is not None:
try:
final_target_settings = TargetSettings.parse_file(target_settings)
final_target_settings = TargetSettings.model_validate_json(target_settings)
except Exception as e:
raise e
else:
Expand Down
2 changes: 1 addition & 1 deletion deepsearch/cps/client/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from typing import Optional

import requests
from pydantic.v1 import ValidationError
from pydantic import ValidationError

import deepsearch.cps.apis.user
from deepsearch.core.client import (
Expand Down
8 changes: 4 additions & 4 deletions deepsearch/cps/client/components/data_indices.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from urllib.parse import urlparse

import requests
from pydantic.v1 import BaseModel
from pydantic import BaseModel

from deepsearch.cps.apis import public as sw_client
from deepsearch.cps.apis.public.models.attachment_upload_data import (
Expand Down Expand Up @@ -40,7 +40,7 @@ def list(self, proj_key: str) -> List[DataIndex]:

# filter out saved searchs index
return [
DataIndex.parse_obj(item.to_dict())
DataIndex.model_validate(item.to_dict())
for item in response
if item.to_dict()["type"] != "View"
]
Expand Down Expand Up @@ -94,7 +94,7 @@ def create(
self.sw_api.create_project_data_index(proj_key=proj_key, data=data)
)

return DataIndex.parse_obj(response.to_dict())
return DataIndex.model_validate(response.to_dict())

def delete(
self,
Expand Down Expand Up @@ -146,7 +146,7 @@ def convert_from_page_urls(
task: Task = self.sw_api.html_print_convert_upload(
proj_key=coords.proj_key,
index_key=coords.index_key,
body={"urls": [item.dict() for item in urls]},
body={"urls": [item.model_dump() for item in urls]},
)
return task

Expand Down
2 changes: 1 addition & 1 deletion deepsearch/cps/client/components/documents.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import urllib.parse
from typing import TYPE_CHECKING, Any, Dict, Optional, Union

from pydantic.v1 import BaseModel
from pydantic import BaseModel

from deepsearch.cps.apis.public_v2 import SemanticApi
from deepsearch.cps.apis.public_v2.models.cps_task import CpsTask
Expand Down
6 changes: 4 additions & 2 deletions deepsearch/cps/client/components/elastic.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import datetime
from typing import TYPE_CHECKING, Any, Dict, List, Union

from pydantic.v1 import BaseModel
from pydantic import BaseModel

from deepsearch.cps.apis import public as sw_client
from deepsearch.cps.client.components.data_indices import (
Expand Down Expand Up @@ -31,7 +31,9 @@ def list(self, domain: str = "all") -> List[ElasticDataCollection]:
index_type="all", index_domain=domain
)

return [ElasticDataCollection.parse_obj(item.to_dict()) for item in response]
return [
ElasticDataCollection.model_validate(item.to_dict()) for item in response
]


class ElasticDataCollectionSource(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion deepsearch/cps/client/components/projects.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from enum import Enum
from typing import TYPE_CHECKING, List, Literal, Optional, Union

from pydantic.v1 import BaseModel
from pydantic import BaseModel

import deepsearch.cps.apis.user
from deepsearch.cps.apis.user.models.token_response import TokenResponse
Expand Down
2 changes: 1 addition & 1 deletion deepsearch/cps/client/components/uploader.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Union

import requests
from pydantic.v1 import BaseModel
from pydantic import BaseModel

from deepsearch.cps.apis import public as sw_client
from deepsearch.cps.apis.public.models.temporary_upload_file_result import (
Expand Down
6 changes: 4 additions & 2 deletions deepsearch/cps/data_indices/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,9 @@ def process_local_file(
if conv_settings is not None:
payload["conversion_settings"] = conv_settings.to_ccs_spec()
if target_settings is not None:
payload["target_settings"] = target_settings.dict(exclude_none=True)
payload["target_settings"] = target_settings.model_dump(
exclude_none=True
)

task_id = api.data_indices.upload_file(coords=coords, body=payload)
task_ids.append(task_id)
Expand Down Expand Up @@ -203,7 +205,7 @@ def process_external_cos(
bar_format=progressbar.bar_format,
) as progress:
# upload using coordinates
payload = {"s3_source": {"coordinates": s3_coordinates.dict()}}
payload = {"s3_source": {"coordinates": s3_coordinates.model_dump()}}
task_id = api.data_indices.upload_file(
coords=coords,
body=payload,
Expand Down
11 changes: 7 additions & 4 deletions deepsearch/cps/queries/__init__.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Dict, List, Literal, Optional, Union

from pydantic.v1 import BaseModel, Field, validate_arguments
from pydantic import BaseModel, Field, validate_call
from pydantic_settings import SettingsConfigDict
from typing_extensions import Annotated

from deepsearch.cps.client.components.documents import (
Expand Down Expand Up @@ -114,8 +115,10 @@ class _APISemanticRagParameters(_APISemanticRetrievalParameters):
chunk_refs: Optional[List[ChunkRef]] = None
gen_timeout: Optional[float] = None

model_config = SettingsConfigDict(protected_namespaces=())

@validate_arguments

@validate_call
def RAGQuery(
question: str,
*,
Expand Down Expand Up @@ -202,7 +205,7 @@ def RAGQuery(
return query


@validate_arguments
@validate_call
def SemanticQuery(
question: str,
*,
Expand Down Expand Up @@ -254,7 +257,7 @@ def SemanticQuery(
task = query.add(
task_id="QA",
kind_or_task="SemanticRetrieval",
parameters=params.dict(),
parameters=params.model_dump(),
coordinates=coords,
)
task.output("items").output_as("items")
Expand Down
6 changes: 4 additions & 2 deletions deepsearch/cps/queries/results.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

from typing import List, Optional

from pydantic.v1 import BaseModel
from pydantic import BaseModel

from deepsearch.cps.client.components.queries import RunQueryResult

Expand Down Expand Up @@ -71,7 +71,9 @@ def from_api_output(cls, data: RunQueryResult, raise_on_error=True):
grounding=RAGGroundingInfo(
retr_items=(
[
SearchResultItem.parse_obj(search_result_items[i])
SearchResultItem.model_validate(
search_result_items[i]
)
for i in retr_idxs
]
if retr_idxs is not None and retrieval_part is not None
Expand Down
2 changes: 1 addition & 1 deletion deepsearch/documents/core/input_process.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ def process_cos_input(
cps_proj_key=cps_proj_key,
source={
"type": "s3",
"coordinates": source_cos.dict(),
"coordinates": source_cos.model_dump(),
},
target=target,
conversion_settings=conversion_settings,
Expand Down
2 changes: 1 addition & 1 deletion deepsearch/documents/core/lookup.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import Any, Dict, List

from pydantic.v1 import BaseModel
from pydantic import BaseModel


def _resolve_item(item, doc):
Expand Down
Loading

0 comments on commit 5b172db

Please sign in to comment.