diff --git a/argilla-frontend/components/features/login/OAuthLogin.vue b/argilla-frontend/components/features/login/OAuthLogin.vue index 9ca0a29b73..3b4053c77e 100644 --- a/argilla-frontend/components/features/login/OAuthLogin.vue +++ b/argilla-frontend/components/features/login/OAuthLogin.vue @@ -1,13 +1,19 @@ @@ -29,6 +35,14 @@ export default { display: flex; flex-direction: column; gap: $base-space * 3; + &__providers { + display: flex; + flex-direction: column; + gap: $base-space; + justify-content: center; + padding: 0; + list-style: none; + } } } diff --git a/argilla-frontend/components/features/login/components/HuggingFaceButton.vue b/argilla-frontend/components/features/login/components/HuggingFaceButton.vue index e50eaa90fd..d825ce9bb6 100644 --- a/argilla-frontend/components/features/login/components/HuggingFaceButton.vue +++ b/argilla-frontend/components/features/login/components/HuggingFaceButton.vue @@ -16,6 +16,7 @@ export default { background: var(--color-black); color: var(--color-white); width: 100%; + min-height: $base-space * 6; padding: calc($base-space / 2) $base-space * 4; justify-content: center; &:hover { diff --git a/argilla-frontend/components/features/login/components/OAuthLoginButton.vue b/argilla-frontend/components/features/login/components/OAuthLoginButton.vue new file mode 100644 index 0000000000..3f906066e9 --- /dev/null +++ b/argilla-frontend/components/features/login/components/OAuthLoginButton.vue @@ -0,0 +1,45 @@ + + + + diff --git a/argilla-frontend/translation/de.js b/argilla-frontend/translation/de.js index 7a66cdf6db..98f1ddf4c7 100644 --- a/argilla-frontend/translation/de.js +++ b/argilla-frontend/translation/de.js @@ -118,6 +118,7 @@ export default { button: { ignore_and_continue: "Ignorieren und fortfahren", login: "Anmelden", + signin_with_provider: "Anmeldung bei {provider} starten", "hf-login": "Mit Hugging Face anmelden", sign_in_with_username: "Mit Benutzername anmelden", cancel: "Abbrechen", diff --git a/argilla-frontend/translation/en.js b/argilla-frontend/translation/en.js index c0f7e5d47e..4693035f75 100644 --- a/argilla-frontend/translation/en.js +++ b/argilla-frontend/translation/en.js @@ -116,6 +116,7 @@ export default { button: { ignore_and_continue: "Ignore and continue", login: "Sign in", + signin_with_provider: "Sign in with {provider}", "hf-login": "Sign in with Hugging Face", sign_in_with_username: "Sign in with username", cancel: "Cancel", diff --git a/argilla-frontend/translation/es.js b/argilla-frontend/translation/es.js index 927249e148..51ff1f41a3 100644 --- a/argilla-frontend/translation/es.js +++ b/argilla-frontend/translation/es.js @@ -115,6 +115,7 @@ export default { button: { ignore_and_continue: "Ignorar y continuar", login: "Iniciar sesión", + signin_with_provider: "Iniciar sesión con {provider}", "hf-login": "Iniciar sesión con Hugging Face", sign_in_with_username: "Iniciar sesión con nombre de usuario", cancel: "Cancelar", diff --git a/argilla-server/src/argilla_server/_app.py b/argilla-server/src/argilla_server/_app.py index 05ad3fae04..39187aeb07 100644 --- a/argilla-server/src/argilla_server/_app.py +++ b/argilla-server/src/argilla_server/_app.py @@ -216,9 +216,6 @@ def _show_telemetry_warning(): async def _create_oauth_allowed_workspaces(db: AsyncSession): from argilla_server.security.settings import settings as security_settings - if not security_settings.oauth.enabled: - return - for allowed_workspace in security_settings.oauth.allowed_workspaces: if await Workspace.get_by(db, name=allowed_workspace.name) is None: _LOGGER.info(f"Creating workspace with name {allowed_workspace.name!r}") diff --git a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py index 6257da9601..0e16f07e28 100644 --- a/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py +++ b/argilla-server/src/argilla_server/api/handlers/v1/oauth2.py @@ -11,20 +11,16 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from typing import Optional from fastapi import APIRouter, Depends, Request, Path from fastapi.responses import RedirectResponse from sqlalchemy.ext.asyncio import AsyncSession from argilla_server.api.schemas.v1.oauth2 import Provider, Providers, Token -from argilla_server.api.schemas.v1.users import UserCreate from argilla_server.contexts import accounts from argilla_server.database import get_async_db -from argilla_server.enums import UserRole from argilla_server.errors.future import NotFoundError from argilla_server.models import User -from pydantic import Field from argilla_server.security.authentication.oauth2 import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo from argilla_server.security.settings import settings @@ -32,34 +28,21 @@ router = APIRouter(prefix="/oauth2", tags=["Authentication"]) -class UserOAuthCreate(UserCreate): - """This schema is used to validate the creation of a new user by using the oauth userinfo""" - - username: str = Field(min_length=1) - role: Optional[UserRole] - password: Optional[str] = None - - def get_provider_by_name_or_raise(provider: str = Path()) -> OAuth2ClientProvider: - if not settings.oauth.enabled: - raise NotFoundError(message="OAuth2 is not enabled") - - if provider in settings.oauth.providers: + try: return settings.oauth.providers[provider] - - raise NotFoundError(message=f"OAuth Provider '{provider}' not found") + except KeyError: + raise NotFoundError(message=f"OAuth Provider '{provider}' not found") @router.get("/providers", response_model=Providers) def list_providers() -> Providers: - if not settings.oauth.enabled: - return Providers(items=[]) - - return Providers(items=[Provider(name=provider_name) for provider_name in settings.oauth.providers]) + providers = [Provider(name=provider_name) for provider_name in settings.oauth.providers] + return Providers(items=providers) @router.get("/providers/{provider}/authentication") -def get_authentication( +async def get_authentication( request: Request, provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), ) -> RedirectResponse: @@ -72,7 +55,8 @@ async def get_access_token( provider: OAuth2ClientProvider = Depends(get_provider_by_name_or_raise), db: AsyncSession = Depends(get_async_db), ) -> Token: - userinfo = UserInfo(await provider.get_user_data(request)).use_claims(provider.claims) + user_data = await provider.get_user_data(request) + userinfo = UserInfo(user_data) if not userinfo.username: raise RuntimeError("OAuth error: Missing username") @@ -81,11 +65,9 @@ async def get_access_token( if user is None: user = await accounts.create_user_with_random_password( db, - **UserOAuthCreate( - username=userinfo.username, - first_name=userinfo.first_name, - role=userinfo.role, - ).model_dump(exclude_unset=True), + username=userinfo.username, + first_name=userinfo.first_name, + role=userinfo.role, workspaces=[workspace.name for workspace in settings.oauth.allowed_workspaces], ) diff --git a/argilla-server/src/argilla_server/contexts/accounts.py b/argilla-server/src/argilla_server/contexts/accounts.py index f12fbadf0c..ddf65c342b 100644 --- a/argilla-server/src/argilla_server/contexts/accounts.py +++ b/argilla-server/src/argilla_server/contexts/accounts.py @@ -26,6 +26,7 @@ from argilla_server.models import User, Workspace, WorkspaceUser from argilla_server.security.authentication.jwt import JWT from argilla_server.security.authentication.userinfo import UserInfo +from argilla_server.validators.users import UserCreateValidator async def create_workspace_user(db: AsyncSession, workspace_user_attrs: dict) -> WorkspaceUser: @@ -52,7 +53,7 @@ async def list_workspaces(db: AsyncSession) -> List[Workspace]: return result.scalars().all() -async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> List[Workspace]: +async def list_workspaces_by_user_id(db: AsyncSession, user_id: UUID) -> Sequence[Workspace]: result = await db.execute( select(Workspace) .join(WorkspaceUser) @@ -102,22 +103,22 @@ async def list_users_by_ids(db: AsyncSession, ids: Iterable[UUID]) -> Sequence[U return result.scalars().all() -# TODO: After removing API v0 implementation we can remove the workspaces attribute. -# With API v1 the workspaces will be created doing additional requests to other endpoints for it. -async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List[str], None] = None) -> User: - if await get_user_by_username(db, user_attrs["username"]) is not None: - raise NotUniqueError(f"User username `{user_attrs['username']}` is not unique") - - user = await User.create( - db, +async def create_user( + db: AsyncSession, + user_attrs: dict, + workspaces: Union[List[str], None] = None, +) -> User: + new_user = User( first_name=user_attrs["first_name"], last_name=user_attrs["last_name"], username=user_attrs["username"], role=user_attrs["role"], password_hash=hash_password(user_attrs["password"]), - autocommit=False, ) + await UserCreateValidator.validate(db, user=new_user) + + await new_user.save(db, autocommit=False) if workspaces is not None: for workspace_name in workspaces: workspace = await Workspace.get_by(db, name=workspace_name) @@ -127,13 +128,13 @@ async def create_user(db: AsyncSession, user_attrs: dict, workspaces: Union[List await WorkspaceUser.create( db, workspace_id=workspace.id, - user_id=user.id, + user_id=new_user.id, autocommit=False, ) await db.commit() - return user + return new_user async def create_user_with_random_password( diff --git a/argilla-server/src/argilla_server/security/authentication/claims.py b/argilla-server/src/argilla_server/security/authentication/claims.py deleted file mode 100644 index a34696a685..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/claims.py +++ /dev/null @@ -1,40 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -import os -from typing import Any, Callable, Union, Optional - -from argilla_server.enums import UserRole - - -def _parse_role_from_environment(userinfo: dict) -> Optional[UserRole]: - """This is a temporal solution, and it will be replaced by a proper Sign up process""" - if userinfo["username"] == os.getenv("USERNAME"): - return UserRole.owner - - -class Claims(dict): - """Claims configuration for a single provider.""" - - display_name: Union[str, Callable[[dict], Any]] - identity: Union[str, Callable[[dict], Any]] - picture: Union[str, Callable[[dict], Any]] - email: Union[str, Callable[[dict], Any]] - - def __init__(self, seq=None, **kwargs) -> None: - super().__init__(seq or {}, **kwargs) - self["display_name"] = kwargs.get("display_name", self.get("display_name", "name")) - self["identity"] = kwargs.get("identity", self.get("identity", "sub")) - self["picture"] = kwargs.get("picture", self.get("picture", "picture")) - self["email"] = kwargs.get("email", self.get("email", "email")) - self["role"] = kwargs.get("role", _parse_role_from_environment) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py b/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py index f8dcc52a18..f5605ab591 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/__init__.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider # noqa +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider # noqa from argilla_server.security.authentication.oauth2.settings import OAuth2Settings # noqa __all__ = ["OAuth2Settings", "OAuth2ClientProvider"] diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py new file mode 100644 index 0000000000..a663034ae3 --- /dev/null +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/_backends.py @@ -0,0 +1,88 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from typing import Type, Dict, Any + +from social_core.backends.oauth import BaseOAuth2 +from social_core.backends.open_id_connect import OpenIdConnectAuth +from social_core.backends.utils import load_backends +from social_core.strategy import BaseStrategy + +from argilla_server.errors.future import NotFoundError + + +class Strategy(BaseStrategy): + def request_data(self, merge=True) -> Dict[str, Any]: + return {} + + def absolute_uri(self, path=None) -> str: + return path + + def get_setting(self, name): + return os.environ[name] + + +class HuggingfaceOpenId(OpenIdConnectAuth): + """Huggingface OpenID Connect authentication backend.""" + + name = "huggingface" + + AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" + ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" + + # OIDC configuration + OIDC_ENDPOINT = "https://huggingface.co" + + DEFAULT_SCOPE = ["openid", "profile"] + + +_SUPPORTED_BACKENDS = {} + + +def load_supported_backends(extra_backends: list = None) -> Dict[str, Type[BaseOAuth2]]: + global _SUPPORTED_BACKENDS + + backends = [ + "argilla_server.security.authentication.oauth2._backends.HuggingfaceOpenId", + "social_core.backends.github.GithubOAuth2", + "social_core.backends.google.GoogleOAuth2", + ] + + if extra_backends: + backends.extend(extra_backends) + + _SUPPORTED_BACKENDS = load_backends(backends, force_load=True) + + for backend in _SUPPORTED_BACKENDS.values(): + if not issubclass(backend, BaseOAuth2): + raise ValueError( + f"Backend {backend} is not a supported OAuth2 backend. " + "Please, make sure it is a subclass of BaseOAuth2." + ) + + return _SUPPORTED_BACKENDS + + +def get_supported_backend_by_name(name: str) -> Type[BaseOAuth2]: + """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" + global _SUPPORTED_BACKENDS + + if not _SUPPORTED_BACKENDS: + _SUPPORTED_BACKENDS = load_supported_backends() + + if provider := _SUPPORTED_BACKENDS.get(name): + return provider + else: + raise NotFoundError(f"Unsupported provider {name}. Supported providers are {_SUPPORTED_BACKENDS.keys()}") diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py b/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py index 53774a2bf8..41f84d0b35 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/auth_backend.py @@ -19,7 +19,7 @@ from starlette.authentication import AuthCredentials, AuthenticationBackend, BaseUser from argilla_server.security.authentication.jwt import JWT -from argilla_server.security.authentication.oauth2.providers import OAuth2ClientProvider +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider from argilla_server.security.authentication.userinfo import UserInfo @@ -39,7 +39,4 @@ async def authenticate(self, request: Request) -> Optional[Tuple[AuthCredentials token_data = JWT.decode(credentials.credentials) user = UserInfo(token_data) - provider = self.providers.get(user.get("provider")) - claims = provider.claims if provider else {} - - return AuthCredentials(user.pop("scope", [])), user.use_claims(claims) + return AuthCredentials(user.pop("scope", [])), user diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py similarity index 67% rename from argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py rename to argilla-server/src/argilla_server/security/authentication/oauth2/provider.py index ef6586f92d..bad389c09a 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_base.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/provider.py @@ -17,66 +17,67 @@ import random import re import string -from typing import Dict, Any, ClassVar, Type, Optional, Union, List, Tuple +from typing import Dict, Any, ClassVar, Type, Optional, List, Tuple from urllib.parse import urljoin import httpx from oauthlib.oauth2 import WebApplicationClient from social_core.backends.oauth import BaseOAuth2 from social_core.exceptions import AuthException - from social_core.strategy import BaseStrategy from starlette.requests import Request -from starlette.responses import RedirectResponse +from starlette.responses import RedirectResponse, Response from argilla_server.errors import future -from argilla_server.security.authentication.claims import Claims +from argilla_server.security.authentication.oauth2._backends import Strategy from argilla_server.security.settings import settings -class Strategy(BaseStrategy): - def request_data(self, merge=True) -> Dict[str, Any]: - return {} - - def absolute_uri(self, path=None) -> str: - return path - - def get_setting(self, name): - return None - - class OAuth2ClientProvider: - """OAuth2 flow handler of a certain provider.""" + """OAuth2 flow handler of a certain provider.""" OAUTH_STATE_COOKIE_NAME = "oauth2_state" OAUTH_STATE_COOKIE_MAX_AGE = 90 - name: ClassVar[str] - backend_class: ClassVar[Type[BaseOAuth2]] - claims: ClassVar[Optional[Union[Claims, dict]]] = None backend_strategy: ClassVar[BaseStrategy] = Strategy() def __init__( self, + backend_class: Type[BaseOAuth2], client_id: str = None, client_secret: str = None, scope: Optional[List[str]] = None, redirect_uri: str = None, ) -> None: - self.client_id = client_id or self._environment_variable_for_property("client_id") - self.client_secret = client_secret or self._environment_variable_for_property("client_secret") - self.scope = scope or self._environment_variable_for_property("scope", "") - self.scope = self.scope.split(" ") if self.scope else [] - self.redirect_uri = redirect_uri or self._environment_variable_for_property("redirect_uri") - self.redirect_uri = self.redirect_uri or f"/oauth/{self.name}/callback" - - self._backend = self.backend_class(strategy=self.backend_strategy) + self.name = backend_class.name + self._backend = backend_class(strategy=self.backend_strategy) + self._authorization_endpoint = self._backend.authorization_url() self._token_endpoint = self._backend.access_token_url() + # Social Core uses the key and secret names for the client_id and client_secret + # These lines allow the use of the same environment variables as the social_core library. + # See https://python-social-auth.readthedocs.io/en/latest/configuration/settings.html for more information. + self.client_id = (client_id or self._environment_variable_for_property("client_id")) or self._backend.setting( + "key" + ) + + self.client_secret = ( + client_secret or self._environment_variable_for_property("client_secret") + ) or self._backend.setting("secret") + + self.scope = (scope or self._environment_variable_for_property("scope")) or self._backend.setting( + "scope", + default=self._backend.get_scope(), + ) + if isinstance(self.scope, str): + self.scope = self.scope.split(" ") + + self.redirect_uri = redirect_uri or f"/oauth/{self.name}/callback" + @classmethod - def from_dict(cls, provider: dict) -> "OAuth2ClientProvider": - return cls(**provider) + def from_dict(cls, provider: dict, backend_class: Type[BaseOAuth2]) -> "OAuth2ClientProvider": + return cls(backend_class=backend_class, **provider) def new_oauth_client(self) -> WebApplicationClient: return WebApplicationClient(self.client_id) @@ -89,8 +90,12 @@ def authorization_url(self, request: Request) -> Tuple[str, Optional[str]]: redirect_uri = self.get_redirect_uri(request) state = "".join([random.choice(string.ascii_letters) for _ in range(32)]) - oauth2_query_params = dict(state=state, scope=self.scope, redirect_uri=redirect_uri) - oauth2_query_params.update(request.query_params) + oauth2_query_params = { + "state": state, + "scope": self.scope, + "redirect_uri": redirect_uri, + **request.query_params, + } authorization_url = str( self.new_oauth_client().prepare_request_uri(self._authorization_endpoint, **oauth2_query_params) @@ -102,17 +107,18 @@ def authorization_redirect(self, request: Request) -> RedirectResponse: url, state = self.authorization_url(request) response = RedirectResponse(url, 303) - response.set_cookie( - self.OAUTH_STATE_COOKIE_NAME, - value=state, - secure=True, - httponly=True, - max_age=self.OAUTH_STATE_COOKIE_MAX_AGE, - samesite="none", - ) + self._set_state(state, response) return response + def standardize(self, data: Dict[str, Any]) -> Dict[str, Any]: + data = self._backend.get_user_details(data) + + data["provider"] = self.name + data["scope"] = self.scope + + return data + async def get_user_data(self, request: Request) -> dict: self._check_request_params(request) @@ -131,10 +137,13 @@ def _check_request_params(self, request) -> None: if "state" not in request.query_params: raise ValueError("'state' parameter was not found in callback request") - state = request.cookies.get(self.OAUTH_STATE_COOKIE_NAME) + state = self._get_state(request) if request.query_params.get("state") != state: raise ValueError("'state' parameter does not match") + def _get_state(self, request) -> Optional[str]: + return request.cookies.get(self._get_state_cookie_name()) + @staticmethod def _align_url_to_allow_http_redirect(url: str) -> str: """This method is used to align the URL to the HTTP/HTTPS scheme""" @@ -142,27 +151,40 @@ def _align_url_to_allow_http_redirect(url: str) -> str: scheme = "http" if settings.oauth.allow_http_redirect else "https" return re.sub(r"^https?", scheme, url) - def standardize(self, data: Dict[str, Any]) -> Dict[str, Any]: - data["provider"] = self.name - data["scope"] = self.scope + def _set_state(self, state: str, response: Response) -> None: + response.set_cookie( + self._get_state_cookie_name(), + value=state, + secure=True, + httponly=True, + max_age=self.OAUTH_STATE_COOKIE_MAX_AGE, + samesite="none", + ) - return data + def _get_state_cookie_name(self) -> str: + return f"{self.name}_{self.OAUTH_STATE_COOKIE_NAME}" async def _fetch_user_data(self, authorization_response: str, **oauth_query_params) -> dict: oauth_client = self.new_oauth_client() + token_request_params = {**oauth_query_params} + + auth = None + if self._backend.use_basic_auth(): + auth = httpx.BasicAuth(self.client_id, self.client_secret) + else: + token_request_params["client_secret"] = self.client_secret + token_url, headers, content = oauth_client.prepare_token_request( self._token_endpoint, authorization_response=authorization_response, - **oauth_query_params, + **token_request_params, ) headers.update({"Accept": "application/json"}) - auth = httpx.BasicAuth(self.client_id, self.client_secret) async with httpx.AsyncClient(auth=auth) as session: try: response = await session.post(token_url, headers=headers, content=content) oauth_client.parse_request_body_response(json.dumps(response.json())) - return self.standardize(self._backend.user_data(oauth_client.access_token)) except httpx.HTTPError as e: raise ValueError(str(e)) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py deleted file mode 100644 index 0950bc30d2..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/__init__.py +++ /dev/null @@ -1,42 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from typing import Type - -from argilla_server.errors.future import NotFoundError -from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider -from argilla_server.security.authentication.oauth2.providers._github import GitHubClientProvider -from argilla_server.security.authentication.oauth2.providers._huggingface import HuggingfaceClientProvider - -__all__ = [ - "OAuth2ClientProvider", - "GitHubClientProvider", - "HuggingfaceClientProvider", - "get_provider_by_name", -] - -_ALL_SUPPORTED_OAUTH2_PROVIDERS = { - GitHubClientProvider.name: GitHubClientProvider, - HuggingfaceClientProvider.name: HuggingfaceClientProvider, -} - - -def get_provider_by_name(name: str) -> Type["OAuth2ClientProvider"]: - """Get a registered oauth provider by name. Raise a ValueError if provided not found.""" - if provider_class := _ALL_SUPPORTED_OAUTH2_PROVIDERS.get(name): - return provider_class - else: - raise NotFoundError( - f"Unsupported provider {name}. " f"Supported providers are {_ALL_SUPPORTED_OAUTH2_PROVIDERS.keys()}" - ) diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py deleted file mode 100644 index ea4f3f1918..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_github.py +++ /dev/null @@ -1,28 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -from social_core.backends.github import GithubOAuth2 - -from argilla_server.security.authentication.claims import Claims -from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider - - -class GitHubClientProvider(OAuth2ClientProvider): - claims = Claims( - picture="avatar_url", - identity=lambda user: f"{user.provider}:{user.id}", - username="login", - ) - backend_class = GithubOAuth2 - name = "github" diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py b/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py deleted file mode 100644 index 57365930d8..0000000000 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/providers/_huggingface.py +++ /dev/null @@ -1,47 +0,0 @@ -# Copyright 2021-present, the Recognai S.L. team. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - -import logging - -from social_core.backends.open_id_connect import OpenIdConnectAuth - -from argilla_server.logging import LoggingMixin -from argilla_server.security.authentication.claims import Claims -from argilla_server.security.authentication.oauth2.providers._base import OAuth2ClientProvider - -_LOGGER = logging.getLogger("argilla.security.oauth2.providers.huggingface") - - -class HuggingfaceOpenId(OpenIdConnectAuth): - """Huggingface OpenID Connect authentication backend.""" - - name = "huggingface" - - OIDC_ENDPOINT = "https://huggingface.co" - AUTHORIZATION_URL = "https://huggingface.co/oauth/authorize" - ACCESS_TOKEN_URL = "https://huggingface.co/oauth/token" - - def oidc_endpoint(self) -> str: - return self.OIDC_ENDPOINT - - -_HF_PREFERRED_USERNAME = "preferred_username" - - -class HuggingfaceClientProvider(OAuth2ClientProvider, LoggingMixin): - """Specialized HuggingFace OAuth2 provider.""" - - claims = Claims(username=_HF_PREFERRED_USERNAME, first_name="name") - backend_class = HuggingfaceOpenId - name = "huggingface" diff --git a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py index e4771bad07..61c702ec69 100644 --- a/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py +++ b/argilla-server/src/argilla_server/security/authentication/oauth2/settings.py @@ -16,7 +16,11 @@ import yaml -from argilla_server.security.authentication.oauth2.providers import get_provider_by_name, OAuth2ClientProvider +from argilla_server.security.authentication.oauth2._backends import ( + get_supported_backend_by_name, + load_supported_backends, +) +from argilla_server.security.authentication.oauth2.provider import OAuth2ClientProvider __all__ = ["OAuth2Settings"] @@ -31,8 +35,6 @@ class OAuth2Settings: OAuth2 settings model. Args: - enabled: - Whether OAuth2 authentication is enabled or not. allow_http_redirect: Whether to allow HTTP scheme on redirect urls (for tests purposes). providers: @@ -43,18 +45,18 @@ class OAuth2Settings: ALLOWED_WORKSPACES_KEY = "allowed_workspaces" PROVIDERS_KEY = "providers" + EXTRA_BACKENDS_KEY = "extra_backends" def __init__( self, - enabled: bool = True, allow_http_redirect: bool = False, - providers: List[OAuth2ClientProvider] = None, - allowed_workspaces: List[AllowedWorkspace] = None, + extra_backends: List[str] = None, + **settings, ): - self.enabled = enabled self.allow_http_redirect = allow_http_redirect - self.allowed_workspaces = allowed_workspaces or [] - self._providers = providers or [] + self.extra_backends = extra_backends or [] + self.allowed_workspaces = self._build_workspaces(settings) or [] + self._providers = self._build_providers(settings, extra_backends) or [] if self.allow_http_redirect: # See https://stackoverflow.com/questions/27785375/testing-flask-oauthlib-locally-without-https @@ -69,16 +71,7 @@ def from_yaml(cls, yaml_file: str) -> "OAuth2Settings": """Creates an instance of OAuth2Settings from a YAML file.""" with open(yaml_file) as f: - return cls.from_dict(yaml.safe_load(f)) - - @classmethod - def from_dict(cls, settings: dict) -> "OAuth2Settings": - """Creates an instance of OAuth2Settings from a dictionary.""" - - settings[cls.PROVIDERS_KEY] = cls._build_providers(settings) - settings[cls.ALLOWED_WORKSPACES_KEY] = cls._build_workspaces(settings) - - return cls(**settings) + return cls(**yaml.safe_load(f)) @classmethod def _build_workspaces(cls, settings: dict) -> List[AllowedWorkspace]: @@ -86,13 +79,15 @@ def _build_workspaces(cls, settings: dict) -> List[AllowedWorkspace]: return [AllowedWorkspace(**workspace) for workspace in allowed_workspaces] @classmethod - def _build_providers(cls, settings: dict) -> List["OAuth2ClientProvider"]: + def _build_providers(cls, settings: dict, extra_backends) -> List["OAuth2ClientProvider"]: providers = [] + load_supported_backends(extra_backends=extra_backends) + for provider in settings.pop("providers", []): name = provider.pop("name") - provider_class = get_provider_by_name(name) - providers.append(provider_class.from_dict(provider)) + backend_class = get_supported_backend_by_name(name) + providers.append(OAuth2ClientProvider.from_dict(provider, backend_class)) return providers diff --git a/argilla-server/src/argilla_server/security/authentication/userinfo.py b/argilla-server/src/argilla_server/security/authentication/userinfo.py index 54220fc027..70173cda7e 100644 --- a/argilla-server/src/argilla_server/security/authentication/userinfo.py +++ b/argilla-server/src/argilla_server/security/authentication/userinfo.py @@ -11,13 +11,12 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - +import os from typing import Any, Optional from starlette.authentication import BaseUser from argilla_server.enums import UserRole -from argilla_server.security.authentication.claims import Claims _DEFAULT_USER_ROLE = UserRole.annotator @@ -39,16 +38,14 @@ def first_name(self) -> str: @property def role(self) -> UserRole: - role = self.get("role") or _DEFAULT_USER_ROLE + role = self.get("role") or self._parse_role_from_environment() return UserRole(role) - def use_claims(self, claims: Optional[Claims]) -> "UserInfo": - claims = claims or {} - - for attr, item in claims.items(): - self[attr] = self.__getprop__(item) - - return self + def _parse_role_from_environment(self) -> Optional[UserRole]: + """This is a temporal solution, and it will be replaced by a proper Sign up process""" + if self["username"] == os.getenv("USERNAME"): + return UserRole.owner + return _DEFAULT_USER_ROLE def __getprop__(self, item, default="") -> Any: if callable(item): diff --git a/argilla-server/src/argilla_server/security/settings.py b/argilla-server/src/argilla_server/security/settings.py index 3aab0eed8a..6835bc6186 100644 --- a/argilla-server/src/argilla_server/security/settings.py +++ b/argilla-server/src/argilla_server/security/settings.py @@ -64,7 +64,7 @@ def oauth(self) -> "OAuth2Settings": if not self._oauth_settings and os.path.exists(self.oauth_cfg): self._oauth_settings = OAuth2Settings.from_yaml(self.oauth_cfg) else: - self._oauth_settings = OAuth2Settings(enabled=False) + self._oauth_settings = OAuth2Settings() return self._oauth_settings diff --git a/argilla-server/src/argilla_server/validators/users.py b/argilla-server/src/argilla_server/validators/users.py new file mode 100644 index 0000000000..3d506fb032 --- /dev/null +++ b/argilla-server/src/argilla_server/validators/users.py @@ -0,0 +1,41 @@ +# Copyright 2021-present, the Recognai S.L. team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from sqlalchemy.ext.asyncio import AsyncSession + +from argilla_server.errors.future import NotUniqueError, UnprocessableEntityError +from argilla_server.models import User + + +class UserCreateValidator: + @classmethod + async def validate(cls, db: AsyncSession, user: User) -> None: + await cls._validate_username(db, user) + + @classmethod + async def _validate_username(cls, db, user: User): + await cls._validate_username_length(user) + await cls._validate_unique_username(db, user) + + @classmethod + async def _validate_unique_username(cls, db, user): + from argilla_server.contexts import accounts + + if await accounts.get_user_by_username(db, user.username) is not None: + raise NotUniqueError(f"User username `{user.username}` is not unique") + + @classmethod + async def _validate_username_length(cls, user: User): + if len(user.username) < 1: + raise UnprocessableEntityError("Username must be at least 1 characters long") diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index b5a3d87477..76bc4c6b4c 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -28,24 +28,16 @@ from tests.factories import AdminFactory, AnnotatorFactory -@pytest.fixture -def disabled_oauth_settings() -> OAuth2Settings: - return OAuth2Settings(enabled=False) - - @pytest.fixture def default_oauth_settings() -> OAuth2Settings: - return OAuth2Settings.from_dict( - { - "enabled": True, - "providers": [ - { - "name": "huggingface", - "client_id": "client_id", - "client_secret": "client_secret", - } - ], - } + return OAuth2Settings( + providers=[ + { + "name": "huggingface", + "client_id": "client_id", + "client_secret": "client_secret", + } + ] ) @@ -57,25 +49,6 @@ async def tests_list_providers_with_default_config(self, async_client: AsyncClie assert response.status_code == 200 assert response.json() == {"items": []} - async def test_list_providers_with_oauth_disabled( - self, async_client: AsyncClient, owner_auth_header: dict, disabled_oauth_settings: OAuth2Settings - ): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: disabled_oauth_settings - ): - response = await async_client.get("/api/v1/oauth2/providers", headers=owner_auth_header) - assert response.status_code == 200 - assert response.json() == {"items": []} - - async def test_list_provider_with_oauth_disabled_from_settings( - self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings - ): - default_oauth_settings.enabled = False - with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): - response = await async_client.get("/api/v1/oauth2/providers", headers=owner_auth_header) - assert response.status_code == 200 - assert response.json() == {"items": []} - async def test_list_providers( self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings ): @@ -99,33 +72,6 @@ async def test_provider_huggingface_authentication( assert b"/oauth/authorize?response_type=code&client_id=client_id" in redirect_url.target assert b"&extra=params" in redirect_url.target - async def test_provider_authentication_with_oauth_disabled( - self, - async_client: AsyncClient, - owner_auth_header: dict, - disabled_oauth_settings: OAuth2Settings, - ): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: disabled_oauth_settings - ): - response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header - ) - assert response.status_code == 404 - - async def test_provider_authentication_with_oauth_disabled_and_provider_defined( - self, - async_client: AsyncClient, - owner_auth_header: dict, - default_oauth_settings: OAuth2Settings, - ): - default_oauth_settings.enabled = False - with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): - response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header - ) - assert response.status_code == 404 - async def test_provider_authentication_with_invalid_provider( self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings ): @@ -144,14 +90,14 @@ async def test_provider_huggingface_access_token( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": "username", "name": "name"}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": "username", "name": "name"}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 200 @@ -173,14 +119,14 @@ async def test_provider_huggingface_access_token_with_missing_username( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", return_value={"name": "name"}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 500 @@ -194,14 +140,14 @@ async def test_provider_huggingface_access_token_with_missing_name( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": "username"}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": "username"}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 200 @@ -215,20 +161,6 @@ async def test_provider_huggingface_access_token_with_missing_name( assert user.role == UserRole.annotator assert user.first_name == "username" - async def test_provider_access_token_with_oauth_disabled( - self, - async_client: AsyncClient, - owner_auth_header: dict, - disabled_oauth_settings: OAuth2Settings, - ): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: disabled_oauth_settings - ): - response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/access-token", headers=owner_auth_header - ) - assert response.status_code == 404 - async def test_provider_access_token_with_invalid_provider( self, async_client: AsyncClient, owner_auth_header: dict, default_oauth_settings: OAuth2Settings ): @@ -266,7 +198,7 @@ async def test_provider_access_token_with_invalid_state( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "invalid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 422 assert response.json() == {"detail": "'state' parameter does not match"} @@ -276,14 +208,14 @@ async def test_provider_access_token_with_authentication_error( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", side_effect=AuthenticationError("error"), ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 401 assert response.json() == {"detail": "error"} @@ -299,14 +231,14 @@ async def test_provider_access_token_with_already_created_user( with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": admin.username, "name": admin.first_name}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": admin.username, "name": admin.first_name}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) assert response.status_code == 200 @@ -325,14 +257,14 @@ async def test_provider_access_token_with_same_username( with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): with mock.patch( - "argilla_server.security.authentication.oauth2.providers.OAuth2ClientProvider._fetch_user_data", - return_value={"preferred_username": user.username, "name": user.first_name}, + "argilla_server.security.authentication.oauth2.provider.OAuth2ClientProvider._fetch_user_data", + return_value={"username": user.username, "name": user.first_name}, ): response = await async_client.get( "/api/v1/oauth2/providers/huggingface/access-token", params={"code": "code", "state": "valid"}, headers=owner_auth_header, - cookies={"oauth2_state": "valid"}, + cookies={"huggingface_oauth2_state": "valid"}, ) # This will throw an error once we detect users created by OAuth2 assert response.status_code == 200 diff --git a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py index c0175f273b..d396434706 100644 --- a/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py +++ b/argilla-server/tests/unit/security/authentication/oauth2/test_settings.py @@ -21,20 +21,18 @@ class TestOAuth2Settings: def test_configure_unsupported_provider(self): with pytest.raises(NotFoundError): - OAuth2Settings.from_dict({"providers": [{"name": "unsupported"}]}) + OAuth2Settings(providers=[{"name": "unsupported"}]) def test_configure_github_provider(self): - settings = OAuth2Settings.from_dict( - { - "providers": [ - { - "name": "github", - "client_id": "github_client_id", - "client_secret": "github_client_secret", - "scope": "user:email", - } - ] - } + settings = OAuth2Settings( + providers=[ + { + "name": "github", + "client_id": "github_client_id", + "client_secret": "github_client_secret", + "scope": "user:email", + } + ] ) github_provider = settings.providers["github"] @@ -44,17 +42,15 @@ def test_configure_github_provider(self): assert github_provider.scope == ["user:email"] def test_configure_huggingface_provider(self): - settings = OAuth2Settings.from_dict( - { - "providers": [ - { - "name": "huggingface", - "client_id": "huggingface_client_id", - "client_secret": "huggingface_client_secret", - "scope": "openid profile email", - } - ] - } + settings = OAuth2Settings( + providers=[ + { + "name": "huggingface", + "client_id": "huggingface_client_id", + "client_secret": "huggingface_client_secret", + "scope": "openid profile email", + } + ] ) huggingface_provider = settings.providers["huggingface"] @@ -62,3 +58,38 @@ def test_configure_huggingface_provider(self): assert huggingface_provider.client_id == "huggingface_client_id" assert huggingface_provider.client_secret == "huggingface_client_secret" assert huggingface_provider.scope == ["openid", "profile", "email"] + + def test_configure_extra_backends(self): + from social_core.backends.microsoft import MicrosoftOAuth2 + + provider_name = MicrosoftOAuth2.name + settings = OAuth2Settings( + extra_backends=["social_core.backends.microsoft.MicrosoftOAuth2"], + providers=[ + { + "name": provider_name, + "client_id": "microsoft_client_id", + "client_secret": "microsoft_client_secret", + } + ], + ) + + assert len(settings.providers) == 1 + extra_provider = settings.providers[provider_name] + + assert extra_provider.name == provider_name + assert extra_provider.client_id == "microsoft_client_id" + assert extra_provider.client_secret == "microsoft_client_secret" + + def test_configure_non_supported_extra_backends(self): + with pytest.raises(ValueError): + OAuth2Settings( + extra_backends=["social_core.backends.twitter.TwitterOAuth"], + providers=[ + { + "name": "github", + "client_id": "github_client_id", + "client_secret": "github_client_secret", + } + ], + ) diff --git a/argilla-server/tests/unit/security/authentication/test_userinfo.py b/argilla-server/tests/unit/security/authentication/test_userinfo.py index 8203d56c66..200fb940c1 100644 --- a/argilla-server/tests/unit/security/authentication/test_userinfo.py +++ b/argilla-server/tests/unit/security/authentication/test_userinfo.py @@ -18,7 +18,6 @@ from argilla_server.enums import UserRole from argilla_server.security.authentication import UserInfo -from argilla_server.security.authentication.claims import Claims class TestUserInfo: @@ -43,19 +42,8 @@ def test_get_userinfo_role(self): userinfo = UserInfo({"username": "user", "role": "owner"}) assert userinfo.role == UserRole.owner - def test_get_userinfo_with_claims(self): - userinfo = UserInfo({"username": "user"}).use_claims( - Claims( - first_name=lambda user: user["username"].upper(), - last_name=lambda user: "Peter", - ) - ) - - assert userinfo.first_name == "USER" - assert userinfo.last_name == "Peter" - def test_get_userinfo_role_with_username_env(self, mocker: MockerFixture): mocker.patch.dict(os.environ, {"USERNAME": "user"}) - userinfo = UserInfo({"id": "user"}).use_claims(Claims(username="id")) + userinfo = UserInfo({"username": "user"}) assert userinfo.role == UserRole.owner diff --git a/argilla-server/tests/unit/test_app.py b/argilla-server/tests/unit/test_app.py index 48fbc82bff..93de167ed6 100644 --- a/argilla-server/tests/unit/test_app.py +++ b/argilla-server/tests/unit/test_app.py @@ -78,12 +78,7 @@ def test_server_timing_header(self): async def test_create_allowed_workspaces(self, db: AsyncSession): with mock.patch( "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings.from_dict( - { - "enabled": True, - "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], - } - ), + new_callable=lambda: OAuth2Settings(allowed_workspaces=[{"name": "ws1"}, {"name": "ws2"}]), ): await _create_oauth_allowed_workspaces(db) @@ -91,25 +86,8 @@ async def test_create_allowed_workspaces(self, db: AsyncSession): assert len(workspaces) == 2 assert set([ws.name for ws in workspaces]) == {"ws1", "ws2"} - async def test_create_allowed_workspaces_with_oauth_disabled(self, db: AsyncSession): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings.from_dict( - { - "enabled": False, - "allowed_workspaces": [{"name": "ws1"}, {"name": "ws2"}], - } - ), - ): - await _create_oauth_allowed_workspaces(db) - - workspaces = (await db.scalars(select(Workspace))).all() - assert len(workspaces) == 0 - async def test_create_workspaces_with_empty_workspaces_list(self, db: AsyncSession): - with mock.patch( - "argilla_server.security.settings.Settings.oauth", new_callable=lambda: OAuth2Settings(enabled=True) - ): + with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=OAuth2Settings): await _create_oauth_allowed_workspaces(db) workspaces = (await db.scalars(select(Workspace))).all() @@ -120,7 +98,7 @@ async def test_create_workspaces_with_existing_workspaces(self, db: AsyncSession with mock.patch( "argilla_server.security.settings.Settings.oauth", - new_callable=lambda: OAuth2Settings(enabled=True, allowed_workspaces=[AllowedWorkspace(name=ws.name)]), + new_callable=lambda: OAuth2Settings(allowed_workspaces=[{"name": ws.name}]), ): await _create_oauth_allowed_workspaces(db) diff --git a/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md b/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md index 72b5c1ae87..9d55019640 100644 --- a/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md +++ b/argilla/docs/getting_started/how-to-configure-argilla-on-huggingface.md @@ -51,8 +51,6 @@ To restrict access or change the default behaviour, there's two options: **Modify the `.oauth.yml` configuration file**. You can find and modify this file under the `Files` tab of your Space. The default file looks like this: ```yaml -# Change to `false` to disable HF oauth integration -#enabled: false providers: - name: huggingface @@ -61,10 +59,10 @@ providers: allowed_workspaces: - name: argilla ``` -You can modify two things: +You can: -- Uncomment `enabled: false` to completely disable the Sign in with Hugging Face. If you disable it make sure to set the `USERNAME` and `PASSWORD` Space secrets to be able to login as an `owner`. - Change the list of `allowed` workspaces. +- Rename the `.oauth.yml` file to disable OAuth access. For example if you want to let users join a new workspace `community-initiative`: diff --git a/argilla/docs/reference/argilla-server/configuration.md b/argilla/docs/reference/argilla-server/configuration.md index 99cb92cf9b..b81480f76d 100644 --- a/argilla/docs/reference/argilla-server/configuration.md +++ b/argilla/docs/reference/argilla-server/configuration.md @@ -50,9 +50,10 @@ You can set the following environment variables to further configure your server #### Authentication -- `ARGILLA_AUTH_SECRET_KEY`: The secret key used to sign the API token data. You can use `openssl rand -hex 32` to generate a 32 character string to use with this environment variable. By default a random value is generated, so if you are using more than one server worker (or more than one Argilla server) you will need to set the same value for all of them. - `USERNAME`: If provided, the owner username (Default: `None`). - `PASSWORD`: If provided, the owner password (Default: `None`). +- `ARGILLA_AUTH_SECRET_KEY`: The secret key used to sign the API token data. You can use `openssl rand -hex 32` to generate a 32 character string to use with this environment variable. By default a random value is generated, so if you are using more than one server worker (or more than one Argilla server) you will need to set the same value for all of them. +- `ARGILLA_AUTH_OAUTH_CFG`: Path to the OAuth2 configuration file (Default: `$PWD/.oauth.yml`). If `USERNAME` and `PASSWORD` are provided, the owner user will be created with these credentials on the server startup. diff --git a/argilla/docs/reference/argilla-server/oauth2_configuration.md b/argilla/docs/reference/argilla-server/oauth2_configuration.md new file mode 100644 index 0000000000..da14903ed2 --- /dev/null +++ b/argilla/docs/reference/argilla-server/oauth2_configuration.md @@ -0,0 +1,163 @@ +# OAuth2 configuration + +Argilla supports OAuth2 authentication for users. This allows users to authenticate using other services like Google, +GitHub, or Hugging Face. Next sections will guide you through the configuration of the OAuth2 authentication. + +## The OAuth2 configuration file + +The OAuth2 configuration file is a YAML file that contains the configuration for the OAuth2 providers that you want to +enable. The default file name is `.oauth.yml` and it should be placed in the root directory of the Argilla server. You +can also specify a different file name using the `ARGILLA_AUTH_OAUTH_CFG` environment variable. + +The file should have the following structure: + +```yaml +providers: + - name: huggingface + client_id: "" + client_secret: "" + scope: "openid profile" + + - name: google-oauth2 + client_id: "" + client_secret: "" + scope: "openid email profile" + + - name: github + client_id: "" + client_secret: "" + +allowed_workspaces: + - name: argilla + +allow_http_redirect: false +``` + +### Providers + +The `providers` key is a list of dictionaries, each dictionary represents a provider configuration, including the +following fields: + +- `name`: The name of the provider. The available options by default are `huggingface`, `github` and `google-oauth2`. +We will see later how to add more providers not supported by default. +- `client_id`: The client ID provided by the OAuth2 provider. You can get this value by creating an application in the +provider's developer console. This is a required field, but you can also use the +`ARGILLA_OAUTH2__CLIENT_ID` environment variable to set the value. +- `client_secret`: The client secret provided by the OAuth2 provider. You can get this value by creating an application +in the provider's developer console. This is a required field, but you can also use +the `ARGILLA_OAUTH2__CLIENT_SECRET` environment variable to set the value. +- `scope`: The scope of the OAuth2 provider. This is an optional field, and normally you don't need to set it, but +you can use it to request specific permissions from the user access. + +### Allowed Workspaces + +The `allowed_workspaces` key defines the available workspaces when users log in using the OAuth2 provider. This is +a list of `name` fields that should match the workspace name in the Argilla server. By default, the `argilla` workspace +is allowed to authenticate using the OAuth2 provider. + +If the workspace doesn't exist, it will be created automatically on the first server startup. + +### Allow HTTP Redirect + +The `allow_http_redirect` key is a boolean value that allows the OAuth2 provider to redirect the user to an HTTP URL. +By default, this value is set to `false`, and you should set it to `true` only if you are running the Argilla server +behind a proxy that doesn't support HTTPS or if you are running the server locally. + +Enabling this option is not recommended for production environments and should be used only for development purposes. + +## Supported OAuth2 providers configuration + +The following sections will guide you through the configuration of the supported OAuth2 providers. Before diving into +the configuration, you should create an application in the provider's developer console to get the client ID and client +secret. + +A common step when creating an application in the provider's developer console is to set the redirect URI. The +redirect URI is the URL where the OAuth2 provider will redirect the user after the authentication process. + +The redirect URI should be set to the Argilla server URL, followed by `/oauth//callback`. For example, +if the Argilla server is running on `http://localhost:8000`, the redirect URI for provider application should +be `http://localhost:8000/oauth/huggingface/callback`. + +### Hugging Face OAuth2 configuration + +Argilla supports Hugging Face OAuth2 authentication out of the box, and is already configured when running Argilla +on Hugging Face Spaces (See the [Hugging Face Spaces settings](../../getting_started/how-to-configure-argilla-on-huggingface.md) for more information). + +But, if you want to manually configure the Hugging Face OAuth2 provider, you should define the following +fields in the `.oauth.yml` file: + +```yaml + +providers: + - name: huggingface + client_id: "" # You can use the ARGILLA_OAUTH2_HUGGINGFACE_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_HUGGINGFACE_CLIENT_SECRET environment variable + scope: "openid profile" # This field is optional. But this value must be aligned your OAuth2 application created in Hugging Face. + +... +``` + +To get your client ID and client secret, you need to create an [OAuth2 application](https://huggingface.co/settings/applications/new) in the Hugging Face +settings page. + +The minimal scope required for the Hugging Face OAuth2 provider is `openid profile`, so you don't need to +change the `scope` when creating the application. + +### GitHub OAuth2 configuration + +Argilla also supports GitHub OAuth2 authentication out of the box. To configure the GitHub OAuth2 provider, you should +define the following fields in the `.oauth.yml` file: + +```yaml + +providers: + - name: github + client_id: "" # You can use the ARGILLA_OAUTH2_GITHUB_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_GITHUB_CLIENT_SECRET environment variable + +... +``` + +To get your client ID and client secret, you need to register a new [OAuth application](https://github.com/settings/applications/new) in the GitHub settings page. + +### Google OAuth2 configuration + +Argilla also supports Google OAuth2 authentication out of the box. To configure the Google OAuth2 provider, you +should define the following fields in the `.oauth.yml` file: + +```yaml + +providers: + - name: google-oauth2 + client_id: "" # You can use the ARGILLA_OAUTH2_GOOGLE_OAUTH2_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_GOOGLE_OAUTH2_CLIENT_SECRET environment variable + +... +``` + +To get your client ID and client secret, you need to create a new [OAuth2 client](https://console.cloud.google.com/apis/credentials/oauthclient) in the Google Cloud Console. + +### Adding more OAuth2 providers + +If you want to add more OAuth2 providers that are not supported by default, you can do so by adding a new provider +configuration to the `.oauth.yml` file. The Argilla server uses the [Social Auth backends](https://python-social-auth.readthedocs.io/en/latest/backends/index.html) component to define +the provider configuration. You only need to register the provider backend using the `extra_backends` key in +the `.oauth.yml` file. + +For example, to configure the [Apple OAuth2 provider](https://python-social-auth.readthedocs.io/en/latest/backends/apple.html), you should add the following configuration to +the `.oauth.yml` file: + +```yaml + +providers: + - name: apple-id + client_id: "" # You can use the ARGILLA_OAUTH2_APPLE_ID_CLIENT_ID environment variable + client_secret: "" # You can use the ARGILLA_OAUTH2_APPLE_ID_CLIENT_SECRET environment variable + +extra_backends: + - social_core.backends.apple.AppleIdAuth # Register the Apple OAuth2 provider backend + +``` + +All the `SOCIAL_AUTH_*` environment variables are supported by the Argilla server, so you can customize the provider +configuration using these environment variables. diff --git a/argilla/mkdocs.yml b/argilla/mkdocs.yml index b31214619b..09dc10040d 100644 --- a/argilla/mkdocs.yml +++ b/argilla/mkdocs.yml @@ -189,6 +189,7 @@ nav: - Python SDK: reference/argilla/ - FastAPI Server: - Server configuration: reference/argilla-server/configuration.md + - OAuth2 configuration: reference/argilla-server/oauth2_configuration.md - Telemetry: - Server Telemetry: reference/argilla-server/telemetry.md - Community: