Skip to content

Commit

Permalink
Merge pull request #932 from roboflow/feature/add-auth-middleware
Browse files Browse the repository at this point in the history
Add env-injectable headers to RF API requests
  • Loading branch information
PawelPeczek-Roboflow authored Jan 9, 2025
2 parents 0f75f11 + d580daf commit d149fb7
Show file tree
Hide file tree
Showing 5 changed files with 130 additions and 18 deletions.
2 changes: 2 additions & 0 deletions inference/core/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
else "https://api.roboflow.one"
),
)
# extra headers expected to be serialised json
ROBOFLOW_API_EXTRA_HEADERS = os.getenv("ROBOFLOW_API_EXTRA_HEADERS")

# Debug flag for the API, default is False
API_DEBUG = os.getenv("API_DEBUG", False)
Expand Down
33 changes: 30 additions & 3 deletions inference/core/roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from inference.core.env import (
API_BASE_URL,
MODEL_CACHE_DIR,
ROBOFLOW_API_EXTRA_HEADERS,
USE_FILE_CACHE_FOR_WORKFLOWS_DEFINITIONS,
WORKFLOWS_DEFINITION_CACHE_EXPIRY,
)
Expand Down Expand Up @@ -147,6 +148,7 @@ def add_custom_metadata(
}
]
},
headers=build_roboflow_api_headers(),
)
api_key_safe_raise_for_status(response=response)

