diff --git a/fern/openapi.json b/fern/openapi.json index 6837c527..8b2a10b7 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -29385,6 +29385,49 @@ "title": "BedrockModelSettings", "description": "AWS Bedrock model configuration." }, + "BillingContext": { + "properties": { + "plan_type": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Plan Type", + "description": "Subscription tier" + }, + "cost_source": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Cost Source", + "description": "Cost source: 'quota' or 'credits'" + }, + "customer_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Customer Id", + "description": "Customer ID for billing records" + } + }, + "type": "object", + "title": "BillingContext", + "description": "Billing context for LLM request cost tracking." + }, "Block": { "properties": { "value": { @@ -42964,6 +43007,17 @@ ], "title": "Llm Config", "description": "LLM configuration used for this call (non-summarization calls only)" + }, + "billing_context": { + "anyOf": [ + { + "$ref": "#/components/schemas/BillingContext" + }, + { + "type": "null" + } + ], + "description": "Billing context from request headers" } }, "additionalProperties": false, diff --git a/fern/scripts/prepare-openapi.ts b/fern/scripts/prepare-openapi.ts new file mode 100644 index 00000000..99263e85 --- /dev/null +++ b/fern/scripts/prepare-openapi.ts @@ -0,0 +1,220 @@ +import * as fs from 'fs'; +import * as path from 'path'; + +import { omit } from 'lodash'; +import { execSync } from 'child_process'; +import { merge, isErrorResult } from 'openapi-merge'; +import type { Swagger } from 'atlassian-openapi'; +import { RESTRICTED_ROUTE_BASE_PATHS } from '@letta-cloud/sdk-core'; + +const lettaWebOpenAPIPath = path.join( + __dirname, + '..', + '..', + '..', + 'web', + 'autogenerated', + 'letta-web-openapi.json', +); +const lettaAgentsAPIPath = path.join( + __dirname, + '..', + '..', + 'letta', + 'server', + 'openapi_letta.json', +); + +const lettaWebOpenAPI = JSON.parse( + fs.readFileSync(lettaWebOpenAPIPath, 'utf8'), +) as Swagger.SwaggerV3; +const lettaAgentsAPI = JSON.parse( + fs.readFileSync(lettaAgentsAPIPath, 'utf8'), +) as Swagger.SwaggerV3; + +// removes any routes that are restricted +lettaAgentsAPI.paths = Object.fromEntries( + Object.entries(lettaAgentsAPI.paths).filter(([path]) => + RESTRICTED_ROUTE_BASE_PATHS.every( + (restrictedPath) => !path.startsWith(restrictedPath), + ), + ), +); + +const lettaAgentsAPIWithNoEndslash = Object.keys(lettaAgentsAPI.paths).reduce( + (acc, path) => { + const pathWithoutSlash = path.endsWith('/') + ? path.slice(0, path.length - 1) + : path; + acc[pathWithoutSlash] = lettaAgentsAPI.paths[path]; + return acc; + }, + {} as Swagger.SwaggerV3['paths'], +); + +// remove duplicate paths, delete from letta-web-openapi if it exists in sdk-core +// some paths will have an extra / at the end, so we need to remove that as well +lettaWebOpenAPI.paths = Object.fromEntries( + Object.entries(lettaWebOpenAPI.paths).filter(([path]) => { + const pathWithoutSlash = path.endsWith('/') + ? path.slice(0, path.length - 1) + : path; + return !lettaAgentsAPIWithNoEndslash[pathWithoutSlash]; + }), +); + +const agentStatePathsToOverride: Array<[string, string]> = [ + ['/v1/templates/{project}/{template_version}/agents', '201'], + ['/v1/agents/search', '200'], +]; + +for (const [path, responseCode] of agentStatePathsToOverride) { + if (lettaWebOpenAPI.paths[path]?.post?.responses?.[responseCode]) { + // Get direct reference to the schema object + const responseSchema = + lettaWebOpenAPI.paths[path].post.responses[responseCode]; + const contentSchema = responseSchema.content['application/json'].schema; + + // Replace the entire agents array schema with the reference + if (contentSchema.properties?.agents) { + contentSchema.properties.agents = { + type: 'array', + items: { + $ref: '#/components/schemas/AgentState', + }, + }; + } + } +} + +// go through the paths and remove "user_id"/"actor_id" from the headers +for (const path of Object.keys(lettaAgentsAPI.paths)) { + for (const method of Object.keys(lettaAgentsAPI.paths[path])) { + // @ts-expect-error - a + if (lettaAgentsAPI.paths[path][method]?.parameters) { + // @ts-expect-error - a + lettaAgentsAPI.paths[path][method].parameters = lettaAgentsAPI.paths[ + path + ][method].parameters.filter( + (param: Record) => + param.in !== 'header' || + ( + param.name !== 'user_id' && + param.name !== 'User-Agent' && + param.name !== 'X-Project-Id' && + param.name !== 'X-Letta-Source' && + param.name !== 'X-Stainless-Package-Version' && + !param.name.startsWith('X-Experimental') && + !param.name.startsWith('X-Billing') + ), + ); + } + } +} + +const result = merge([ + { + oas: lettaAgentsAPI, + }, + { + oas: lettaWebOpenAPI, + }, +]); + +if (isErrorResult(result)) { + console.error(`${result.message} (${result.type})`); + process.exit(1); +} + +result.output.openapi = '3.1.0'; +result.output.info = { + title: 'Letta API', + version: '1.0.0', +}; + +result.output.servers = [ + { + url: 'https://app.letta.com', + description: 'Letta Cloud', + }, + { + url: 'http://localhost:8283', + description: 'Self-hosted', + }, +]; + +result.output.components = { + ...result.output.components, + securitySchemes: { + bearerAuth: { + type: 'http', + scheme: 'bearer', + }, + }, +}; + +result.output.security = [ + ...(result.output.security || []), + { + bearerAuth: [], + }, +]; + +// omit all instances of "user_id" from the openapi.json file +function deepOmitPreserveArrays(obj: unknown, key: string): unknown { + if (Array.isArray(obj)) { + return obj.map((item) => deepOmitPreserveArrays(item, key)); + } + + if (typeof obj !== 'object' || obj === null) { + return obj; + } + + if (key in obj) { + return omit(obj, key); + } + + return Object.fromEntries( + Object.entries(obj).map(([k, v]) => [k, deepOmitPreserveArrays(v, key)]), + ); +} + +// eslint-disable-next-line @typescript-eslint/ban-ts-comment +// @ts-ignore +result.output.components = deepOmitPreserveArrays( + result.output.components, + 'user_id', +); + +// eslint-disable-next-line @typescript-eslint/ban-ts-comment +// @ts-ignore +result.output.components = deepOmitPreserveArrays( + result.output.components, + 'actor_id', +); + +// eslint-disable-next-line @typescript-eslint/ban-ts-comment +// @ts-ignore +result.output.components = deepOmitPreserveArrays( + result.output.components, + 'organization_id', +); + +fs.writeFileSync( + path.join(__dirname, '..', 'openapi.json'), + JSON.stringify(result.output, null, 2), +); + +function formatOpenAPIJson() { + const openApiPath = path.join(__dirname, '..', 'openapi.json'); + + try { + execSync(`npx prettier --write "${openApiPath}"`, { stdio: 'inherit' }); + console.log('Successfully formatted openapi.json with Prettier'); + } catch (error) { + console.error('Error formatting openapi.json:', error); + process.exit(1); + } +} + +formatOpenAPIJson(); diff --git a/letta/adapters/letta_llm_adapter.py b/letta/adapters/letta_llm_adapter.py index 49e99c49..c78796cb 100644 --- a/letta/adapters/letta_llm_adapter.py +++ b/letta/adapters/letta_llm_adapter.py @@ -7,6 +7,7 @@ from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.llm_config import LLMConfig from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall +from letta.schemas.provider_trace import BillingContext from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.telemetry_manager import TelemetryManager @@ -31,6 +32,7 @@ class LettaLLMAdapter(ABC): run_id: str | None = None, org_id: str | None = None, user_id: str | None = None, + billing_context: BillingContext | None = None, ) -> None: self.llm_client: LLMClientBase = llm_client self.llm_config: LLMConfig = llm_config @@ -40,6 +42,7 @@ class LettaLLMAdapter(ABC): self.run_id: str | None = run_id self.org_id: str | None = org_id self.user_id: str | None = user_id + self.billing_context: BillingContext | None = billing_context self.message_id: str | None = None self.request_data: dict | None = None self.response_data: dict | None = None diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index 76fc6d65..426a15e7 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -10,7 +10,7 @@ from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method from letta.schemas.enums import LLMCallType, ProviderType from letta.schemas.letta_message import LettaMessage from letta.schemas.llm_config import LLMConfig -from letta.schemas.provider_trace import ProviderTrace +from letta.schemas.provider_trace import BillingContext, ProviderTrace from letta.schemas.user import User from letta.settings import settings from letta.utils import safe_create_task @@ -36,6 +36,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): run_id: str | None = None, org_id: str | None = None, user_id: str | None = None, + billing_context: "BillingContext | None" = None, ) -> None: super().__init__( llm_client, @@ -46,6 +47,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): run_id=run_id, org_id=org_id, user_id=user_id, + billing_context=billing_context, ) self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None diff --git a/letta/adapters/simple_llm_request_adapter.py b/letta/adapters/simple_llm_request_adapter.py index f67e7dc9..3f57e41f 100644 --- a/letta/adapters/simple_llm_request_adapter.py +++ b/letta/adapters/simple_llm_request_adapter.py @@ -51,6 +51,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter): org_id=self.org_id, user_id=self.user_id, llm_config=self.llm_config.model_dump() if self.llm_config else None, + billing_context=self.billing_context, ) try: self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config) diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index 26c054fd..a5d880d8 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -278,6 +278,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): org_id=self.org_id, user_id=self.user_id, llm_config=self.llm_config.model_dump() if self.llm_config else None, + billing_context=self.billing_context, ), ), label="create_provider_trace", diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 326dc60a..3e6019a4 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -15,6 +15,7 @@ from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate, MessageUpdate +from letta.schemas.provider_trace import BillingContext from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.agent_manager import AgentManager @@ -51,7 +52,11 @@ class BaseAgent(ABC): @abstractmethod async def step( - self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None + self, + input_messages: List[MessageCreate], + max_steps: int = DEFAULT_MAX_STEPS, + run_id: Optional[str] = None, + billing_context: "BillingContext | None" = None, ) -> LettaResponse: """ Main execution loop for the agent. diff --git a/letta/agents/base_agent_v2.py b/letta/agents/base_agent_v2.py index b6fe89ce..515edb76 100644 --- a/letta/agents/base_agent_v2.py +++ b/letta/agents/base_agent_v2.py @@ -12,6 +12,7 @@ from letta.schemas.user import User if TYPE_CHECKING: from letta.schemas.letta_request import ClientToolSchema + from letta.schemas.provider_trace import BillingContext class BaseAgentV2(ABC): @@ -52,6 +53,7 @@ class BaseAgentV2(ABC): request_start_timestamp_ns: int | None = None, client_tools: list["ClientToolSchema"] | None = None, include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility + billing_context: "BillingContext | None" = None, ) -> LettaResponse: """ Execute the agent loop in blocking mode, returning all messages at once. @@ -76,6 +78,7 @@ class BaseAgentV2(ABC): conversation_id: str | None = None, client_tools: list["ClientToolSchema"] | None = None, include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility + billing_context: "BillingContext | None" = None, ) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]: """ Execute the agent loop in streaming mode, yielding chunks as they become available. diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index be6a378b..1086d7c9 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -48,6 +48,7 @@ from letta.schemas.openai.chat_completion_response import ( UsageStatisticsCompletionTokenDetails, UsageStatisticsPromptTokenDetails, ) +from letta.schemas.provider_trace import BillingContext from letta.schemas.step import StepProgression from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool_execution_result import ToolExecutionResult @@ -179,6 +180,7 @@ class LettaAgent(BaseAgent): request_start_timestamp_ns: int | None = None, include_return_message_types: list[MessageType] | None = None, dry_run: bool = False, + billing_context: "BillingContext | None" = None, ) -> Union[LettaResponse, dict]: # TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions agent_state = await self.agent_manager.get_agent_by_id_async( diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 686d49fb..13fb9b07 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -44,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import ( UsageStatisticsCompletionTokenDetails, UsageStatisticsPromptTokenDetails, ) +from letta.schemas.provider_trace import BillingContext from letta.schemas.step import Step, StepProgression from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool import Tool @@ -185,6 +186,7 @@ class LettaAgentV2(BaseAgentV2): request_start_timestamp_ns: int | None = None, client_tools: list[ClientToolSchema] | None = None, include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility + billing_context: "BillingContext | None" = None, ) -> LettaResponse: """ Execute the agent loop in blocking mode, returning all messages at once. @@ -290,6 +292,7 @@ class LettaAgentV2(BaseAgentV2): conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility client_tools: list[ClientToolSchema] | None = None, include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility + billing_context: BillingContext | None = None, ) -> AsyncGenerator[str, None]: """ Execute the agent loop in streaming mode, yielding chunks as they become available. diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 7c556710..3a77c011 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -45,6 +45,7 @@ from letta.schemas.letta_response import LettaResponse, TurnTokenData from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, ToolCall, ToolCallDenial, UsageStatistics +from letta.schemas.provider_trace import BillingContext from letta.schemas.step import StepProgression from letta.schemas.step_metrics import StepMetrics from letta.schemas.tool_execution_result import ToolExecutionResult @@ -149,6 +150,7 @@ class LettaAgentV3(LettaAgentV2): conversation_id: str | None = None, client_tools: list[ClientToolSchema] | None = None, include_compaction_messages: bool = False, + billing_context: "BillingContext | None" = None, ) -> LettaResponse: """ Execute the agent loop in blocking mode, returning all messages at once. @@ -232,6 +234,7 @@ class LettaAgentV3(LettaAgentV2): run_id=run_id, org_id=self.actor.organization_id, user_id=self.actor.id, + billing_context=billing_context, ) credit_task = None @@ -362,6 +365,7 @@ class LettaAgentV3(LettaAgentV2): conversation_id: str | None = None, client_tools: list[ClientToolSchema] | None = None, include_compaction_messages: bool = False, + billing_context: BillingContext | None = None, ) -> AsyncGenerator[str, None]: """ Execute the agent loop in streaming mode, yielding chunks as they become available. @@ -419,6 +423,7 @@ class LettaAgentV3(LettaAgentV2): run_id=run_id, org_id=self.actor.organization_id, user_id=self.actor.id, + billing_context=billing_context, ) elif use_sglang_native: # Use SGLang native adapter for multi-turn RL training @@ -431,6 +436,7 @@ class LettaAgentV3(LettaAgentV2): run_id=run_id, org_id=self.actor.organization_id, user_id=self.actor.id, + billing_context=billing_context, ) # Reset turns tracking for this step self.turns = [] @@ -444,6 +450,7 @@ class LettaAgentV3(LettaAgentV2): run_id=run_id, org_id=self.actor.organization_id, user_id=self.actor.id, + billing_context=billing_context, ) try: diff --git a/letta/groups/sleeptime_multi_agent_v2.py b/letta/groups/sleeptime_multi_agent_v2.py index 65b33632..e842e0b7 100644 --- a/letta/groups/sleeptime_multi_agent_v2.py +++ b/letta/groups/sleeptime_multi_agent_v2.py @@ -13,6 +13,7 @@ from letta.schemas.letta_message import MessageType from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.message import Message, MessageCreate +from letta.schemas.provider_trace import BillingContext from letta.schemas.run import Run from letta.schemas.user import User from letta.services.agent_manager import AgentManager @@ -69,6 +70,7 @@ class SleeptimeMultiAgentV2(BaseAgent): use_assistant_message: bool = True, request_start_timestamp_ns: int | None = None, include_return_message_types: list[MessageType] | None = None, + billing_context: "BillingContext | None" = None, ) -> LettaResponse: run_ids = [] @@ -100,6 +102,7 @@ class SleeptimeMultiAgentV2(BaseAgent): run_id=run_id, use_assistant_message=use_assistant_message, include_return_message_types=include_return_message_types, + billing_context=billing_context, ) # Get last response messages diff --git a/letta/groups/sleeptime_multi_agent_v3.py b/letta/groups/sleeptime_multi_agent_v3.py index d1c8c302..257d5d97 100644 --- a/letta/groups/sleeptime_multi_agent_v3.py +++ b/letta/groups/sleeptime_multi_agent_v3.py @@ -15,6 +15,7 @@ from letta.schemas.letta_request import ClientToolSchema from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.message import Message, MessageCreate +from letta.schemas.provider_trace import BillingContext from letta.schemas.run import Run, RunUpdate from letta.schemas.user import User from letta.services.group_manager import GroupManager @@ -47,6 +48,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2): request_start_timestamp_ns: int | None = None, client_tools: list[ClientToolSchema] | None = None, include_compaction_messages: bool = False, + billing_context: "BillingContext | None" = None, ) -> LettaResponse: self.run_ids = [] @@ -62,6 +64,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2): request_start_timestamp_ns=request_start_timestamp_ns, client_tools=client_tools, include_compaction_messages=include_compaction_messages, + billing_context=billing_context, ) await self.run_sleeptime_agents() diff --git a/letta/groups/sleeptime_multi_agent_v4.py b/letta/groups/sleeptime_multi_agent_v4.py index 9995ee15..8fb4d049 100644 --- a/letta/groups/sleeptime_multi_agent_v4.py +++ b/letta/groups/sleeptime_multi_agent_v4.py @@ -14,6 +14,7 @@ from letta.schemas.letta_request import ClientToolSchema from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import StopReasonType from letta.schemas.message import Message, MessageCreate +from letta.schemas.provider_trace import BillingContext from letta.schemas.run import Run, RunUpdate from letta.schemas.user import User from letta.services.group_manager import GroupManager @@ -47,6 +48,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3): conversation_id: str | None = None, client_tools: list[ClientToolSchema] | None = None, include_compaction_messages: bool = False, + billing_context: "BillingContext | None" = None, ) -> LettaResponse: self.run_ids = [] @@ -63,6 +65,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3): conversation_id=conversation_id, client_tools=client_tools, include_compaction_messages=include_compaction_messages, + billing_context=billing_context, ) run_ids = await self.run_sleeptime_agents() diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 0cdbe894..080ab2df 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -14,7 +14,7 @@ from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionResponse -from letta.schemas.provider_trace import ProviderTrace +from letta.schemas.provider_trace import BillingContext, ProviderTrace from letta.schemas.usage import LettaUsageStatistics from letta.services.telemetry_manager import TelemetryManager from letta.settings import settings @@ -48,6 +48,7 @@ class LLMClientBase: self._telemetry_user_id: Optional[str] = None self._telemetry_compaction_settings: Optional[Dict] = None self._telemetry_llm_config: Optional[Dict] = None + self._telemetry_billing_context: Optional[BillingContext] = None def set_telemetry_context( self, @@ -62,6 +63,7 @@ class LLMClientBase: compaction_settings: Optional[Dict] = None, llm_config: Optional[Dict] = None, actor: Optional["User"] = None, + billing_context: Optional[BillingContext] = None, ) -> None: """Set telemetry context for provider trace logging.""" if actor is not None: @@ -76,6 +78,7 @@ class LLMClientBase: self._telemetry_user_id = user_id self._telemetry_compaction_settings = compaction_settings self._telemetry_llm_config = llm_config + self._telemetry_billing_context = billing_context def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics: """Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics.""" @@ -125,6 +128,7 @@ class LLMClientBase: user_id=self._telemetry_user_id, compaction_settings=self._telemetry_compaction_settings, llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config, + billing_context=self._telemetry_billing_context, ), ) except Exception as e: @@ -186,6 +190,7 @@ class LLMClientBase: user_id=self._telemetry_user_id, compaction_settings=self._telemetry_compaction_settings, llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config, + billing_context=self._telemetry_billing_context, ), ) except Exception as e: diff --git a/letta/schemas/llm_trace.py b/letta/schemas/llm_trace.py index 13cbb806..2ba7c520 100644 --- a/letta/schemas/llm_trace.py +++ b/letta/schemas/llm_trace.py @@ -95,6 +95,11 @@ class LLMTrace(LettaBase): response_json: str = Field(..., description="Full response payload as JSON string") llm_config_json: str = Field(default="", description="LLM config as JSON string") + # Billing context + billing_plan_type: Optional[str] = Field(default=None, description="Subscription tier (e.g., 'basic', 'standard', 'max', 'enterprise')") + billing_cost_source: Optional[str] = Field(default=None, description="Cost source: 'quota' or 'credits'") + billing_customer_id: Optional[str] = Field(default=None, description="Customer ID for cross-referencing billing records") + # Timestamp created_at: datetime = Field(default_factory=get_utc_time, description="When the trace was created") @@ -128,6 +133,9 @@ class LLMTrace(LettaBase): self.request_json, self.response_json, self.llm_config_json, + self.billing_plan_type or "", + self.billing_cost_source or "", + self.billing_customer_id or "", self.created_at, ) @@ -162,5 +170,8 @@ class LLMTrace(LettaBase): "request_json", "response_json", "llm_config_json", + "billing_plan_type", + "billing_cost_source", + "billing_customer_id", "created_at", ] diff --git a/letta/schemas/provider_trace.py b/letta/schemas/provider_trace.py index 0f4202e8..9256b032 100644 --- a/letta/schemas/provider_trace.py +++ b/letta/schemas/provider_trace.py @@ -3,13 +3,21 @@ from __future__ import annotations from datetime import datetime from typing import Any, Dict, Optional -from pydantic import Field +from pydantic import BaseModel, Field from letta.helpers.datetime_helpers import get_utc_time from letta.schemas.enums import PrimitiveType from letta.schemas.letta_base import OrmMetadataBase +class BillingContext(BaseModel): + """Billing context for LLM request cost tracking.""" + + plan_type: Optional[str] = Field(None, description="Subscription tier") + cost_source: Optional[str] = Field(None, description="Cost source: 'quota' or 'credits'") + customer_id: Optional[str] = Field(None, description="Customer ID for billing records") + + class BaseProviderTrace(OrmMetadataBase): __id_prefix__ = PrimitiveType.PROVIDER_TRACE.value @@ -53,6 +61,8 @@ class ProviderTrace(BaseProviderTrace): compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)") llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)") + billing_context: Optional[BillingContext] = Field(None, description="Billing context from request headers") + created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/server/rest_api/dependencies.py b/letta/server/rest_api/dependencies.py index b6f6b6cc..66ea43bb 100644 --- a/letta/server/rest_api/dependencies.py +++ b/letta/server/rest_api/dependencies.py @@ -6,6 +6,7 @@ from pydantic import BaseModel from letta.errors import LettaInvalidArgumentError from letta.otel.tracing import tracer from letta.schemas.enums import PrimitiveType +from letta.schemas.provider_trace import BillingContext from letta.validators import PRIMITIVE_ID_PATTERNS if TYPE_CHECKING: @@ -30,18 +31,24 @@ class HeaderParams(BaseModel): letta_source: Optional[str] = None sdk_version: Optional[str] = None experimental_params: Optional[ExperimentalParams] = None + billing_context: Optional[BillingContext] = None def get_headers( actor_id: Optional[str] = Header(None, alias="user_id"), user_agent: Optional[str] = Header(None, alias="User-Agent"), project_id: Optional[str] = Header(None, alias="X-Project-Id"), - letta_source: Optional[str] = Header(None, alias="X-Letta-Source"), - sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version"), - message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async"), - letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent"), - letta_v1_agent_message_async: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent-Message-Async"), - modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox"), + letta_source: Optional[str] = Header(None, alias="X-Letta-Source", include_in_schema=False), + sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version", include_in_schema=False), + message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async", include_in_schema=False), + letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent", include_in_schema=False), + letta_v1_agent_message_async: Optional[str] = Header( + None, alias="X-Experimental-Letta-V1-Agent-Message-Async", include_in_schema=False + ), + modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox", include_in_schema=False), + billing_plan_type: Optional[str] = Header(None, alias="X-Billing-Plan-Type", include_in_schema=False), + billing_cost_source: Optional[str] = Header(None, alias="X-Billing-Cost-Source", include_in_schema=False), + billing_customer_id: Optional[str] = Header(None, alias="X-Billing-Customer-Id", include_in_schema=False), ) -> HeaderParams: """Dependency injection function to extract common headers from requests.""" with tracer.start_as_current_span("dependency.get_headers"): @@ -63,6 +70,13 @@ def get_headers( letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None, modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None, ), + billing_context=BillingContext( + plan_type=billing_plan_type, + cost_source=billing_cost_source, + customer_id=billing_customer_id, + ) + if any([billing_plan_type, billing_cost_source, billing_customer_id]) + else None, ) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 9f4a079e..dce47240 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -49,6 +49,7 @@ from letta.schemas.memory import ( ) from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult from letta.schemas.passage import Passage +from letta.schemas.provider_trace import BillingContext from letta.schemas.run import Run as PydanticRun, RunUpdate from letta.schemas.source import Source from letta.schemas.tool import Tool @@ -1697,6 +1698,7 @@ async def send_message( actor=actor, request=request, run_type="send_message", + billing_context=headers.billing_context, ) return result @@ -1767,6 +1769,7 @@ async def send_message( include_return_message_types=request.include_return_message_types, client_tools=request.client_tools, include_compaction_messages=request.include_compaction_messages, + billing_context=headers.billing_context, ) run_status = result.stop_reason.stop_reason.run_status return result @@ -1845,6 +1848,7 @@ async def send_message_streaming( actor=actor, request=request, run_type="send_message_streaming", + billing_context=headers.billing_context, ) return result @@ -2043,6 +2047,7 @@ async def _process_message_background( include_return_message_types: list[MessageType] | None = None, override_model: str | None = None, include_compaction_messages: bool = False, + billing_context: "BillingContext | None" = None, ) -> None: """Background task to process the message and update run status.""" request_start_timestamp_ns = get_utc_timestamp_ns() @@ -2074,6 +2079,7 @@ async def _process_message_background( request_start_timestamp_ns=request_start_timestamp_ns, include_return_message_types=include_return_message_types, include_compaction_messages=include_compaction_messages, + billing_context=billing_context, ) runs_manager = RunManager() from letta.schemas.enums import RunStatus @@ -2242,6 +2248,7 @@ async def send_message_async( include_return_message_types=request.include_return_message_types, override_model=request.override_model, include_compaction_messages=request.include_compaction_messages, + billing_context=headers.billing_context, ), label=f"process_message_background_{run.id}", ) diff --git a/letta/server/rest_api/routers/v1/conversations.py b/letta/server/rest_api/routers/v1/conversations.py index 65af1d19..fbca76a4 100644 --- a/letta/server/rest_api/routers/v1/conversations.py +++ b/letta/server/rest_api/routers/v1/conversations.py @@ -19,6 +19,7 @@ from letta.schemas.job import LettaRequestConfig from letta.schemas.letta_message import LettaMessageUnion from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest from letta.schemas.letta_response import LettaResponse +from letta.schemas.provider_trace import BillingContext from letta.schemas.run import Run as PydanticRun from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator @@ -211,6 +212,7 @@ async def _send_agent_direct_message( request: ConversationMessageRequest, server: SyncServer, actor, + billing_context: "BillingContext | None" = None, ) -> StreamingResponse | LettaResponse: """ Handle agent-direct messaging with locking but without conversation features. @@ -244,6 +246,7 @@ async def _send_agent_direct_message( run_type="send_message", conversation_id=None, should_lock=True, + billing_context=billing_context, ) return result @@ -299,6 +302,7 @@ async def _send_agent_direct_message( client_tools=request.client_tools, conversation_id=None, include_compaction_messages=request.include_compaction_messages, + billing_context=billing_context, ) finally: # Release lock @@ -351,6 +355,7 @@ async def send_conversation_message( request=request, server=server, actor=actor, + billing_context=headers.billing_context, ) # Normal conversation mode @@ -383,6 +388,7 @@ async def send_conversation_message( request=streaming_request, run_type="send_conversation_message", conversation_id=conversation_id, + billing_context=headers.billing_context, ) return result @@ -445,6 +451,7 @@ async def send_conversation_message( client_tools=request.client_tools, conversation_id=conversation_id, include_compaction_messages=request.include_compaction_messages, + billing_context=headers.billing_context, ) diff --git a/letta/services/provider_trace_backends/clickhouse.py b/letta/services/provider_trace_backends/clickhouse.py index 3ba84772..e81235bc 100644 --- a/letta/services/provider_trace_backends/clickhouse.py +++ b/letta/services/provider_trace_backends/clickhouse.py @@ -141,6 +141,9 @@ class ClickhouseProviderTraceBackend(ProviderTraceBackendClient): request_json=request_json_str, response_json=response_json_str, llm_config_json=llm_config_json_str, + billing_plan_type=provider_trace.billing_context.plan_type if provider_trace.billing_context else None, + billing_cost_source=provider_trace.billing_context.cost_source if provider_trace.billing_context else None, + billing_customer_id=provider_trace.billing_context.customer_id if provider_trace.billing_context else None, ) def _extract_usage(self, response_json: dict, provider: str) -> dict: diff --git a/letta/services/provider_trace_backends/postgres.py b/letta/services/provider_trace_backends/postgres.py index a70eadf8..938a4874 100644 --- a/letta/services/provider_trace_backends/postgres.py +++ b/letta/services/provider_trace_backends/postgres.py @@ -29,7 +29,7 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient): ) -> ProviderTrace: """Write full provider trace to provider_traces table.""" async with db_registry.async_session() as session: - provider_trace_model = ProviderTraceModel(**provider_trace.model_dump()) + provider_trace_model = ProviderTraceModel(**provider_trace.model_dump(exclude={"billing_context"})) provider_trace_model.organization_id = actor.organization_id if provider_trace.request_json: diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index 9bb9901e..c025405e 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -34,6 +34,7 @@ from letta.schemas.letta_request import ClientToolSchema, LettaStreamingRequest from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import MessageCreate +from letta.schemas.provider_trace import BillingContext from letta.schemas.run import Run as PydanticRun, RunUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User @@ -78,6 +79,7 @@ class StreamingService: run_type: str = "streaming", conversation_id: Optional[str] = None, should_lock: bool = False, + billing_context: "BillingContext | None" = None, ) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]: """ Create a streaming response for an agent. @@ -176,6 +178,7 @@ class StreamingService: lock_key=lock_key, # For lock release (may differ from conversation_id) client_tools=request.client_tools, include_compaction_messages=request.include_compaction_messages, + billing_context=billing_context, ) # handle background streaming if requested @@ -340,6 +343,7 @@ class StreamingService: lock_key: Optional[str] = None, client_tools: Optional[list[ClientToolSchema]] = None, include_compaction_messages: bool = False, + billing_context: BillingContext | None = None, ) -> AsyncIterator: """ Create a stream with unified error handling. @@ -368,6 +372,7 @@ class StreamingService: conversation_id=conversation_id, client_tools=client_tools, include_compaction_messages=include_compaction_messages, + billing_context=billing_context, ) async for chunk in stream: diff --git a/tests/test_utils.py b/tests/test_utils.py index 2aea57e5..0c34efc0 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -24,6 +24,9 @@ def test_get_headers_user_id_allows_none(): letta_v1_agent=None, letta_v1_agent_message_async=None, modal_sandbox=None, + billing_plan_type=None, + billing_cost_source=None, + billing_customer_id=None, ) assert isinstance(headers, HeaderParams) @@ -40,6 +43,9 @@ def test_get_headers_user_id_rejects_invalid_format(): letta_v1_agent=None, letta_v1_agent_message_async=None, modal_sandbox=None, + billing_plan_type=None, + billing_cost_source=None, + billing_customer_id=None, ) @@ -54,6 +60,9 @@ def test_get_headers_user_id_accepts_valid_format(): letta_v1_agent=None, letta_v1_agent_message_async=None, modal_sandbox=None, + billing_plan_type=None, + billing_cost_source=None, + billing_customer_id=None, ) assert headers.actor_id == "user-123e4567-e89b-42d3-8456-426614174000"