Skip to content

Commit

Permalink
clean up
Browse files Browse the repository at this point in the history
  • Loading branch information
paulineribeyre committed Dec 5, 2024
1 parent 492276f commit 90492de
Show file tree
Hide file tree
Showing 4 changed files with 95 additions and 104 deletions.
4 changes: 2 additions & 2 deletions .secrets.baseline
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@
"filename": "gen3workflow/config-default.yaml",
"hashed_secret": "afc848c316af1a89d49826c5ae9d00ed769415f3",
"is_verified": false,
"line_number": 31
"line_number": 30
}
],
"migrations/versions/e1886270d9d2_create_system_key_table.py": [
Expand Down Expand Up @@ -182,5 +182,5 @@
}
]
},
"generated_at": "2024-12-05T16:27:30Z"
"generated_at": "2024-12-05T22:45:14Z"
}
1 change: 0 additions & 1 deletion gen3workflow/config-default.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ MAX_IAM_KEYS_PER_USER: 2 # the default AWS AccessKeysPerUser quota is 2
IAM_KEYS_LIFETIME_DAYS: 30
USER_BUCKETS_REGION: us-east-1

S3_ENDPOINTS_AWS_ROLE_ARN:
S3_ENDPOINTS_AWS_ACCESS_KEY_ID:
S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY:

Expand Down
7 changes: 6 additions & 1 deletion gen3workflow/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,14 +48,19 @@ def validate_top_level_configs(self):
"MAX_IAM_KEYS_PER_USER": {"type": "integer", "maximum": 100},
"IAM_KEYS_LIFETIME_DAYS": {"type": "integer"},
"USER_BUCKETS_REGION": {"type": "string"},
# TODO S3_ENDPOINTS_AWS_ROLE_ARN etc
"S3_ENDPOINTS_AWS_ACCESS_KEY_ID": {"type": ["string", "null"]},
"S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY": {"type": ["string", "null"]},
"ARBORIST_URL": {"type": ["string", "null"]},
"TASK_IMAGE_WHITELIST": {"type": "array", "items": {"type": "string"}},
"TES_SERVER_URL": {"type": "string"},
},
}
validate(instance=self, schema=schema)

assert bool(self["S3_ENDPOINTS_AWS_ACCESS_KEY_ID"]) == bool(
self["S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY"]
), "Both 'S3_ENDPOINTS_AWS_ACCESS_KEY_ID' and 'S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY' must be configured, or both must be left empty"


config = Gen3WorkflowConfig(DEFAULT_CFG_PATH)
try:
Expand Down
187 changes: 87 additions & 100 deletions gen3workflow/routes/s3.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from datetime import datetime
from datetime import datetime, timezone
import hashlib
import json
import os
Expand All @@ -25,29 +25,6 @@
router = APIRouter(prefix="/s3")


async def _log_request(request, path):
# Read body as bytes, then decode it as string if necessary
body_bytes = await request.body()
try:
body = body_bytes.decode()
except UnicodeDecodeError:
body = str(body_bytes) # In case of binary data
try:
body = json.loads(body)
except:
pass # Keep body as string if not JSON

timestamp = datetime.now().isoformat()
log_entry = {
'timestamp': timestamp,
'method': request.method,
'path': path,
'headers': dict(request.headers),
'body': body,
}
logger.debug(f"Incoming request: {json.dumps(log_entry, indent=2)}")


def get_access_token(headers: Headers) -> str:
"""
Extract the user's access token, which should have been provided as the key ID, from the
Expand All @@ -63,47 +40,61 @@ def get_access_token(headers: Headers) -> str:
auth_header = headers.get("authorization")
if not auth_header:
return ""
return auth_header.split("Credential=")[1].split("/")[0]
try:
return auth_header.split("Credential=")[1].split("/")[0]
except Exception as e:
logger.error(
f"Unexpected format; unable to extract access token from authorization header: {e}"
)
return ""


def get_signature_key(key: str, date_stamp: str, region_name: str, service_name: str) -> str:
def get_signature_key(key: str, date: str, region_name: str, service_name: str) -> str:
"""
Create a signing key using the AWS Signature Version 4 algorithm.
"""
key_date = hmac.new(f"AWS4{key}".encode('utf-8'), date_stamp.encode('utf-8'), hashlib.sha256).digest()
key_region = hmac.new(key_date, region_name.encode('utf-8'), hashlib.sha256).digest()
key_service = hmac.new(key_region, service_name.encode('utf-8'), hashlib.sha256).digest()
key_date = hmac.new(
f"AWS4{key}".encode("utf-8"), date.encode("utf-8"), hashlib.sha256
).digest()
key_region = hmac.new(
key_date, region_name.encode("utf-8"), hashlib.sha256
).digest()
key_service = hmac.new(
key_region, service_name.encode("utf-8"), hashlib.sha256
).digest()
key_signing = hmac.new(key_service, b"aws4_request", hashlib.sha256).digest()
return key_signing


@router.api_route(
"/{path:path}",
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE", "HEAD"]
methods=["GET", "POST", "PUT", "DELETE", "OPTIONS", "PATCH", "TRACE", "HEAD"],
)
async def todo_rename(path: str, request: Request):
"""TODO
Args:
path (str): _description_
request (Request): _description_
Raises:
Exception: _description_
Returns:
_type_: _description_
async def s3_endpoint(path: str, request: Request):
"""
Receive incoming S3 requests, re-sign them with the appropriate credentials to access the
current user's AWS S3 bucket, and forward them to AWS S3.
"""
# await _log_request(request, path)
logger.debug(f"Incoming S3 request: '{request.method} {path}'")

