diff --git a/fern/scripts/prepare-openapi.ts b/fern/scripts/prepare-openapi.ts new file mode 100644 index 00000000..d449f221 --- /dev/null +++ b/fern/scripts/prepare-openapi.ts @@ -0,0 +1,218 @@ +import * as fs from 'fs'; +import * as path from 'path'; + +import { omit } from 'lodash'; +import { execSync } from 'child_process'; +import { merge, isErrorResult } from 'openapi-merge'; +import type { Swagger } from 'atlassian-openapi'; +import { RESTRICTED_ROUTE_BASE_PATHS } from '@letta-cloud/sdk-core'; + +const lettaWebOpenAPIPath = path.join( + __dirname, + '..', + '..', + '..', + 'web', + 'autogenerated', + 'letta-web-openapi.json', +); +const lettaAgentsAPIPath = path.join( + __dirname, + '..', + '..', + 'letta', + 'server', + 'openapi_letta.json', +); + +const lettaWebOpenAPI = JSON.parse( + fs.readFileSync(lettaWebOpenAPIPath, 'utf8'), +) as Swagger.SwaggerV3; +const lettaAgentsAPI = JSON.parse( + fs.readFileSync(lettaAgentsAPIPath, 'utf8'), +) as Swagger.SwaggerV3; + +// removes any routes that are restricted +lettaAgentsAPI.paths = Object.fromEntries( + Object.entries(lettaAgentsAPI.paths).filter(([path]) => + RESTRICTED_ROUTE_BASE_PATHS.every( + (restrictedPath) => !path.startsWith(restrictedPath), + ), + ), +); + +const lettaAgentsAPIWithNoEndslash = Object.keys(lettaAgentsAPI.paths).reduce( + (acc, path) => { + const pathWithoutSlash = path.endsWith('/') + ? path.slice(0, path.length - 1) + : path; + acc[pathWithoutSlash] = lettaAgentsAPI.paths[path]; + return acc; + }, + {} as Swagger.SwaggerV3['paths'], +); + +// remove duplicate paths, delete from letta-web-openapi if it exists in sdk-core +// some paths will have an extra / at the end, so we need to remove that as well +lettaWebOpenAPI.paths = Object.fromEntries( + Object.entries(lettaWebOpenAPI.paths).filter(([path]) => { + const pathWithoutSlash = path.endsWith('/') + ? path.slice(0, path.length - 1) + : path; + return !lettaAgentsAPIWithNoEndslash[pathWithoutSlash]; + }), +); + +const agentStatePathsToOverride: Array<[string, string]> = [ + ['/v1/templates/{project}/{template_version}/agents', '201'], + ['/v1/agents/search', '200'], +]; + +for (const [path, responseCode] of agentStatePathsToOverride) { + if (lettaWebOpenAPI.paths[path]?.post?.responses?.[responseCode]) { + // Get direct reference to the schema object + const responseSchema = + lettaWebOpenAPI.paths[path].post.responses[responseCode]; + const contentSchema = responseSchema.content['application/json'].schema; + + // Replace the entire agents array schema with the reference + if (contentSchema.properties?.agents) { + contentSchema.properties.agents = { + type: 'array', + items: { + $ref: '#/components/schemas/AgentState', + }, + }; + } + } +} + +// go through the paths and remove "user_id"/"actor_id" from the headers +for (const path of Object.keys(lettaAgentsAPI.paths)) { + for (const method of Object.keys(lettaAgentsAPI.paths[path])) { + // @ts-expect-error - a + if (lettaAgentsAPI.paths[path][method]?.parameters) { + // @ts-expect-error - a + lettaAgentsAPI.paths[path][method].parameters = lettaAgentsAPI.paths[ + path + ][method].parameters.filter( + (param: Record) => + param.in !== 'header' || + ( + param.name !== 'user_id' && + param.name !== 'User-Agent' && + param.name !== 'X-Project-Id' && + param.name !== 'X-Stainless-Package-Version' && + !param.name.startsWith('X-Experimental') + ), + ); + } + } +} + +const result = merge([ + { + oas: lettaAgentsAPI, + }, + { + oas: lettaWebOpenAPI, + }, +]); + +if (isErrorResult(result)) { + console.error(`${result.message} (${result.type})`); + process.exit(1); +} + +result.output.openapi = '3.1.0'; +result.output.info = { + title: 'Letta API', + version: '1.0.0', +}; + +result.output.servers = [ + { + url: 'https://app.letta.com', + description: 'Letta Cloud', + }, + { + url: 'http://localhost:8283', + description: 'Self-hosted', + }, +]; + +result.output.components = { + ...result.output.components, + securitySchemes: { + bearerAuth: { + type: 'http', + scheme: 'bearer', + }, + }, +}; + +result.output.security = [ + ...(result.output.security || []), + { + bearerAuth: [], + }, +]; + +// omit all instances of "user_id" from the openapi.json file +function deepOmitPreserveArrays(obj: unknown, key: string): unknown { + if (Array.isArray(obj)) { + return obj.map((item) => deepOmitPreserveArrays(item, key)); + } + + if (typeof obj !== 'object' || obj === null) { + return obj; + } + + if (key in obj) { + return omit(obj, key); + } + + return Object.fromEntries( + Object.entries(obj).map(([k, v]) => [k, deepOmitPreserveArrays(v, key)]), + ); +} + +// eslint-disable-next-line @typescript-eslint/ban-ts-comment +// @ts-ignore +result.output.components = deepOmitPreserveArrays( + result.output.components, + 'user_id', +); + +// eslint-disable-next-line @typescript-eslint/ban-ts-comment +// @ts-ignore +result.output.components = deepOmitPreserveArrays( + result.output.components, + 'actor_id', +); + +// eslint-disable-next-line @typescript-eslint/ban-ts-comment +// @ts-ignore +result.output.components = deepOmitPreserveArrays( + result.output.components, + 'organization_id', +); + +fs.writeFileSync( + path.join(__dirname, '..', 'openapi.json'), + JSON.stringify(result.output, null, 2), +); + +function formatOpenAPIJson() { + const openApiPath = path.join(__dirname, '..', 'openapi.json'); + + try { + execSync(`npx prettier --write "${openApiPath}"`, { stdio: 'inherit' }); + console.log('Successfully formatted openapi.json with Prettier'); + } catch (error) { + console.error('Error formatting openapi.json:', error); + process.exit(1); + } +} + +formatOpenAPIJson(); diff --git a/letta/server/rest_api/dependencies.py b/letta/server/rest_api/dependencies.py index d6c87466..aa2888d7 100644 --- a/letta/server/rest_api/dependencies.py +++ b/letta/server/rest_api/dependencies.py @@ -20,6 +20,7 @@ class HeaderParams(BaseModel): actor_id: Optional[str] = None user_agent: Optional[str] = None project_id: Optional[str] = None + sdk_version: Optional[str] = None experimental_params: Optional[ExperimentalParams] = None @@ -27,6 +28,7 @@ def get_headers( actor_id: Optional[str] = Header(None, alias="user_id"), user_agent: Optional[str] = Header(None, alias="User-Agent"), project_id: Optional[str] = Header(None, alias="X-Project-Id"), + sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version"), message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async"), letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent"), ) -> HeaderParams: @@ -35,6 +37,7 @@ def get_headers( actor_id=actor_id, user_agent=user_agent, project_id=project_id, + sdk_version=sdk_version, experimental_params=ExperimentalParams( message_async=(message_async == "true") if message_async else None, letta_v1_agent=(letta_v1_agent == "true") if letta_v1_agent else None, diff --git a/letta/utils.py b/letta/utils.py index b5b619e5..184a98fb 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -41,6 +41,7 @@ from letta.helpers.json_helpers import json_dumps, json_loads from letta.log import get_logger from letta.otel.tracing import log_attributes, trace_method from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from letta.server.rest_api.dependencies import HeaderParams logger = get_logger(__name__) @@ -1400,3 +1401,51 @@ def fire_and_forget(coro, task_name: Optional[str] = None, error_callback: Optio task.add_done_callback(callback) return task + + +def is_1_0_sdk_version(headers: HeaderParams): + """ + Check if the SDK version is 1.0.0 or above. + 1. If sdk_version is provided from stainless (all stainless versions are 1.0.0+) + 2. If user_agent is provided and in the format + @letta-ai/letta-client/version (node) or + letta-client/version (python) + """ + sdk_version = headers.sdk_version + if sdk_version: + return True + + client = headers.user_agent + if "/" not in client: + return False + + # Split into parts to validate format + parts = client.split("/") + + # Should have at least 2 parts (client-name/version) + if len(parts) < 2: + return False + + if len(parts) == 3: + # Format: @letta-ai/letta-client/version + if parts[0] != "@letta-ai" or parts[1] != "letta-client": + return False + elif len(parts) == 2: + # Format: letta-client/version + if parts[0] != "letta-client": + return False + else: + return False + + # Extract and validate version + maybe_version = parts[-1] + if "." not in maybe_version: + return False + + # Extract major version (handle alpha/beta suffixes like 1.0.0-alpha.2 or 1.0.0a5) + version_base = maybe_version.split("-")[0].split("a")[0].split("b")[0] + if "." not in version_base: + return False + + major_version = version_base.split(".")[0] + return major_version == "1" diff --git a/tests/test_utils.py b/tests/test_utils.py index 225dcec4..bf89ed0b 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -3,9 +3,10 @@ import pytest from letta.constants import MAX_FILENAME_LENGTH from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source from letta.schemas.file import FileMetadata +from letta.server.rest_api.dependencies import HeaderParams from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.helpers.agent_manager_helper import safe_format -from letta.utils import sanitize_filename, validate_function_response +from letta.utils import is_1_0_sdk_version, sanitize_filename, validate_function_response CORE_MEMORY_VAR = "My core memory is that I like to eat bananas" VARS_DICT = {"CORE_MEMORY": CORE_MEMORY_VAR} @@ -669,3 +670,15 @@ def test_validate_function_response_whitespace(): """Test whitespace-only string handling""" response = validate_function_response(" \n\t ", return_char_limit=100) assert response == " \n\t " + + +def test_sdk_version_check(): + """Test SDK version check""" + assert not is_1_0_sdk_version(HeaderParams(user_agent="letta-client/0.0.200")) + assert is_1_0_sdk_version(HeaderParams(user_agent="letta-client/1.0.0a5")) + assert not is_1_0_sdk_version(HeaderParams(user_agent="@letta-ai/letta-client/0.0.200")) + assert is_1_0_sdk_version(HeaderParams(user_agent="@letta-ai/letta-client/1.0.0-alpha.5")) + assert is_1_0_sdk_version(HeaderParams(sdk_version="v1.0.0")) + assert is_1_0_sdk_version(HeaderParams(sdk_version="v1.0.0-alpha.7")) + assert is_1_0_sdk_version(HeaderParams(sdk_version="v1.0.0a7")) + assert is_1_0_sdk_version(HeaderParams(sdk_version="v2.0.0"))