Skip to content

Commit

Permalink
fix getting credentials for assumed role
Browse files Browse the repository at this point in the history
  • Loading branch information
paulineribeyre committed Dec 5, 2024
1 parent d61a1fd commit e72a214
Showing 1 changed file with 26 additions and 21 deletions.
47 changes: 26 additions & 21 deletions gen3workflow/routes/s3.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import urllib.parse

import boto3
from fastapi import APIRouter, Request
from fastapi import APIRouter, HTTPException, Request
from fastapi.security import HTTPAuthorizationCredentials
from botocore.credentials import Credentials
import hmac
import httpx
from starlette.datastructures import Headers
from starlette.responses import Response
from starlette.status import HTTP_401_UNAUTHORIZED

from gen3workflow import aws_utils, logger
from gen3workflow.auth import Auth
Expand Down Expand Up @@ -80,7 +81,7 @@ def get_signature_key(key: str, date_stamp: str, region_name: str, service_name:
"/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE", "HEAD"]
)
async def catch_all_v4(path: str, request: Request):
async def todo_rename(path: str, request: Request):
"""TODO
Args:
Expand All @@ -102,11 +103,6 @@ async def catch_all_v4(path: str, request: Request):
token_claims = await auth.get_token_claims()
user_id = token_claims.get("sub")
user_bucket = aws_utils.get_safe_name_from_user_id(user_id)
user_bucket = "ga4ghtes-pauline-planx-pla-net" # TODO remove - for testing

query_params = dict(request.query_params)
query_params_names = sorted(list(query_params.keys())) # query params have to be sorted
canonical_r_q_params = "&".join(f"{urllib.parse.quote_plus(key)}={urllib.parse.quote_plus(query_params[key])}" for key in query_params_names)

# Example 1:
# - path = my-bucket//
Expand All @@ -116,6 +112,8 @@ async def catch_all_v4(path: str, request: Request):
# - path = my-bucket/pre/fix/
# - request_path = /pre/fix/
# - api_endpoint = pre/fix/
if user_bucket not in request_path:
raise HTTPException(HTTP_401_UNAUTHORIZED, f"'{path}' not allowed. You can make calls to your personal bucket, '{user_bucket}'")
request_path = path.split(user_bucket)[1]
api_endpoint = "/".join(request_path.split("/")[1:])

Expand Down Expand Up @@ -144,9 +142,30 @@ async def catch_all_v4(path: str, request: Request):
amz_date = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
headers['x-amz-date'] = amz_date

# AWS Credentials for signing
if config["S3_ENDPOINTS_AWS_ROLE_ARN"]:
# sts_client = boto3.client('sts')
# response = sts_client.assume_role(
# RoleArn=config["S3_ENDPOINTS_AWS_ROLE_ARN"],
# RoleSessionName='SessionName'
# )
# credentials = response['Credentials']
session = boto3.Session()
credentials = session.get_credentials()
headers["x-amz-security-token"] = credentials.token
else:
credentials = Credentials(
access_key=config["S3_ENDPOINTS_AWS_ACCESS_KEY_ID"],
secret_key=config["S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY"],
)

canon_headers = "".join(f"{key}:{headers[key]}\n" for key in sorted(list(headers.keys())))
header_names = ";".join(sorted(list(headers.keys())))

query_params = dict(request.query_params)
query_params_names = sorted(list(query_params.keys())) # query params have to be sorted
canonical_r_q_params = "&".join(f"{urllib.parse.quote_plus(key)}={urllib.parse.quote_plus(query_params[key])}" for key in query_params_names)

# Construct the canonical request (with cleaned-up path)
canonical_request = (
f"{request.method}\n"
Expand All @@ -159,20 +178,6 @@ async def catch_all_v4(path: str, request: Request):
)
# logger.debug(f"- Canonical Request:\n{canonical_request}")

# AWS Credentials for signing
if config["S3_ENDPOINTS_AWS_ROLE_ARN"]:
sts_client = boto3.client('sts')
response = sts_client.assume_role(
RoleArn=config["S3_ENDPOINTS_AWS_ROLE_ARN"],
RoleSessionName='SessionName'
)
credentials = response['Credentials']
else:
credentials = Credentials(
access_key=config["S3_ENDPOINTS_AWS_ACCESS_KEY_ID"],
secret_key=config["S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY"],
)

# Create the string to sign based on the canonical request
region = config["USER_BUCKETS_REGION"]
service = 's3'
Expand Down

0 comments on commit e72a214

Please sign in to comment.