Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DM-48432: Move quota calculation to QuotaConfig model #1217

Merged
merged 1 commit into from
Jan 14, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 43 additions & 44 deletions src/gafaelfawr/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
from .exceptions import InvalidTokenError
from .keypair import RSAKeyPair
from .models.token import Token
from .models.userinfo import Quota
from .util import group_name_for_github_team

HttpsUrl = Annotated[
Expand All @@ -79,12 +80,10 @@
"GitHubGroupTeam",
"HttpsUrl",
"LDAPConfig",
"NotebookQuota",
"OIDCClient",
"OIDCConfig",
"OIDCServerConfig",
"QuotaConfig",
"QuotaGrant",
]


Expand Down Expand Up @@ -657,56 +656,16 @@ def keypair(self) -> RSAKeyPair:
return self._keypair


class NotebookQuota(BaseModel):
"""Quota settings for the Notebook Aspect."""

model_config = ConfigDict(extra="forbid")

cpu: float = Field(
..., title="CPU limit", description="Maximum number of CPU equivalents"
)

memory: float = Field(
...,
title="Memory limit (GiB)",
description="Maximum memory usage in GiB",
)


class QuotaGrant(BaseModel):
"""One grant of quotas.

There may be one of these per group, as well as a default one, in the
overall quota configuration.
"""

model_config = ConfigDict(extra="forbid")

api: dict[str, int] = Field(
{},
title="Service quotas",
description=(
"Mapping of service names to quota of requests per 15 minutes"
),
)

notebook: NotebookQuota | None = Field(
None,
title="Notebook quota",
description="Quota settings for the Notebook Aspect",
)


class QuotaConfig(BaseModel):
"""Quota configuration."""

model_config = ConfigDict(extra="forbid")

default: QuotaGrant = Field(
default: Quota = Field(
..., title="Default quota", description="Default quotas for all users"
)

groups: dict[str, QuotaGrant] = Field(
groups: dict[str, Quota] = Field(
{},
title="Quota grants by group",
description="Additional quota grants by group name",
Expand All @@ -718,6 +677,46 @@ class QuotaConfig(BaseModel):
description="Groups whose members bypass all quota restrictions",
)

def calculate_quota(self, groups: set[str]) -> Quota | None:
"""Calculate user's quota given their group membership.

Parameters
----------
groups
Group membership of the user.

Returns
-------
Quota or None
Quota information for that user or `None` if no quotas apply.
"""
if groups & self.bypass:
return None

# Start with the defaults.
api = dict(self.default.api)
notebook = None
if self.default.notebook:
notebook = self.default.notebook.model_copy()

# Look for group-specific rules.
for group in groups & set(self.groups.keys()):
extra = self.groups[group]
if extra.notebook:
if notebook:
notebook.cpu += extra.notebook.cpu
notebook.memory += extra.notebook.memory
else:
notebook = extra.notebook.model_copy()
for service, quota in extra.api.items():
if service in api:
api[service] += quota
else:
api[service] = quota

# Return the results.
return Quota(api=api, notebook=notebook)


class GitHubGroupTeam(BaseModel):
"""Specification for a GitHub team."""
Expand Down
6 changes: 5 additions & 1 deletion src/gafaelfawr/models/userinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from dataclasses import dataclass
from datetime import datetime

from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field

from ..constants import GROUPNAME_REGEX
from ..pydantic import Timestamp
Expand Down Expand Up @@ -77,6 +77,8 @@ class Group(BaseModel):
class NotebookQuota(BaseModel):
"""Notebook Aspect quota information for a user."""

model_config = ConfigDict(extra="forbid")

cpu: float = Field(..., title="CPU equivalents", examples=[4.0])

memory: float = Field(
Expand All @@ -87,6 +89,8 @@ class NotebookQuota(BaseModel):
class Quota(BaseModel):
"""Quota information for a user."""

model_config = ConfigDict(extra="forbid")

api: dict[str, int] = Field(
{},
title="API quotas",
Expand Down
61 changes: 8 additions & 53 deletions src/gafaelfawr/services/userinfo.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from ..exceptions import FirestoreError
from ..models.ldap import LDAPUserData
from ..models.token import TokenData, TokenUserInfo
from ..models.userinfo import Group, NotebookQuota, Quota, UserInfo
from ..models.userinfo import Group, UserInfo
from .firestore import FirestoreService
from .ldap import LDAPService

Expand Down Expand Up @@ -112,6 +112,12 @@ async def get_user_info_from_token(
if not gid and not ldap_data.gid and self._config.add_user_group:
gid = uid or ldap_data.uid

# Calculate the quota.
quota = None
if self._config.quota:
group_names = {g.name for g in groups}
quota = self._config.quota.calculate_quota(group_names)

# Return the results.
return UserInfo(
username=username,
Expand All @@ -120,7 +126,7 @@ async def get_user_info_from_token(
gid=gid or ldap_data.gid,
email=token_data.email or ldap_data.email,
groups=sorted(groups, key=lambda g: g.name),
quota=self._calculate_quota(groups),
quota=quota,
)

async def get_scopes(self, user_info: TokenUserInfo) -> set[str] | None:
Expand Down Expand Up @@ -210,57 +216,6 @@ async def invalidate_cache(self, username: str) -> None:
if self._ldap:
await self._ldap.invalidate_cache(username)

def _calculate_quota(self, groups: list[Group]) -> Quota | None:
"""Calculate the quota for a user.

Parameters
----------
groups
The user's group membership.

Returns
-------
gafaelfawr.models.token.Quota
Quota information for that user.
"""
if not self._config.quota:
return None
group_names = {g.name for g in groups}
if group_names & self._config.quota.bypass:
return Quota()

# Start with the defaults.
api = dict(self._config.quota.default.api)
notebook = None
if self._config.quota.default.notebook:
notebook = NotebookQuota(
cpu=self._config.quota.default.notebook.cpu,
memory=self._config.quota.default.notebook.memory,
)

# Look for group-specific rules.
for group in group_names:
if group not in self._config.quota.groups:
continue
extra = self._config.quota.groups[group]
if extra.notebook:
if notebook:
notebook.cpu += extra.notebook.cpu
notebook.memory += extra.notebook.memory
else:
notebook = NotebookQuota(
cpu=extra.notebook.cpu,
memory=extra.notebook.memory,
)
for service in extra.api:
if service in api:
api[service] += extra.api[service]
else:
api[service] = extra.api[service]

# Return the results.
return Quota(api=api, notebook=notebook)

async def _get_groups_from_ldap(
self,
username: str,
Expand Down
Loading