diff --git a/argilla-frontend/CHANGELOG.md b/argilla-frontend/CHANGELOG.md index 0fc3a3bee5..c0da610be7 100644 --- a/argilla-frontend/CHANGELOG.md +++ b/argilla-frontend/CHANGELOG.md @@ -17,7 +17,12 @@ These are the section headers that we use: ## [Unreleased]() ### Added -- Add a high-contrast theme & improvements for the forced-colors mode ([#5661](https://github.com/argilla-io/argilla/pull/5661)) + +- Add a high-contrast theme & improvements for the forced-colors mode. ([#5661](https://github.com/argilla-io/argilla/pull/5661)) + +### Fixed + +- Fixed redirection problems after users sign-in using HF OAuth. ([#5635](https://github.com/argilla-io/argilla/pull/5635)) ## [2.4.0](https://github.com/argilla-io/argilla/compare/v2.3.0...v2.4.0) diff --git a/argilla-frontend/middleware/route-guard.ts b/argilla-frontend/middleware/route-guard.ts index f333e69d1e..562e379af2 100644 --- a/argilla-frontend/middleware/route-guard.ts +++ b/argilla-frontend/middleware/route-guard.ts @@ -17,10 +17,16 @@ import { Context } from "@nuxt/types"; import { useRunningEnvironment } from "~/v1/infrastructure/services/useRunningEnvironment"; +import { useLocalStorage } from "~/v1/infrastructure/services"; + +const { set } = useLocalStorage(); export default ({ $auth, route, redirect }: Context) => { const { isRunningOnHuggingFace } = useRunningEnvironment(); + // By-pass unknown routes. This is needed to avoid errors with API calls. + if (route.name == null) return; + switch (route.name) { case "sign-in": if ($auth.loggedIn) return redirect("/"); @@ -28,12 +34,8 @@ export default ({ $auth, route, redirect }: Context) => { if (route.params.omitCTA) return; if (isRunningOnHuggingFace()) { - // eslint-disable-next-line @typescript-eslint/no-unused-vars - const { redirect: _, ...query } = route.query; - return redirect({ name: "welcome-hf-sign-in", - query, }); } break; @@ -50,14 +52,11 @@ export default ({ $auth, route, redirect }: Context) => { default: if (!$auth.loggedIn) { if (route.path !== "/") { - route.query.redirect = route.fullPath; + set("redirectTo", route.fullPath); } redirect({ name: "sign-in", - query: { - ...route.query, - }, }); } } diff --git a/argilla-frontend/pages/oauth/_provider/useOAuthViewModel.ts b/argilla-frontend/pages/oauth/_provider/useOAuthViewModel.ts index 91ccd6104d..1cbb722af7 100644 --- a/argilla-frontend/pages/oauth/_provider/useOAuthViewModel.ts +++ b/argilla-frontend/pages/oauth/_provider/useOAuthViewModel.ts @@ -2,7 +2,11 @@ import { useFetch, useRoute } from "@nuxtjs/composition-api"; import { useResolve } from "ts-injecty"; import { ProviderType } from "~/v1/domain/entities/oauth/OAuthProvider"; import { OAuthLoginUseCase } from "~/v1/domain/usecases/oauth-login-use-case"; -import { useRoutes, useTranslate } from "~/v1/infrastructure/services"; +import { + useRoutes, + useTranslate, + useLocalStorage, +} from "~/v1/infrastructure/services"; import { useNotifications } from "~/v1/infrastructure/services/useNotifications"; export const useOAuthViewModel = () => { @@ -11,11 +15,17 @@ export const useOAuthViewModel = () => { const routes = useRoute(); const router = useRoutes(); const oauthLoginUseCase = useResolve(OAuthLoginUseCase); + const { pop } = useLocalStorage(); useFetch(async () => { await tryLogin(); }); + const redirect = () => { + const redirect = pop("redirectTo"); + router.go(redirect || "/"); + }; + const tryLogin = async () => { const { params, query } = routes.value; @@ -23,12 +33,12 @@ export const useOAuthViewModel = () => { try { await oauthLoginUseCase.login(provider, query); + redirect(); } catch { notification.notify({ message: t("argilla.api.errors::UnauthorizedError"), type: "danger", }); - } finally { router.go("/"); } }; diff --git a/argilla-frontend/pages/sign-in.vue b/argilla-frontend/pages/sign-in.vue index 6d5d0f71ce..f804223a01 100644 --- a/argilla-frontend/pages/sign-in.vue +++ b/argilla-frontend/pages/sign-in.vue @@ -98,18 +98,8 @@ export default { }, }, methods: { - nextRedirect() { - const redirect_url = this.$nuxt.$route.query.redirect || "/"; - this.$router.push({ - path: redirect_url, - }); - }, async loginUser({ username, password }) { await this.login(username, password); - - this.$notification.clear(); - - this.nextRedirect(); }, async onLoginUser() { try { diff --git a/argilla-frontend/pages/useSignInViewModel.ts b/argilla-frontend/pages/useSignInViewModel.ts index 12075594be..b2d5f6f2ee 100644 --- a/argilla-frontend/pages/useSignInViewModel.ts +++ b/argilla-frontend/pages/useSignInViewModel.ts @@ -1,11 +1,23 @@ import { useResolve } from "ts-injecty"; import { AuthLoginUseCase } from "~/v1/domain/usecases/auth-login-use-case"; +import { useRoutes, useLocalStorage } from "~/v1/infrastructure/services"; +import { useNotifications } from "~/v1/infrastructure/services/useNotifications"; export const useSignInViewModel = () => { const useCase = useResolve(AuthLoginUseCase); + const router = useRoutes(); + const notification = useNotifications(); + const { pop } = useLocalStorage(); + + const redirect = () => { + const redirect = pop("redirectTo"); + router.go(redirect || "/"); + }; const login = async (username: string, password: string) => { await useCase.login(username, password); + notification.clear(); + redirect(); }; return { diff --git a/argilla-frontend/v1/infrastructure/services/useLocalStorage.ts b/argilla-frontend/v1/infrastructure/services/useLocalStorage.ts index f3331172bb..7d6c8a7483 100644 --- a/argilla-frontend/v1/infrastructure/services/useLocalStorage.ts +++ b/argilla-frontend/v1/infrastructure/services/useLocalStorage.ts @@ -1,4 +1,4 @@ -type Options = "showShortcutsHelper" | "layout"; +type Options = "showShortcutsHelper" | "layout" | "redirectTo"; const STORAGE_KEY = "argilla"; @@ -34,8 +34,15 @@ export const useLocalStorage = () => { } catch {} }; + const pop = (key: Options) => { + const value = get(key); + set(key, null); + return value; + }; + return { get, set, + pop, }; }; diff --git a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py index 7f8d6ddee1..b5a3d87477 100644 --- a/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py +++ b/argilla-server/tests/unit/api/handlers/v1/test_oauth2.py @@ -89,7 +89,7 @@ async def test_provider_huggingface_authentication( ): with mock.patch("argilla_server.security.settings.Settings.oauth", new_callable=lambda: default_oauth_settings): response = await async_client.get( - "/api/v1/oauth2/providers/huggingface/authentication", headers=owner_auth_header + "/api/v1/oauth2/providers/huggingface/authentication?extra=params", headers=owner_auth_header ) assert response.status_code == 303 @@ -97,6 +97,7 @@ async def test_provider_huggingface_authentication( assert redirect_url.scheme == b"https" assert redirect_url.host == b"huggingface.co" assert b"/oauth/authorize?response_type=code&client_id=client_id" in redirect_url.target + assert b"&extra=params" in redirect_url.target async def test_provider_authentication_with_oauth_disabled( self,