# extract the user's access token from the request headers, and use it to get the name of
# the user's bucket
auth = Auth(api_request=request)
auth.bearer_token = HTTPAuthorizationCredentials(scheme="bearer", credentials=get_access_token(request.headers))
auth.bearer_token = HTTPAuthorizationCredentials(
scheme="bearer", credentials=get_access_token(request.headers)
)
token_claims = await auth.get_token_claims()
user_id = token_claims.get("sub")
user_bucket = aws_utils.get_safe_name_from_user_id(user_id)

# TODO make sure calls to bucket1 is not allowed when user's bucket is bucket12
if user_bucket not in path:
err_msg = f"'{path}' not allowed. You can make calls to your personal bucket, '{user_bucket}'"
logger.error(err_msg)
raise HTTPException(HTTP_401_UNAUTHORIZED, err_msg)

# extract the request path (used in the canonical request) and the API endpoint (used to make
# the request to AWS).
# Example 1:
# - path = my-bucket//
# - request_path = //
Expand All @@ -112,97 +103,88 @@ async def todo_rename(path: str, request: Request):
# - path = my-bucket/pre/fix/
# - request_path = /pre/fix/
# - api_endpoint = pre/fix/
if user_bucket not in path:
err_msg = f"'{path}' not allowed. You can make calls to your personal bucket, '{user_bucket}'"
logger.error(err_msg)
raise HTTPException(HTTP_401_UNAUTHORIZED, err_msg)
request_path = path.split(user_bucket)[1]
api_endpoint = "/".join(request_path.split("/")[1:])

# generate the request headers
# headers = dict(request.headers)
# headers.pop("authorization")
headers = {}
# TODO try again to include all the headers
# `x-amz-content-sha256` is sometimes set to "STREAMING-AWS4-HMAC-SHA256-PAYLOAD" in
# the original request, but i was not able to get the signing working when copying it

# the original request, but i was not able to get the signing working when copying it.
# if "Content-Type" in request.headers:
# headers["content-type"] = request.headers["Content-Type"]
headers['host'] = f'{user_bucket}.s3.amazonaws.com'

# Hash the request body
headers["host"] = f"{user_bucket}.s3.amazonaws.com"
body = await request.body()
body_hash = hashlib.sha256(body).hexdigest()
headers['x-amz-content-sha256'] = body_hash
headers["x-amz-content-sha256"] = body_hash
# headers['x-amz-content-sha256'] = request.headers['x-amz-content-sha256']
# if 'content-length' in request.headers:
# headers['content-length'] = request.headers['content-length']
# if 'x-amz-decoded-content-length' in request.headers:
# headers['x-amz-decoded-content-length'] = request.headers['x-amz-decoded-content-length']
timestamp = datetime.now(timezone.utc).strftime("%Y%m%dT%H%M%SZ")
headers["x-amz-date"] = timestamp

# Ensure 'x-amz-date' is included in the headers (it's needed for signature calculation)
amz_date = datetime.utcnow().strftime('%Y%m%dT%H%M%SZ')
headers['x-amz-date'] = amz_date

# AWS Credentials for signing
if config["S3_ENDPOINTS_AWS_ROLE_ARN"]:
session = boto3.Session()
credentials = session.get_credentials()
headers["x-amz-security-token"] = credentials.token
else:
# get AWS credentials from the configuration or the current assumed role session
if config["S3_ENDPOINTS_AWS_ACCESS_KEY_ID"]:
credentials = Credentials(
access_key=config["S3_ENDPOINTS_AWS_ACCESS_KEY_ID"],
secret_key=config["S3_ENDPOINTS_AWS_SECRET_ACCESS_KEY"],
)
else: # running in k8s: get credentials from the assumed role
session = boto3.Session()
credentials = session.get_credentials()
headers["x-amz-security-token"] = credentials.token

