diff --git a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py index 0e16f07e28..8cb0bf10e6 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py @@ -63,12 +63,15 @@ async def get_access_token( user = await User.get_by(db, username=userinfo.username) if user is None: + default_available_workspaces = [workspace.name for workspace in settings.oauth.allowed_workspaces] + workspaces = userinfo.available_workspaces or default_available_workspaces + user = await accounts.create_user_with_random_password( db, username=userinfo.username, first_name=userinfo.first_name, role=userinfo.role, - workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces], + workspaces=workspaces, ) return Token(access_token=accounts.generate_user_token(user)) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py index 3caa90b9a3..011fce1f5e 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py @@ -13,7 +13,7 @@ # limitations under the License. import os -from typing import Type, Dict, Any, Optional +from typing import Type, Dict, Any, Optional, List from social_core.backends.oauth import BaseOAuth2 from social_core.backends.open_id_connect import OpenIdConnectAuth @@ -67,6 +67,41 @@ def oidc_endpoint(self) -> str: return value + def get_user_details(self, response: Dict[str, Any]) -> Dict[str, Any]: + user = super().get_user_details(response) + + if role := self._extract_role(response): + user["role"] = role + + if available_workspaces := self._extract_available_workspaces(response): + user["available_workspaces"] = available_workspaces + + return user + + def _extract_role(self, response: Dict[str, Any]) -> Optional[str]: + roles = self._read_realm_roles(response) + + for role in roles: + if role.startswith("argilla_role:"): + role = role.split(":")[1] + return role + + def _extract_available_workspaces(self, response: Dict[str, Any]) -> List[str]: + roles = self._read_realm_roles(response) + + workspaces = [] + for role in roles: + if role.startswith("argilla_workspace:"): + workspace = role.split(":")[1] + workspaces.append(workspace) + + return workspaces + + @classmethod + def _read_realm_roles(cls, response) -> List[str]: + realm_access = response.get("realm_access") or {} + return realm_access.get("roles") or [] + _SUPPORTED_BACKENDS = {} diff --git a/argilla-server/src/argilla_server/security/authentication/userinfo.py b/argilla-server/src/argilla_server/security/authentication/userinfo.py index 70173cda7e..5f01628e76 100644 --- a/argilla-server/src/argilla_server/security/authentication/userinfo.py +++ b/argilla-server/src/argilla_server/security/authentication/userinfo.py @@ -41,6 +41,10 @@ def role(self) -> UserRole: role = self.get("role") or self._parse_role_from_environment() return UserRole(role) + @property + def available_workspaces(self) -> Optional[list]: + return self.get("available_workspaces") + def _parse_role_from_environment(self) -> Optional[UserRole]: """This is a temporal solution, and it will be replaced by a proper Sign up process""" if self["username"] == os.getenv("USERNAME"): diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 76bc4c6b4c..be29523c40 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -19,13 +19,14 @@ from httpx import AsyncClient from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession +from sqlalchemy.orm import selectinload from argilla_server.enums import UserRole from argilla_server.errors.future import AuthenticationError from argilla_server.models import User from argilla_server.security.authentication import JWT from argilla_server.security.authentication.oauth2 import OAuth2Settings -from tests.factories import AdminFactory, AnnotatorFactory +from tests.factories import AdminFactory, AnnotatorFactory, WorkspaceFactory @pytest.fixture @@ -110,6 +111,40 @@ async def test_provider_huggingface_access_token( assert user is not None assert user.role == UserRole.annotator + async def test_provider_access_token_with_specific_userinfo_workspace( + self, + async_client: AsyncClient, + db: AsyncSession, + owner_auth_header: dict, + default_oauth_settings: OAuth2Settings, + ): + workspace = await WorkspaceFactory.create() + + with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): + with mock.patch( + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": "username", "name": "name", "available_workspaces": [workspace.name]}, + ): + response = await async_client.get( + "/api/v1/oauth2/providers/huggingface/access-token", + params={"code": "code", "state": "valid"}, + headers=owner_auth_header, + cookies={"huggingface_oauth2_state": "valid"}, + ) + + assert response.status_code == 200 + + json_response = response.json() + assert JWT.decode(json_response["access_token"])["username"] == "username" + assert json_response["token_type"] == "bearer" + + user = await db.scalar( + select(User).options(selectinload(User.workspaces)).filter_by(username="username") + ) + assert user is not None + assert user.role == UserRole.annotator + assert user.workspaces == [workspace] + async def test_provider_huggingface_access_token_with_missing_username( self, async_client: AsyncClient, diff --git a/argilla-server/tests/unit/security/authentication/oauth2/backends/__init__.py b/argilla-server/tests/unit/security/authentication/oauth2/backends/__init__.py new file mode 100644 index 0000000000..4b6cecae7f --- /dev/null +++ b/argilla-server/tests/unit/security/authentication/oauth2/backends/__init__.py @@ -0,0 +1,14 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + diff --git a/argilla-server/tests/unit/security/authentication/oauth2/backends/test_keycloack_backend.py b/argilla-server/tests/unit/security/authentication/oauth2/backends/test_keycloack_backend.py new file mode 100644 index 0000000000..c542abf771 --- /dev/null +++ b/argilla-server/tests/unit/security/authentication/oauth2/backends/test_keycloack_backend.py @@ -0,0 +1,114 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from argilla_server.security.authentication.oauth2._backends import KeycloakOpenId, Strategy + + +class TestKeyCloackOpenIdBackend: + def test_get_user_details_with_argilla_role(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2", "argilla_role:annotator"]}, + } + ) + + assert user_details["role"] == "annotator" + + def test_get_user_details_with_wrong_argilla_role_definition(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2", "argilla_role=annotator"]}, + } + ) + + assert "role" not in user_details + + def test_get_user_details_without_argilla_role(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2"]}, + } + ) + + assert "role" not in user_details + + def test_get_user_details_with_argilla_workspaces(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2", "argilla_workspace:ws1"]}, + } + ) + + assert user_details["available_workspaces"] == ["ws1"] + + def test_get_user_details_with_wrong_argilla_workspace_definition(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2", "argilla_workspace=ws1"]}, + } + ) + + assert "available_workspaces" not in user_details + + def test_get_user_details_with_multiple_argilla_workspaces(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2", "argilla_workspace:ws1", "argilla_workspace:ws2"]}, + } + ) + + assert user_details["available_workspaces"] == ["ws1", "ws2"] + + def test_get_user_details_with_missing_argilla_workspaces(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"roles": ["role1", "role2"]}, + } + ) + + assert "available_workspaces" not in user_details + + def test_get_user_details_with_missing_roles_key(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details( + { + "realm_access": {"other": "stuff"}, + } + ) + + assert "role" not in user_details + assert "available_workspaces" not in user_details + + def test_get_user_details_with_missing_realm_access_key(self): + backend = KeycloakOpenId(strategy=Strategy()) + + user_details = backend.get_user_details({"other": "stuff"}) + + assert "role" not in user_details + assert "available_workspaces" not in user_details