Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: Changes for line-item prompt type #880

Merged
merged 10 commits into from
Jan 15, 2025
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,8 @@
"boolean":"boolean",
"json":"json",
"table":"table",
"record":"record"
"record":"record",
"line_item":"line-item"
},
"output_processing":{
"DEFAULT":"Default"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def update_or_create_prompt_output(

output = outputs.get(prompt.prompt_key)
# TODO: use enums here
if prompt.enforce_type in {"json", "table", "record"}:
if prompt.enforce_type in {"json", "table", "record", "line-item"}:
output = json.dumps(output)
profile_manager = default_profile
eval_metrics = outputs.get(f"{prompt.prompt_key}__evaluation", [])
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# Generated by Django 4.2.1 on 2025-01-09 21:09

from django.db import migrations, models


class Migration(migrations.Migration):

dependencies = [
("prompt_studio_v2", "0005_alter_toolstudioprompt_required"),
]

operations = [
migrations.AlterField(
model_name="toolstudioprompt",
name="enforce_type",
field=models.TextField(
blank=True,
choices=[
("Text", "Response sent as Text"),
("number", "Response sent as number"),
("email", "Response sent as email"),
("date", "Response sent as date"),
("boolean", "Response sent as boolean"),
("json", "Response sent as json"),
("table", "Response sent as table"),
(
"record",
"Response sent for records. Entries of records are list of logical and organized individual entities with distint values",
),
(
"line-item",
"Response sent as line-item which is large a JSON output. If extraction stopped due to token limitation, we try to continue extraction from where it stopped",
),
],
db_comment="Field to store the type in which the response to be returned.",
default="Text",
),
),
]
6 changes: 6 additions & 0 deletions backend/prompt_studio/prompt_studio_v2/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,12 @@ class EnforceType(models.TextChoices):
"logical and organized individual "
"entities with distint values"
)
LINE_ITEM = "line-item", (
"Response sent as line-item "
"which is large a JSON output. "
"If extraction stopped due to token limitation, "
"we try to continue extraction from where it stopped"
)
chandrasekharan-zipstack marked this conversation as resolved.
Show resolved Hide resolved

class PromptType(models.TextChoices):
PROMPT = "PROMPT", "Response sent as Text"
Expand Down
1 change: 1 addition & 0 deletions prompt-service/src/unstract/prompt_service/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ class PromptServiceContants:
REQUIRED = "required"
EXECUTION_SOURCE = "execution_source"
METRICS = "metrics"
LINE_ITEM = "line-item"


class RunLevel(Enum):
Expand Down
84 changes: 84 additions & 0 deletions prompt-service/src/unstract/prompt_service/helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,11 @@
from unstract.sdk.file_storage.constants import StorageType
from unstract.sdk.file_storage.env_helper import EnvHelper

PAID_FEATURE_MSG = (
"It is a cloud / enterprise feature. If you have purchased a plan and still "
"face this issue, please contact support"
)

load_dotenv()

# Global variable to store plugins
Expand Down Expand Up @@ -394,3 +399,82 @@ def extract_table(
except table_extractor["exception_cls"] as e:
msg = f"Couldn't extract table. {e}"
raise APIError(message=msg)


def extract_line_item(
tool_settings: dict[str, Any],
output: dict[str, Any],
plugins: dict[str, dict[str, Any]],
structured_output: dict[str, Any],
llm: LLM,
file_path: str,
metadata: Optional[dict[str, str]],
execution_source: str,
) -> dict[str, Any]:
line_item_extraction_plugin: dict[str, Any] = plugins.get(
"line-item-extraction", {}
)
if not line_item_extraction_plugin:
raise APIError(PAID_FEATURE_MSG)

extract_file_path = file_path
if execution_source == ExecutionSource.IDE.value:
# Adjust file path to read from the extract folder
base_name = os.path.splitext(os.path.basename(file_path))[0]
extract_file_path = os.path.join(
os.path.dirname(file_path), "extract", f"{base_name}.txt"
)

# Read file content into context
if check_feature_flag_status(FeatureFlag.REMOTE_FILE_STORAGE):
fs_instance: FileStorage = FileStorage(FileStorageProvider.LOCAL)
if execution_source == ExecutionSource.IDE.value:
fs_instance = EnvHelper.get_storage(
storage_type=StorageType.PERMANENT,
env_name=FileStorageKeys.PERMANENT_REMOTE_STORAGE,
)
if execution_source == ExecutionSource.TOOL.value:
fs_instance = EnvHelper.get_storage(
storage_type=StorageType.TEMPORARY,
env_name=FileStorageKeys.TEMPORARY_REMOTE_STORAGE,
)

if not fs_instance.exists(extract_file_path):
raise FileNotFoundError(
f"The file at path '{extract_file_path}' does not exist."
)
context = fs_instance.read(path=extract_file_path, encoding="utf-8", mode="rb")
else:
if not os.path.exists(extract_file_path):
raise FileNotFoundError(
f"The file at path '{extract_file_path}' does not exist."
)

with open(extract_file_path, encoding="utf-8") as file:
context = file.read()

prompt = construct_prompt(
preamble=tool_settings.get(PSKeys.PREAMBLE, ""),
prompt=output["promptx"],
postamble=tool_settings.get(PSKeys.POSTAMBLE, ""),
grammar_list=tool_settings.get(PSKeys.GRAMMAR, []),
context=context,
platform_postamble="",
)

try:
line_item_extraction = line_item_extraction_plugin["entrypoint_cls"](
llm=llm,
tool_settings=tool_settings,
output=output,
prompt=prompt,
structured_output=structured_output,
logger=current_app.logger,
)
answer = line_item_extraction.run()
structured_output[output[PSKeys.NAME]] = answer
metadata[PSKeys.CONTEXT][output[PSKeys.NAME]] = [context]
return structured_output
except line_item_extraction_plugin["exception_cls"] as e:
msg = f"Couldn't extract table. {e}"
raise APIError(message=msg)
39 changes: 39 additions & 0 deletions prompt-service/src/unstract/prompt_service/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from unstract.prompt_service.exceptions import APIError, ErrorResponse, NoPayloadError
from unstract.prompt_service.helper import (
construct_and_run_prompt,
extract_line_item,
extract_table,
extract_variable,
get_cleaned_context,
Expand Down Expand Up @@ -264,6 +265,44 @@ def prompt_processor() -> Any:
"Error while extracting table for the prompt",
)
raise api_error
elif output[PSKeys.TYPE] == PSKeys.LINE_ITEM:
try:
Deepak-Kesavan marked this conversation as resolved.
Show resolved Hide resolved
structured_output = extract_line_item(
tool_settings=tool_settings,
output=output,
plugins=plugins,
structured_output=structured_output,
llm=llm,
file_path=file_path,
metadata=metadata,
execution_source=execution_source,
)
metadata = query_usage_metadata(token=platform_key, metadata=metadata)
# TODO: Handle metrics for line-item extraction
response = {
PSKeys.METADATA: metadata,
PSKeys.OUTPUT: structured_output,
PSKeys.METRICS: metrics,
}
continue
except APIError as e:
app.logger.error(
"Failed to extract line-item for the prompt %s: %s",
output[PSKeys.NAME],
str(e),
)
publish_log(
log_events_id,
{
"tool_id": tool_id,
"prompt_key": prompt_name,
"doc_name": doc_name,
},
LogLevel.ERROR,
RunLevel.RUN,
"Error while extracting line-item for the prompt",
)
raise e

try:
if chunk_size == 0:
Expand Down
Loading