canon_headers = "".join(f"{key}:{headers[key]}\n" for key in sorted(list(headers.keys())))
header_names = ";".join(sorted(list(headers.keys())))

# construct the canonical request
canonical_headers = "".join(
f"{key}:{headers[key]}\n" for key in sorted(list(headers.keys()))
)
signed_headers = ";".join(sorted([k.lower() for k in headers.keys()]))
query_params = dict(request.query_params)
query_params_names = sorted(list(query_params.keys())) # query params have to be sorted
canonical_r_q_params = "&".join(f"{urllib.parse.quote_plus(key)}={urllib.parse.quote_plus(query_params[key])}" for key in query_params_names)

# Construct the canonical request (with cleaned-up path)
# the query params in the canonical request have to be sorted:
query_params_names = sorted(list(query_params.keys()))
canonical_query_params = "&".join(
f"{urllib.parse.quote_plus(key)}={urllib.parse.quote_plus(query_params[key])}"
for key in query_params_names
)
canonical_request = (
f"{request.method}\n"
f"{request_path}\n"
f"{canonical_r_q_params}\n" # Query parameters
f"{canon_headers}"
f"{canonical_query_params}\n"
f"{canonical_headers}"
f"\n"
f"{header_names}\n" # Signed headers
f"{headers['x-amz-content-sha256']}" # Final Body hash
f"{signed_headers}\n"
f"{body_hash}"
)
# logger.debug(f"- Canonical Request:\n{canonical_request}")

# Create the string to sign based on the canonical request
# construct the string to sign based on the canonical request
date = timestamp[:8] # the date portion (YYYYMMDD) of the timestamp
region = config["USER_BUCKETS_REGION"]
service = 's3'
date_stamp = headers['x-amz-date'][:8] # The date portion (YYYYMMDD)
service = "s3"
string_to_sign = (
f"AWS4-HMAC-SHA256\n"
f"{headers['x-amz-date']}\n" # The timestamp in 'YYYYMMDDTHHMMSSZ' format
f"{date_stamp}/{region}/{service}/aws4_request\n" # Credential scope
f"{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" # Hash of the Canonical Request
f"{timestamp}\n"
f"{date}/{region}/{service}/aws4_request\n" # credential scope
f"{hashlib.sha256(canonical_request.encode('utf-8')).hexdigest()}" # canonical request hash
)
# logger.debug(f"- String to Sign:\n{string_to_sign}")

# Generate the signing key using our `get_signature_key` function
signing_key = get_signature_key(credentials.secret_key, date_stamp, region, service)

# Calculate the signature by signing the string to sign with the signing key
signature = hmac.new(signing_key, string_to_sign.encode('utf-8'), hashlib.sha256).hexdigest()
# logger.debug(f"- Signature: {signature}")

# Ensure all headers that are in the request are included in the SignedHeaders
signed_headers = ';'.join(sorted([k.lower() for k in headers.keys() if k != 'authorization']))

# Log final headers before sending the request
headers['authorization'] = f"AWS4-HMAC-SHA256 Credential={credentials.access_key}/{date_stamp}/{region}/{service}/aws4_request, SignedHeaders={signed_headers}, Signature={signature}"
# logger.debug(f"- Signed Headers:\n{aws_request.headers}")

# Perform the actual HTTP request
# generate the signing key, and generate the signature by signing the string to sign with the
# signing key
signing_key = get_signature_key(credentials.secret_key, date, region, service)
signature = hmac.new(
signing_key, string_to_sign.encode("utf-8"), hashlib.sha256
).hexdigest()

# construct the Authorization header from the credentials and the signature, and forward the
# call to AWS S3 with the new Authorization header
headers["authorization"] = (
f"AWS4-HMAC-SHA256 Credential={credentials.access_key}/{date}/{region}/{service}/aws4_request, SignedHeaders={signed_headers}, Signature={signature}"
)
s3_api_url = f"https://{user_bucket}.s3.amazonaws.com/{api_endpoint}"
logger.debug(f"Making {request.method} request to {s3_api_url}")
logger.debug(f"Outgoing S3 request: '{request.method} {s3_api_url}'")
async with httpx.AsyncClient() as client:
response = await client.request(
method=request.method,
Expand All @@ -214,6 +196,11 @@ async def todo_rename(path: str, request: Request):
if response.status_code != 200:
logger.error(f"Error from AWS: {response.status_code} {response.text}")

# return the response from AWS S3
if "Content-Type" in response.headers:
return Response(content=response.content, status_code=response.status_code, media_type=response.headers['Content-Type'])
return Response(
content=response.content,
status_code=response.status_code,
media_type=response.headers["Content-Type"],
)
return Response(content=response.content, status_code=response.status_code)

0 comments on commit 90492de

Please sign in to comment.