Skip to content

Commit

Permalink
select improvements from feat/system_keys
Browse files Browse the repository at this point in the history
  • Loading branch information
paulineribeyre committed Dec 16, 2024
1 parent 4a716bf commit 458f9d8
Show file tree
Hide file tree
Showing 12 changed files with 82 additions and 56 deletions.
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@
"filename": "tests/conftest.py",
"hashed_secret": "0dd78d9147bb410f0cb0199c5037da36594f77d8",
"is_verified": false,
"line_number": 222
"line_number": 224
}
],
"tests/migrations/test_migration_e1886270d9d2.py": [
Expand All @@ -209,5 +209,5 @@
}
]
},
"generated_at": "2024-12-12T23:42:54Z"
"generated_at": "2024-12-16T16:26:23Z"
}
4 changes: 4 additions & 0 deletions docs/authorization.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

The Gen3 Workflow endpoints are protected by Arborist policies.

Contents:
- [GA4GH TES](#ga4gh-tes)
- [Authorization configuration example](#authorization-configuration-example)

## GA4GH TES

- To create a task, users need `create` access to resource `/services/workflow/gen3-workflow/tasks` on service `gen3-workflow`.
Expand Down
3 changes: 3 additions & 0 deletions docs/local_installation.md
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,9 @@ aws {
}
workDir = '<your working directory>'
```

A Gen3 access token is expected by most endpoints to verify the user's access (see [Authorization](authorization.md) documentation).

> The Gen3Workflow URL should be set to `http://localhost:8080` in this case; this is where the service runs by default when started with `python run.py`.
- Run a workflow:
Expand Down
13 changes: 10 additions & 3 deletions gen3workflow/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,11 +25,11 @@ def __init__(
self,
api_request: Request,
bearer_token: HTTPAuthorizationCredentials = Security(bearer),
):
) -> None:
self.arborist_client = api_request.app.arborist_client
self.bearer_token = bearer_token

def get_access_token(self):
def get_access_token(self) -> str:
if config["MOCK_AUTH"]:
return "123"

Expand Down Expand Up @@ -95,7 +95,14 @@ async def authorize(

return authorized

async def grant_user_access_to_their_own_tasks(self, username, user_id):
async def grant_user_access_to_their_own_tasks(self, username, user_id) -> None:
"""
Ensure the specified user exists in Arborist and has a policy granting them access to their
own Gen3Workflow tasks ("read" and "delete" access to resource "/users/<user ID>/gen3-workflow/tasks" for service "gen3-workflow").
Args:
username (str): The user's Gen3 username
user_id (str): The user's unique Gen3 ID
"""
logger.info(
f"Granting user '{username}' access to their own tasks if they don't already have it"
)
Expand Down
10 changes: 3 additions & 7 deletions gen3workflow/aws_utils.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,15 @@
import json
from typing import Tuple

import boto3
from botocore.exceptions import ClientError
from fastapi import HTTPException
from starlette.status import HTTP_404_NOT_FOUND

from gen3workflow import logger
from gen3workflow.config import config


iam_client = boto3.client("iam")
iam_resp_err = "Unexpected response from AWS IAM"


def get_safe_name_from_user_id(user_id):
def get_safe_name_from_user_id(user_id: str) -> str:
"""
Generate a valid IAM user name or S3 bucket name for the specified user.
- IAM user names can contain up to 64 characters. They can only contain alphanumeric characters
Expand All @@ -36,7 +32,7 @@ def get_safe_name_from_user_id(user_id):
return safe_name


def create_user_bucket(user_id):
def create_user_bucket(user_id: str) -> Tuple[str, str, str]:
"""
Create an S3 bucket for the specified user and return information about the bucket.
Expand Down
2 changes: 1 addition & 1 deletion gen3workflow/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ DOCS_URL_PREFIX: /gen3workflow
ARBORIST_URL:

# /!\ only use for development! Allows running gen3workflow locally without Arborist interaction
MOCK_AUTH: false # TODO add to config validation. Also add "no unexpected props" to validation.
MOCK_AUTH: false

MAX_IAM_KEYS_PER_USER: 2 # the default AWS AccessKeysPerUser quota is 2
IAM_KEYS_LIFETIME_DAYS: 30
Expand Down
14 changes: 11 additions & 3 deletions gen3workflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,21 +36,29 @@ def validate(self) -> None:
logger.info("Validating configuration")
self.validate_top_level_configs()

def validate_top_level_configs(self):
def validate_top_level_configs(self) -> None:
schema = {
"type": "object",
"additionalProperties": True,
"additionalProperties": False,
"properties": {
"HOSTNAME": {"type": "string"},
"DEBUG": {"type": "boolean"},
"DOCS_URL_PREFIX": {"type": "string"},
"ARBORIST_URL": {"type": ["string", "null"]},
"MOCK_AUTH": {"type": "boolean"},
# aws_utils.list_iam_user_keys should be updated to fetch paginated results if >100
"MAX_IAM_KEYS_PER_USER": {"type": "integer", "maximum": 100},
"IAM_KEYS_LIFETIME_DAYS": {"type": "integer"},
"USER_BUCKETS_REGION": {"type": "string"},
"S3_ENDPOINTS_AWS_ACCESS_KEY_ID": {"type": ["string", "null"]},
"S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY": {"type": ["string", "null"]},
"ARBORIST_URL": {"type": ["string", "null"]},
"DB_DRIVER": {"type": "string"},
"DB_HOST": {"type": "string"},
"DB_PORT": {"type": "integer"},
"DB_USER": {"type": "string"},
"DB_PASSWORD": {"type": "string"},
"DB_DATABASE": {"type": "string"},
"DB_CONNECTION_STRING": {"type": "string"},
"TASK_IMAGE_WHITELIST": {"type": "array", "items": {"type": "string"}},
"TES_SERVER_URL": {"type": "string"},
},
Expand Down
18 changes: 12 additions & 6 deletions gen3workflow/routes/ga4gh_tes.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ async def get_request_body(request: Request):


@router.get("/service-info", status_code=HTTP_200_OK)
async def service_info(request: Request):
async def service_info(request: Request) -> dict:
url = f"{config['TES_SERVER_URL']}/service-info"
res = await request.app.async_client.get(url)
if res.status_code != HTTP_200_OK:
Expand Down Expand Up @@ -72,7 +72,7 @@ def get_non_allowed_images(images: set, username: str) -> set:


@router.post("/tasks", status_code=HTTP_200_OK)
async def create_task(request: Request, auth=Depends(Auth)):
async def create_task(request: Request, auth=Depends(Auth)) -> dict:
await auth.authorize("create", ["/services/workflow/gen3-workflow/tasks"])

body = await get_request_body(request)
Expand Down Expand Up @@ -124,11 +124,17 @@ async def create_task(request: Request, auth=Depends(Auth)):
return res.json()


def apply_view_to_task(view, task) -> dict:
def apply_view_to_task(view: str, task: dict) -> dict:
"""
We always set the view to "FULL" when making get/list requests to the TES server, because we
need to get the AUTHZ tag in order to check whether users have access. This function applies
the view that was originally requested by removing fields according to the TES spec.
Args:
view (str): view to apply (FULL, MINIMAL or BASIC). If None, MINIMAL is applied.
task (dict): TES task
Returns:
dict: TES task with applied view
"""
if view == "FULL":
return task
Expand All @@ -149,7 +155,7 @@ def apply_view_to_task(view, task) -> dict:


@router.get("/tasks", status_code=HTTP_200_OK)
async def list_tasks(request: Request, auth=Depends(Auth)):
async def list_tasks(request: Request, auth=Depends(Auth)) -> dict:
supported_params = {
"name_prefix",
"state",
Expand Down Expand Up @@ -209,7 +215,7 @@ async def list_tasks(request: Request, auth=Depends(Auth)):


@router.get("/tasks/{task_id}", status_code=HTTP_200_OK)
async def get_task(request: Request, task_id: str, auth=Depends(Auth)):
async def get_task(request: Request, task_id: str, auth=Depends(Auth)) -> dict:
supported_params = {"view"}
query_params = {
k: v for k, v in dict(request.query_params).items() if k in supported_params
Expand Down Expand Up @@ -239,7 +245,7 @@ async def get_task(request: Request, task_id: str, auth=Depends(Auth)):


@router.post("/tasks/{task_id}:cancel", status_code=HTTP_200_OK)
async def cancel_task(request: Request, task_id: str, auth=Depends(Auth)):
async def cancel_task(request: Request, task_id: str, auth=Depends(Auth)) -> dict:
# check if this user has access to delete this task
url = f"{config['TES_SERVER_URL']}/tasks/{task_id}?view=FULL"
res = await request.app.async_client.get(url)
Expand Down
2 changes: 1 addition & 1 deletion gen3workflow/routes/storage.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


@router.get("/info", status_code=HTTP_200_OK)
async def get_storage_info(request: Request, auth=Depends(Auth)):
async def get_storage_info(request: Request, auth=Depends(Auth)) -> dict:
token_claims = await auth.get_token_claims()
user_id = token_claims.get("sub")
bucket_name, bucket_prefix, bucket_region = aws_utils.create_user_bucket(user_id)
Expand Down
62 changes: 32 additions & 30 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
See https://github.com/uc-cdis/gen3-user-data-library/blob/main/tests/conftest.py#L1
"""

import asyncio
from datetime import datetime
from dateutil.tz import tzutc
import json
Expand All @@ -16,18 +15,18 @@
import pytest_asyncio
from sqlalchemy.ext.asyncio import async_sessionmaker, create_async_engine
from starlette.config import environ
from threading import Thread
import uvicorn

# Set GEN3WORKFLOW_CONFIG_PATH *before* loading the app, which loads the configuration
# Set up the config *before* loading the app, which loads the configuration
CURRENT_DIR = os.path.dirname(os.path.realpath(__file__))
environ["GEN3WORKFLOW_CONFIG_PATH"] = os.path.join(
CURRENT_DIR, "test-gen3workflow-config.yaml"
)
from gen3workflow.config import config

config.validate()

from gen3workflow.app import get_app
from gen3workflow.config import config
from gen3workflow.models import Base
from tests.migrations.migration_utils import MigrationRunner


TEST_USER_ID = "64"
Expand Down Expand Up @@ -64,44 +63,47 @@
}


@pytest_asyncio.fixture(scope="function")
async def engine():
@pytest_asyncio.fixture(scope="session", autouse=True)
async def migrate_database_to_the_latest():
"""
Non-session scoped engine which recreates the database, yields, then drops the tables
Migrate the database to the latest version before running the tests.
"""
engine = create_async_engine(
config["DB_CONNECTION_STRING"], echo=False, future=True
)
migration_runner = MigrationRunner()
await migration_runner.upgrade("head")

async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
await conn.run_sync(Base.metadata.create_all)

yield engine
@pytest_asyncio.fixture(scope="function")
async def reset_database():
"""
Most tests do not store data in the database, so for performance this fixture does not
autorun. To be used in tests that interact with the database.
"""
migration_runner = MigrationRunner()
await migration_runner.downgrade("base")
await migration_runner.upgrade("head")

async with engine.begin() as conn:
await conn.run_sync(Base.metadata.drop_all)
yield

await engine.dispose()
await migration_runner.downgrade("base")
await migration_runner.upgrade("head")


@pytest_asyncio.fixture()
async def session(engine):
@pytest_asyncio.fixture(scope="function")
async def session():
"""
Database session which utilizes the above engine and event loop and sets up a nested transaction before yielding.
It rolls back the nested transaction after yield.
Database session
"""
event_loop = asyncio.get_running_loop()
engine = create_async_engine(
config["DB_CONNECTION_STRING"], echo=False, future=True
)
session_maker = async_sessionmaker(
engine, expire_on_commit=False, autocommit=False, autoflush=False
)

async with engine.connect() as conn:
tsx = await conn.begin()
async with session_maker(bind=conn) as session:
yield session
async with session_maker() as session:
yield session

await tsx.rollback()
await engine.dispose()


@pytest.fixture(scope="function")
Expand Down Expand Up @@ -352,7 +354,7 @@ def run_uvicorn():
else:
# the tests use a real httpx client that forwards requests to the app
async with httpx.AsyncClient(
app=app, base_url="http://test-gen3-wf"
transport=httpx.ASGITransport(app=app), base_url="http://test-gen3-wf"
) as real_httpx_client:
# for easier access to the param in the tests
real_httpx_client.tes_resp_code = tes_resp_code
Expand Down
4 changes: 2 additions & 2 deletions tests/migrations/migration_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,6 @@ def _run_command(connection):
else:
raise Exception(f"Unknown MigrationRunner action '{self.action}'")

async_engine = create_async_engine(config["DB_CONNECTION_STRING"], echo=True)
async with async_engine.begin() as conn:
engine = create_async_engine(config["DB_CONNECTION_STRING"], echo=True)
async with engine.begin() as conn:
await conn.run_sync(_run_command)
2 changes: 1 addition & 1 deletion tests/migrations/test_migration_e1886270d9d2.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@


@pytest.mark.asyncio
async def test_e1886270d9d2_upgrade(session):
async def test_e1886270d9d2_upgrade(session, reset_database):
# state before the migration
migration_runner = MigrationRunner()
await migration_runner.downgrade("base")
Expand Down

0 comments on commit 458f9d8

Please sign in to comment.