Skip to content

Commit

Permalink
feat: Configure role and workspaces from realm access roles
Browse files Browse the repository at this point in the history
  • Loading branch information
frascuchon committed Dec 5, 2024
1 parent 0da655e commit bf2a0f6
Show file tree
Hide file tree
Showing 6 changed files with 208 additions and 3 deletions.
5 changes: 4 additions & 1 deletion argilla-server/src/argilla_server/api/handlers/v1/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,15 @@ async def get_access_token(

user = await User.get_by(db, username=userinfo.username)
if user is None:
default_available_workspaces = [workspace.name for workspace in settings.oauth.allowed_workspaces]
workspaces = userinfo.available_workspaces or default_available_workspaces

user = await accounts.create_user_with_random_password(
db,
username=userinfo.username,
first_name=userinfo.first_name,
role=userinfo.role,
workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces],
workspaces=workspaces,
)

return Token(access_token=accounts.generate_user_token(user))
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# limitations under the License.

import os
from typing import Type, Dict, Any, Optional
from typing import Type, Dict, Any, Optional, List

from social_core.backends.oauth import BaseOAuth2
from social_core.backends.open_id_connect import OpenIdConnectAuth
Expand Down Expand Up @@ -67,6 +67,41 @@ def oidc_endpoint(self) -> str:

return value

def get_user_details(self, response: Dict[str, Any]) -> Dict[str, Any]:
user = super().get_user_details(response)

if role := self._extract_role(response):
user["role"] = role

if available_workspaces := self._extract_available_workspaces(response):
user["available_workspaces"] = available_workspaces

return user

def _extract_role(self, response: Dict[str, Any]) -> Optional[str]:
roles = self._read_realm_roles(response)

for role in roles:
if role.startswith("argilla_role:"):
role = role.split(":")[1]
return role

def _extract_available_workspaces(self, response: Dict[str, Any]) -> List[str]:
roles = self._read_realm_roles(response)

workspaces = []
for role in roles:
if role.startswith("argilla_workspace:"):
workspace = role.split(":")[1]
workspaces.append(workspace)

return workspaces

@classmethod
def _read_realm_roles(cls, response) -> List[str]:
realm_access = response.get("realm_access") or {}
return realm_access.get("roles") or []


_SUPPORTED_BACKENDS = {}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ def role(self) -> UserRole:
role = self.get("role") or self._parse_role_from_environment()
return UserRole(role)

@property
def available_workspaces(self) -> Optional[list]:
return self.get("available_workspaces")

def _parse_role_from_environment(self) -> Optional[UserRole]:
"""This is a temporal solution, and it will be replaced by a proper Sign up process"""
if self["username"] == os.getenv("USERNAME"):
Expand Down
37 changes: 36 additions & 1 deletion argilla-server/tests/unit/api/handlers/v1/test_oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
from httpx import AsyncClient
from sqlalchemy import select
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import selectinload

from argilla_server.enums import UserRole
from argilla_server.errors.future import AuthenticationError
from argilla_server.models import User
from argilla_server.security.authentication import JWT
from argilla_server.security.authentication.oauth2 import OAuth2Settings
from tests.factories import AdminFactory, AnnotatorFactory
from tests.factories import AdminFactory, AnnotatorFactory, WorkspaceFactory


@pytest.fixture
Expand Down Expand Up @@ -110,6 +111,40 @@ async def test_provider_huggingface_access_token(
assert user is not None
assert user.role == UserRole.annotator

async def test_provider_access_token_with_specific_userinfo_workspace(
self,
async_client: AsyncClient,
db: AsyncSession,
owner_auth_header: dict,
default_oauth_settings: OAuth2Settings,
):
workspace = await WorkspaceFactory.create()

with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings):
with mock.patch(
"argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data",
return_value={"username": "username", "name": "name", "available_workspaces": [workspace.name]},
):
response = await async_client.get(
"/api/v1/oauth2/providers/huggingface/access-token",
params={"code": "code", "state": "valid"},
headers=owner_auth_header,
cookies={"huggingface_oauth2_state": "valid"},
)

assert response.status_code == 200

json_response = response.json()
assert JWT.decode(json_response["access_token"])["username"] == "username"
assert json_response["token_type"] == "bearer"

user = await db.scalar(
select(User).options(selectinload(User.workspaces)).filter_by(username="username")
)
assert user is not None
assert user.role == UserRole.annotator
assert user.workspaces == [workspace]

async def test_provider_huggingface_access_token_with_missing_username(
self,
async_client: AsyncClient,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2021-present, the Recognai S.L. team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from argilla_server.security.authentication.oauth2._backends import KeycloakOpenId, Strategy


class TestKeyCloackOpenIdBackend:
def test_get_user_details_with_argilla_role(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2", "argilla_role:annotator"]},
}
)

assert user_details["role"] == "annotator"

def test_get_user_details_with_wrong_argilla_role_definition(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2", "argilla_role=annotator"]},
}
)

assert "role" not in user_details

def test_get_user_details_without_argilla_role(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2"]},
}
)

assert "role" not in user_details

def test_get_user_details_with_argilla_workspaces(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2", "argilla_workspace:ws1"]},
}
)

assert user_details["available_workspaces"] == ["ws1"]

def test_get_user_details_with_wrong_argilla_workspace_definition(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2", "argilla_workspace=ws1"]},
}
)

assert "available_workspaces" not in user_details

def test_get_user_details_with_multiple_argilla_workspaces(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2", "argilla_workspace:ws1", "argilla_workspace:ws2"]},
}
)

assert user_details["available_workspaces"] == ["ws1", "ws2"]

def test_get_user_details_with_missing_argilla_workspaces(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"roles": ["role1", "role2"]},
}
)

assert "available_workspaces" not in user_details

def test_get_user_details_with_missing_roles_key(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details(
{
"realm_access": {"other": "stuff"},
}
)

assert "role" not in user_details
assert "available_workspaces" not in user_details

def test_get_user_details_with_missing_realm_access_key(self):
backend = KeycloakOpenId(strategy=Strategy())

user_details = backend.get_user_details({"other": "stuff"})

assert "role" not in user_details
assert "available_workspaces" not in user_details

0 comments on commit bf2a0f6

Please sign in to comment.