Expand Down Expand Up @@ -318,10 +320,13 @@ def register_image_at_roboflow(
"file": ("imageToUpload", image_bytes, "image/jpeg"),
}
)
headers = build_roboflow_api_headers(
explicit_headers={"Content-Type": m.content_type},
)
response = requests.post(
url=wrapped_url,
data=m,
headers={"Content-Type": m.content_type},
headers=headers,
)
api_key_safe_raise_for_status(response=response)
parsed_response = response.json()
Expand Down Expand Up @@ -357,10 +362,13 @@ def annotate_image_at_roboflow(
("prediction", str(is_prediction).lower()),
]
wrapped_url = wrap_url(_add_params_to_url(url=url, params=params))
headers = build_roboflow_api_headers(
explicit_headers={"Content-Type": "text/plain"},
)
response = requests.post(
wrapped_url,
data=annotation_content,
headers={"Content-Type": "text/plain"},
headers=headers,
)
api_key_safe_raise_for_status(response=response)
parsed_response = response.json()
Expand Down Expand Up @@ -597,7 +605,10 @@ def get_from_url(


def _get_from_url(url: str, json_response: bool = True) -> Union[Response, dict]:
response = requests.get(wrap_url(url))
response = requests.get(
wrap_url(url),
headers=build_roboflow_api_headers(),
)
api_key_safe_raise_for_status(response=response)
if json_response:
return response.json()
Expand Down Expand Up @@ -627,5 +638,21 @@ def send_inference_results_to_model_monitoring(
response = requests.post(
url=api_url,
json=inference_data,
headers=build_roboflow_api_headers(),
)
api_key_safe_raise_for_status(response=response)


def build_roboflow_api_headers(
explicit_headers: Optional[Dict[str, Union[str, List[str]]]] = None,
) -> Optional[Dict[str, Union[List[str]]]]:
if not ROBOFLOW_API_EXTRA_HEADERS:
return explicit_headers
try:
extra_headers: dict = json.loads(ROBOFLOW_API_EXTRA_HEADERS)
if explicit_headers:
extra_headers.update(explicit_headers)
return extra_headers
except ValueError:
logger.warning("Could not decode ROBOFLOW_API_EXTRA_HEADERS")
return explicit_headers
16 changes: 3 additions & 13 deletions inference/usage_tracking/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,31 +6,21 @@
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings, SettingsConfigDict

from inference.core.env import PROJECT
from inference.core.env import API_BASE_URL
from inference.core.utils.url_utils import wrap_url


class TelemetrySettings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="telemetry_")

api_usage_endpoint_url: str = "https://api.roboflow.com/usage/inference"
api_plan_endpoint_url: str = "https://api.roboflow.com/usage/plan"
api_usage_endpoint_url: str = wrap_url(f"{API_BASE_URL}/usage/inference")
api_plan_endpoint_url: str = wrap_url(f"{API_BASE_URL}/usage/plan")
flush_interval: int = Field(default=10, ge=10, le=300)
opt_out: Optional[bool] = False
queue_size: int = Field(default=10, ge=10, le=10000)

@model_validator(mode="after")
def check_values(cls, inst: TelemetrySettings):
if PROJECT == "roboflow-platform":
inst.api_usage_endpoint_url = wrap_url(
"https://api.roboflow.com/usage/inference"
)
inst.api_plan_endpoint_url = wrap_url("https://api.roboflow.com/usage/plan")
else:
inst.api_usage_endpoint_url = wrap_url(
"https://api.roboflow.one/usage/inference"
)
inst.api_plan_endpoint_url = wrap_url("https://api.roboflow.one/usage/plan")
inst.flush_interval = min(max(inst.flush_interval, 10), 300)
inst.queue_size = min(max(inst.queue_size, 10), 10000)
return inst
Expand Down
7 changes: 6 additions & 1 deletion inference/usage_tracking/payload_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@

import requests

from inference.core.roboflow_api import build_roboflow_api_headers

ResourceID = str
Usage = Union[DefaultDict[str, Any], Dict[str, Any]]
ResourceUsage = Union[DefaultDict[ResourceID, Usage], Dict[ResourceID, Usage]]
Expand Down Expand Up @@ -160,11 +162,14 @@ def send_usage_payload(
if "api_key_hash" in workflow_payload:
del workflow_payload["api_key_hash"]
workflow_payload["api_key"] = api_key
headers = build_roboflow_api_headers(
explicit_headers={"Authorization": f"Bearer {api_key}"}
)
response = requests.post(
api_usage_endpoint_url,
json=complete_workflow_payloads,
verify=ssl_verify,
headers={"Authorization": f"Bearer {api_key}"},
headers=headers,
timeout=1,
)
except Exception:
Expand Down
90 changes: 89 additions & 1 deletion tests/inference/unit_tests/core/test_roboflow_api.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@
get_workflow_specification,
raise_from_lambda,
register_image_at_roboflow,
wrap_roboflow_api_errors,
wrap_roboflow_api_errors, build_roboflow_api_headers,
)
from inference.core.utils.url_utils import wrap_url

Expand Down Expand Up @@ -224,6 +224,11 @@ def test_get_roboflow_workspace_when_workspace_id_is_empty(
assert requests_mock.last_request.query == "api_key=my_api_key&nocache=true"


@mock.patch.object(
roboflow_api,
"ROBOFLOW_API_EXTRA_HEADERS",
json.dumps({"extra": "header"})
)
def test_get_roboflow_workspace_when_response_is_valid(requests_mock: Mocker) -> None:
# given
requests_mock.get(
Expand Down Expand Up @@ -2179,3 +2184,86 @@ def test_get_workflow_specification_when_valid_response_given_on_consecutive_req
}
)
assert len(ephemeral_cache.cache) == 1, "Expected cache content to appear"


@mock.patch.object(roboflow_api, "ROBOFLOW_API_EXTRA_HEADERS", None)
def test_build_roboflow_api_headers_when_no_extra_headers() -> None:
# when
result = build_roboflow_api_headers()

# then
assert result is None


@mock.patch.object(roboflow_api, "ROBOFLOW_API_EXTRA_HEADERS", None)
def test_build_roboflow_api_headers_when_no_extra_headers_but_explicit_headers_given() -> None:
# when
result = build_roboflow_api_headers(explicit_headers={"my": "header"})

# then
assert result == {"my": "header"}, "Expected to preserve explicit header"


@mock.patch.object(
roboflow_api,
"ROBOFLOW_API_EXTRA_HEADERS",
json.dumps({"extra": "header", "another": "extra"})
)
def test_build_roboflow_api_headers_when_extra_headers_given() -> None:
# when
result = build_roboflow_api_headers()

# then
assert result == {"extra": "header", "another": "extra"}, "Expected extra headers to be decoded"


@mock.patch.object(
roboflow_api,
"ROBOFLOW_API_EXTRA_HEADERS",
json.dumps({"extra": "header", "another": "extra"})
)
def test_build_roboflow_api_headers_when_extra_headers_given_and_explicit_headers_present() -> None:
# when
result = build_roboflow_api_headers(explicit_headers={"my": "header"})

# then
assert result == {
"my": "header",
"extra": "header",
"another": "extra",
}, "Expected extra headers to be decoded and shipped along with explicit headers"


@mock.patch.object(
roboflow_api,
"ROBOFLOW_API_EXTRA_HEADERS",
"For sure not a JSON :)"
)
def test_build_roboflow_api_headers_when_extra_headers_given_as_invalid_json() -> None:
# when
result = build_roboflow_api_headers(explicit_headers={"my": "header"})

# then
assert result == {
"my": "header",
}, "Expected extra headers to be decoded and shipped along with explicit headers"


@mock.patch.object(
roboflow_api,
"ROBOFLOW_API_EXTRA_HEADERS",
json.dumps({"extra": "header", "another": "extra"})
)
def test_build_roboflow_api_headers_when_extra_headers_given_and_explicit_headers_collide_with_extras() -> None:
# when
result = build_roboflow_api_headers(explicit_headers={
"extra": "explicit-is-better",
"my": "header",
})

# then
assert result == {
"another": "extra",
"extra": "explicit-is-better",
"my": "header",
}, "Expected extra headers to be decoded and explicit header to override implicit one"

0 comments on commit d149fb7

Please sign in to comment.