feat: add auth to chatui (#1065)

This commit is contained in:
Robin Goetz
2024-02-28 23:33:11 +01:00
committed by GitHub
parent 15b0aebf52
commit a9146fb902
27 changed files with 425 additions and 174 deletions

View File

@@ -1,13 +1,15 @@
import { PropsWithChildren } from 'react';
import { useLocation, useNavigate } from 'react-router-dom';
import { useAuthStoreActions, useAuthStoreState } from './libs/auth/auth.store';
import { useAuthQuery } from './libs/auth/use-auth.query';
const Auth = (props: PropsWithChildren) => {
const result = useAuthQuery();
const { uuid } = useAuthStoreState();
const { setAsAuthenticated } = useAuthStoreActions();
if (result.isSuccess && uuid !== result.data.uuid) {
setAsAuthenticated(result.data.uuid);
const { loggedIn } = useAuthStoreState();
const { logout } = useAuthStoreActions();
const location = useLocation();
const navigate = useNavigate();
if (!loggedIn && location.pathname !== '/login') {
logout();
navigate('/login');
}
return props.children;
};

View File

@@ -6,6 +6,16 @@ export const AgentSchema = z.object({
human: z.string(),
persona: z.string(),
created_at: z.string(),
// TODO: Remove optional once API response returns necessary data
memories: z.number().optional(),
data_sources: z.number().optional(),
last_run: z.string().optional(),
tools: z
.object({
core: z.number(),
user_defined: z.number(),
})
.optional(),
});
export type Agent = z.infer<typeof AgentSchema>;

View File

@@ -1,17 +1,20 @@
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { AgentMemoryUpdate } from './agent-memory-update';
export const useAgentMemoryUpdateMutation = (userId: string | null | undefined) => {
const queryClient = useQueryClient();
const bearerToken = useAuthBearerToken();
return useMutation({
mutationFn: async (params: AgentMemoryUpdate) =>
await fetch(API_BASE_URL + `/agents/memory?${userId}`, {
method: 'POST',
headers: { 'Content-Type': ' application/json' },
headers: { 'Content-Type': ' application/json', Authorization: bearerToken },
body: JSON.stringify(params),
}).then((res) => res.json()),
onSuccess: (res, { agent_id }) =>
onSuccess: (_, { agent_id }) =>
queryClient.invalidateQueries({ queryKey: [userId, 'agents', 'entry', agent_id, 'memory'] }),
});
};

View File

@@ -1,13 +1,18 @@
import { useQuery } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { AgentMemory } from './agent-memory';
export const useAgentMemoryQuery = (userId: string | null | undefined, agentId: string | null | undefined) =>
useQuery({
export const useAgentMemoryQuery = (userId: string | null | undefined, agentId: string | null | undefined) => {
const bearerToken = useAuthBearerToken();
return useQuery({
queryKey: [userId, 'agents', 'entry', agentId, 'memory'],
queryFn: async () =>
(await fetch(API_BASE_URL + `/agents/memory?agent_id=${agentId}&user_id=${userId}`).then((res) =>
res.json()
)) as Promise<AgentMemory>,
(await fetch(API_BASE_URL + `/agents/memory?agent_id=${agentId}&user_id=${userId}`, {
headers: {
Authorization: bearerToken,
},
}).then((res) => res.json())) as Promise<AgentMemory>,
enabled: !!userId && !!agentId,
});
};

View File

@@ -1,14 +1,16 @@
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Agent } from './agent';
export const useAgentsCreateMutation = (userId: string | null | undefined) => {
const queryClient = useQueryClient();
const bearerToken = useAuthBearerToken();
return useMutation({
mutationFn: async (params: { name: string; human: string; persona: string; model: string }) => {
mutationFn: async (params: { name: string; human: string; persona: string; model: string }): Promise<Agent> => {
const response = await fetch(API_BASE_URL + '/agents', {
method: 'POST',
headers: { 'Content-Type': ' application/json' },
headers: { 'Content-Type': ' application/json', Authorization: bearerToken },
body: JSON.stringify({ config: params, user_id: userId }),
});
@@ -18,7 +20,7 @@ export const useAgentsCreateMutation = (userId: string | null | undefined) => {
throw new Error(errorBody || 'Error creating agent');
}
return response.json() as Promise<Agent>;
return await response.json();
},
onSuccess: () => queryClient.invalidateQueries({ queryKey: [userId, 'agents', 'list'] }),
});

View File

@@ -1,14 +1,21 @@
import { useQuery } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Agent } from './agent';
export const useAgentsQuery = (userId: string | null | undefined) =>
useQuery({
export const useAgentsQuery = (userId: string | null | undefined) => {
const bearerToken = useAuthBearerToken();
return useQuery({
queryKey: [userId, 'agents', 'list'],
enabled: !!userId,
queryFn: async () =>
(await fetch(API_BASE_URL + `/agents?user_id=${userId}`).then((res) => res.json())) as Promise<{
(await fetch(API_BASE_URL + `/agents?user_id=${userId}`, {
headers: {
Authorization: bearerToken,
},
}).then((res) => res.json())) as Promise<{
num_agents: number;
agents: Agent[];
}>,
});
};

View File

@@ -3,19 +3,44 @@ import { createJSONStorage, persist } from 'zustand/middleware';
export type AuthState = {
uuid: string | null;
token: string | null;
loggedIn: boolean;
};
export type AuthActions = {
setAsAuthenticated: (uuid: string, token?: string) => void;
setToken: (token: string) => void;
logout: () => void;
};
export type AuthActions = { setAsAuthenticated: (uuid: string) => void };
const useAuthStore = create(
persist<{ auth: AuthState; actions: AuthActions }>(
(set, get) => ({
auth: { uuid: null },
auth: { uuid: null, token: null, loggedIn: false },
actions: {
setAsAuthenticated: (uuid: string) =>
setToken: (token: string) =>
set((prev) => ({
...prev,
auth: {
...prev.auth,
token,
},
})),
setAsAuthenticated: (uuid: string, token?: string) =>
set((prev) => ({
...prev,
auth: {
token: token ?? prev.auth.token,
uuid,
loggedIn: true,
},
})),
logout: () =>
set((prev) => ({
...prev,
auth: {
token: null,
uuid: null,
loggedIn: false,
},
})),
},
@@ -30,3 +55,7 @@ const useAuthStore = create(
export const useAuthStoreState = () => useAuthStore().auth;
export const useAuthStoreActions = () => useAuthStore().actions;
export const useAuthBearerToken = () => {
const { auth } = useAuthStore();
return auth.token ? `Bearer ${auth.token}` : '';
};

View File

@@ -0,0 +1,23 @@
import { useMutation } from '@tanstack/react-query';
import { API_BASE_URL } from '../constants';
import { useAuthStoreActions } from './auth.store';
export type AuthResponse = { uuid: string };
export const useAuthMutation = () => {
const { setAsAuthenticated } = useAuthStoreActions();
return useMutation({
mutationKey: ['auth'],
mutationFn: (password: string) =>
fetch(API_BASE_URL + `/auth`, {
method: 'POST',
headers: { 'Content-Type': ' application/json' },
body: JSON.stringify({ password }),
}).then((res) => {
if (!res.ok) {
throw new Error('Network response was not ok');
}
return res.json();
}) as Promise<AuthResponse>,
onSuccess: (data, password) => setAsAuthenticated(data.uuid, password),
});
};

View File

@@ -0,0 +1,27 @@
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Human } from './human';
export const useHumansCreateMutation = (userId: string | null | undefined) => {
const queryClient = useQueryClient();
const bearerToken = useAuthBearerToken();
return useMutation({
mutationFn: async (params: { name: string; text: string }): Promise<Human> => {
const response = await fetch(API_BASE_URL + '/agents', {
method: 'POST',
headers: { 'Content-Type': ' application/json', Authorization: bearerToken },
body: JSON.stringify({ config: params, user_id: userId }),
});
if (!response.ok) {
// Throw an error if the response is not OK
const errorBody = await response.text();
throw new Error(errorBody || 'Error creating human');
}
return await response.json();
},
onSuccess: () => queryClient.invalidateQueries({ queryKey: [userId, 'humans', 'list'] }),
});
};

View File

@@ -1,13 +1,19 @@
import { useQuery } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Human } from './human';
export const useHumansQuery = (userId: string | null | undefined) =>
useQuery({
export const useHumansQuery = (userId: string | null | undefined) => {
const bearerToken = useAuthBearerToken();
return useQuery({
queryKey: [userId, 'humans', 'list'],
enabled: !!userId,
queryFn: async () => {
const response = await fetch(`${API_BASE_URL}/humans?user_id=${encodeURIComponent(userId || '')}`);
const response = await fetch(`${API_BASE_URL}/humans?user_id=${encodeURIComponent(userId || '')}`, {
headers: {
Authorization: bearerToken,
},
});
if (!response.ok) {
throw new Error('Network response was not ok for fetching humans');
}
@@ -16,3 +22,4 @@ export const useHumansQuery = (userId: string | null | undefined) =>
}>;
},
});
};

View File

@@ -13,7 +13,7 @@ export type MessageHistory = {
const useMessageHistoryStore = create(
persist<{ history: MessageHistory; actions: { addMessage: (key: string, message: Message) => void } }>(
(set, get) => ({
(set) => ({
history: {},
actions: {
addMessage: (key: string, message: Message) =>

View File

@@ -18,7 +18,7 @@ const useMessageStreamStore = create(
{
socket: null as EventSource | null,
socketURL: null as string | null,
readyState: ReadyState.IDLE,
readyState: ReadyState.IDLE as ReadyState,
abortController: null as AbortController | null,
onMessageCallback: ((message: Message) =>
console.warn('No message callback set up. Simply logging message', message)) as (message: Message) => void,
@@ -30,10 +30,12 @@ const useMessageStreamStore = create(
agentId,
message,
role,
bearerToken,
}: {
userId: string;
agentId: string;
message: string;
bearerToken: string;
role?: 'user' | 'system';
}) => {
const abortController = new AbortController();
@@ -49,7 +51,7 @@ const useMessageStreamStore = create(
});
void fetchEventSource(ENDPOINT_URL, {
method: 'POST',
headers: { 'Content-Type': 'application/json', Accept: 'text/event-stream' },
headers: { 'Content-Type': 'application/json', Accept: 'text/event-stream', Authorization: bearerToken },
body: JSON.stringify({
user_id: userId,
agent_id: agentId,

View File

@@ -1,4 +1,5 @@
import { useQuery } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
export const useMessagesQuery = (
@@ -6,14 +7,21 @@ export const useMessagesQuery = (
agentId: string | null | undefined,
start = 0,
count = 10
) =>
useQuery({
) => {
const bearerToken = useAuthBearerToken();
return useQuery({
queryKey: [userId, 'agents', 'item', agentId, 'messages', 'list', start, count],
queryFn: async () =>
(await fetch(
API_BASE_URL + `/agents/message?agent_id=${agentId}&user_id=${userId}&start=${start}&count=${count}`
API_BASE_URL + `/agents/message?agent_id=${agentId}&user_id=${userId}&start=${start}&count=${count}`,
{
headers: {
Authorization: bearerToken,
},
}
).then((res) => res.json())) as Promise<{
messages: { role: string; name: string; content: string }[];
}>,
enabled: !!userId && !!agentId,
});
};

View File

@@ -1,13 +1,19 @@
import { useQuery } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Model } from './model';
export const useModelsQuery = (userId: string | null | undefined) =>
useQuery({
export const useModelsQuery = (userId: string | null | undefined) => {
const bearerToken = useAuthBearerToken();
return useQuery({
queryKey: [userId, 'models', 'list'],
enabled: !!userId,
queryFn: async () => {
const response = await fetch(`${API_BASE_URL}/models?user_id=${encodeURIComponent(userId || '')}`);
const response = await fetch(`${API_BASE_URL}/models?user_id=${encodeURIComponent(userId || '')}`, {
headers: {
Authorization: bearerToken,
},
});
if (!response.ok) {
throw new Error('Network response was not ok for fetching models');
}
@@ -16,3 +22,4 @@ export const useModelsQuery = (userId: string | null | undefined) =>
}>;
},
});
};

View File

@@ -0,0 +1,27 @@
import { useMutation, useQueryClient } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Persona } from './persona';
export const usePersonasCreateMutation = (userId: string | null | undefined) => {
const queryClient = useQueryClient();
const bearerToken = useAuthBearerToken();
return useMutation({
mutationFn: async (params: { name: string; text: string }): Promise<Persona> => {
const response = await fetch(API_BASE_URL + '/agents', {
method: 'POST',
headers: { 'Content-Type': ' application/json', Authorization: bearerToken },
body: JSON.stringify({ config: params, user_id: userId }),
});
if (!response.ok) {
// Throw an error if the response is not OK
const errorBody = await response.text();
throw new Error(errorBody || 'Error creating persona');
}
return await response.json();
},
onSuccess: () => queryClient.invalidateQueries({ queryKey: [userId, 'personas', 'list'] }),
});
};

View File

@@ -1,13 +1,19 @@
import { useQuery } from '@tanstack/react-query';
import { useAuthBearerToken } from '../auth/auth.store';
import { API_BASE_URL } from '../constants';
import { Persona } from './persona';
export const usePersonasQuery = (userId: string | null | undefined) =>
useQuery({
export const usePersonasQuery = (userId: string | null | undefined) => {
const bearerToken = useAuthBearerToken();
return useQuery({
queryKey: [userId, 'personas', 'list'],
enabled: !!userId, // The query will not execute unless userId is truthy
queryFn: async () => {
const response = await fetch(`${API_BASE_URL}/personas?user_id=${encodeURIComponent(userId || '')}`);
const response = await fetch(`${API_BASE_URL}/personas?user_id=${encodeURIComponent(userId || '')}`, {
headers: {
Authorization: bearerToken,
},
});
if (!response.ok) {
throw new Error('Network response was not ok');
}
@@ -16,3 +22,4 @@ export const usePersonasQuery = (userId: string | null | undefined) =>
}>;
},
});
};

View File

@@ -1,6 +1,6 @@
import { useCallback, useEffect, useRef } from 'react';
import { useAgentActions, useCurrentAgent, useLastAgentInitMessage } from '../../libs/agents/agent.store';
import { useAuthStoreState } from '../../libs/auth/auth.store';
import { useAuthBearerToken, useAuthStoreState } from '../../libs/auth/auth.store';
import { useMessageHistoryActions, useMessagesForKey } from '../../libs/messages/message-history.store';
import {
ReadyState,
@@ -21,12 +21,13 @@ const Chat = () => {
const { sendMessage } = useMessageSocketActions();
const { addMessage } = useMessageHistoryActions();
const { setLastAgentInitMessage } = useAgentActions();
const bearerToken = useAuthBearerToken();
const sendMessageAndAddToHistory = useCallback(
(message: string, role: 'user' | 'system' = 'user') => {
if (!currentAgent || !auth.uuid) return;
const date = new Date();
sendMessage({ userId: auth.uuid, agentId: currentAgent.id, message, role });
sendMessage({ userId: auth.uuid, agentId: currentAgent.id, message, role, bearerToken });
addMessage(currentAgent.id, {
type: role === 'user' ? 'user_message' : 'system_message',
message_type: 'user_message',

View File

@@ -0,0 +1,71 @@
import { Avatar, AvatarFallback, AvatarImage } from '@memgpt/components/avatar';
import { Button } from '@memgpt/components/button';
import { Input } from '@memgpt/components/input';
import { Label } from '@memgpt/components/label';
import { cnH3, cnMuted } from '@memgpt/components/typography';
import { Loader2 } from 'lucide-react';
import { useNavigate } from 'react-router-dom';
import { useAuthMutation } from '../../../libs/auth/use-auth.mutation';
const year = new Date().getFullYear();
const LoginPage = () => {
const mutation = useAuthMutation();
const navigate = useNavigate();
return (
<div className="relative flex h-full w-full items-center justify-center">
<div className="-mt-40 flex max-w-sm flex-col items-center justify-center">
<Avatar className="mb-2 h-16 w-16 border bg-white">
<AvatarImage alt="MemGPT logo." src="/memgpt_logo_transparent.png" />
<AvatarFallback className="border">MG</AvatarFallback>
</Avatar>
<h1 className={cnH3('mb-2')}>Welcome to MemGPT</h1>
<p className="mb-6 text-muted-foreground">Sign in below to start chatting with your agent</p>
<form
className="w-full"
onSubmit={(e) => {
e.preventDefault();
const password = new FormData(e.currentTarget).get('password') as string;
if (!password || password.length === 0) return;
mutation.mutate(password, {
onSuccess: ({ uuid }, password) => setTimeout(() => navigate('/'), 600),
});
}}
>
<Label className="sr-only" htmlFor="password">
Password
</Label>
<Input
name="password"
className="mb-2 w-full"
type="password"
autoComplete="off"
autoCorrect="off"
id="password"
/>
<Button type="submit" className="mb-6 w-full">
{mutation.isPending ? (
<span className="flex items-center animate-in slide-in-from-bottom-2">
{/* eslint-disable-next-line react/jsx-no-undef */}
<Loader2 className="mr-2 h-4 w-4 animate-spin" />
Signing in
</span>
) : null}
{mutation.isSuccess ? <span className="animate-in slide-in-from-bottom-2">Signed in!</span> : null}
{!mutation.isPending && mutation.isError ? (
<span className="animate-in slide-in-from-bottom-2">Sign In Failed. Try again...</span>
) : null}
{!mutation.isPending && !mutation.isSuccess && !mutation.isError ? (
<span>Sign In with Password</span>
) : null}
</Button>
</form>
<p className="text-center text-muted-foreground">
By clicking continue, you agree to our Terms of Service and Privacy Policy.
</p>
</div>
<p className={cnMuted('absolute inset-x-0 bottom-3 text-center')}>&copy; {year} MemGPT</p>
</div>
);
};
export default LoginPage;

View File

@@ -0,0 +1,7 @@
import { RouteObject } from 'react-router-dom';
import LoginPage from './login.page';
export const loginRoutes: RouteObject = {
path: 'login',
element: <LoginPage />,
};

View File

@@ -2,6 +2,7 @@ import { createBrowserRouter, Outlet } from 'react-router-dom';
import Auth from './auth';
import { chatRoute } from './modules/chat/chat.routes';
import Home from './modules/home/home';
import { loginRoutes } from './modules/public/login/login.routes';
import { settingsRoute } from './modules/settings/settings.routes';
import Footer from './shared/layout/footer';
import Header from './shared/layout/header';
@@ -30,4 +31,5 @@ export const router = createBrowserRouter([
settingsRoute,
],
},
loginRoutes,
]);

View File

@@ -1,12 +1,14 @@
import { Avatar, AvatarFallback, AvatarImage } from '@memgpt/components/avatar';
import { Button } from '@memgpt/components/button';
import { MoonStar, Sun } from 'lucide-react';
import { LucideLogOut, MoonStar, Sun } from 'lucide-react';
import { NavLink } from 'react-router-dom';
import { useAuthStoreActions } from '../../libs/auth/auth.store';
import { useTheme } from '../theme';
const twNavLink = '[&.active]:opacity-100 opacity-60';
const Header = () => {
const { theme, toggleTheme } = useTheme();
const { logout } = useAuthStoreActions();
return (
<div className="flex items-start justify-between border-b py-2 sm:px-8">
<NavLink to="/">
@@ -29,7 +31,6 @@ const Header = () => {
</NavLink>
</Button>
<Button size="sm" asChild variant="link">
{/* @ts-ignore */}
<NavLink className={twNavLink} to="/settings/agents">
Settings
</NavLink>
@@ -37,6 +38,9 @@ const Header = () => {
<Button size="icon" variant="ghost" onClick={toggleTheme}>
{theme === 'light' ? <MoonStar className="h-4 w-4" /> : <Sun className="w-4 w-4" />}
</Button>
<Button size="icon" variant="ghost" onClick={logout}>
<LucideLogOut className="h-4 w-4" />
</Button>
</nav>
</div>
);

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

File diff suppressed because one or more lines are too long

Binary file not shown.

Before

Width:  |  Height:  |  Size: 15 KiB

After

Width:  |  Height:  |  Size: 28 KiB

View File

@@ -2,7 +2,7 @@
<html lang="en">
<head>
<meta charset="utf-8" />
<title>MemgptFrontend</title>
<title>MemGPT</title>
<base href="/" />
<meta name="viewport" content="width=device-width, initial-scale=1" />
@@ -29,8 +29,8 @@
}
}
</script>
<script type="module" crossorigin src="/assets/index-273ebfe0.js"></script>
<link rel="stylesheet" href="/assets/index-9ace7bf7.css">
<script type="module" crossorigin src="/assets/index-f6a3d52a.js"></script>
<link rel="stylesheet" href="/assets/index-57df4f6c.css">
</head>
<body>
<div class="h-full w-full" id="root"></div>