Add billing context to LLM telemetry traces (#9745)

* feat: add billing context to LLM telemetry traces

Add billing metadata (plan type, cost source, customer ID) to LLM traces in ClickHouse for cost analytics and attribution.

**Data Flow:**
- Cloud-API: Extract billing info from subscription in rate limiting, set x-billing-* headers
- Core: Parse headers into BillingContext object via dependencies
- Adapters: Flow billing_context through all LLM adapters (blocking & streaming)
- Agent: Pass billing_context to step() and stream() methods
- ClickHouse: Store in billing_plan_type, billing_cost_source, billing_customer_id columns

**Changes:**
- Add BillingContext schema to provider_trace.py
- Add billing columns to llm_traces ClickHouse table DDL
- Update getCustomerSubscription to fetch stripeCustomerId from organization_billing_details
- Propagate billing_context through agent step flow, adapters, and streaming service
- Update ProviderTrace and LLMTrace to include billing metadata
- Regenerate SDK with autogen

**Production Deployment:**
Requires env vars: LETTA_PROVIDER_TRACE_BACKEND=clickhouse, LETTA_STORE_LLM_TRACES=true, CLICKHOUSE_*

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix: add billing_context parameter to agent step methods

- Add billing_context to BaseAgent and BaseAgentV2 abstract methods
- Update LettaAgent, LettaAgentV2, LettaAgentV3 step methods
- Update multi-agent groups: SleeptimeMultiAgentV2, V3, V4
- Fix test_utils.py to include billing header parameters
- Import BillingContext in all affected files

* fix: add billing_context to stream methods

- Add billing_context parameter to BaseAgentV2.stream()
- Add billing_context parameter to LettaAgentV2.stream()
- LettaAgentV3.stream() already has it from previous commit

* fix: exclude billing headers from OpenAPI spec

Mark billing headers as internal (include_in_schema=False) so they don't appear in the public API.
These are internal headers between cloud-api and core, not part of the public SDK.

Regenerated SDK with stage-api - removes 10,650 lines of bloat that was causing OOM during Next.js build.

* refactor: return billing context from handleUnifiedRateLimiting instead of mutating req

Instead of passing req into handleUnifiedRateLimiting and mutating headers inside it:
- Return billing context fields (billingPlanType, billingCostSource, billingCustomerId) from handleUnifiedRateLimiting
- Set headers in handleMessageRateLimiting (middleware layer) after getting the result
- This fixes step-orchestrator compatibility since it doesn't have a real Express req object

* chore: remove extra gencode

* p

---------

Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
cthomas
2026-03-03 13:05:43 -08:00
committed by Caren Thomas
parent db9e0f42af
commit 416ffc7cd7
24 changed files with 392 additions and 11 deletions

View File

@@ -29385,6 +29385,49 @@
"title": "BedrockModelSettings",
"description": "AWS Bedrock model configuration."
},
"BillingContext": {
"properties": {
"plan_type": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Plan Type",
"description": "Subscription tier"
},
"cost_source": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Cost Source",
"description": "Cost source: 'quota' or 'credits'"
},
"customer_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Customer Id",
"description": "Customer ID for billing records"
}
},
"type": "object",
"title": "BillingContext",
"description": "Billing context for LLM request cost tracking."
},
"Block": {
"properties": {
"value": {
@@ -42964,6 +43007,17 @@
],
"title": "Llm Config",
"description": "LLM configuration used for this call (non-summarization calls only)"
},
"billing_context": {
"anyOf": [
{
"$ref": "#/components/schemas/BillingContext"
},
{
"type": "null"
}
],
"description": "Billing context from request headers"
}
},
"additionalProperties": false,

View File

@@ -0,0 +1,220 @@
import * as fs from 'fs';
import * as path from 'path';
import { omit } from 'lodash';
import { execSync } from 'child_process';
import { merge, isErrorResult } from 'openapi-merge';
import type { Swagger } from 'atlassian-openapi';
import { RESTRICTED_ROUTE_BASE_PATHS } from '@letta-cloud/sdk-core';
const lettaWebOpenAPIPath = path.join(
__dirname,
'..',
'..',
'..',
'web',
'autogenerated',
'letta-web-openapi.json',
);
const lettaAgentsAPIPath = path.join(
__dirname,
'..',
'..',
'letta',
'server',
'openapi_letta.json',
);
const lettaWebOpenAPI = JSON.parse(
fs.readFileSync(lettaWebOpenAPIPath, 'utf8'),
) as Swagger.SwaggerV3;
const lettaAgentsAPI = JSON.parse(
fs.readFileSync(lettaAgentsAPIPath, 'utf8'),
) as Swagger.SwaggerV3;
// removes any routes that are restricted
lettaAgentsAPI.paths = Object.fromEntries(
Object.entries(lettaAgentsAPI.paths).filter(([path]) =>
RESTRICTED_ROUTE_BASE_PATHS.every(
(restrictedPath) => !path.startsWith(restrictedPath),
),
),
);
const lettaAgentsAPIWithNoEndslash = Object.keys(lettaAgentsAPI.paths).reduce(
(acc, path) => {
const pathWithoutSlash = path.endsWith('/')
? path.slice(0, path.length - 1)
: path;
acc[pathWithoutSlash] = lettaAgentsAPI.paths[path];
return acc;
},
{} as Swagger.SwaggerV3['paths'],
);
// remove duplicate paths, delete from letta-web-openapi if it exists in sdk-core
// some paths will have an extra / at the end, so we need to remove that as well
lettaWebOpenAPI.paths = Object.fromEntries(
Object.entries(lettaWebOpenAPI.paths).filter(([path]) => {
const pathWithoutSlash = path.endsWith('/')
? path.slice(0, path.length - 1)
: path;
return !lettaAgentsAPIWithNoEndslash[pathWithoutSlash];
}),
);
const agentStatePathsToOverride: Array<[string, string]> = [
['/v1/templates/{project}/{template_version}/agents', '201'],
['/v1/agents/search', '200'],
];
for (const [path, responseCode] of agentStatePathsToOverride) {
if (lettaWebOpenAPI.paths[path]?.post?.responses?.[responseCode]) {
// Get direct reference to the schema object
const responseSchema =
lettaWebOpenAPI.paths[path].post.responses[responseCode];
const contentSchema = responseSchema.content['application/json'].schema;
// Replace the entire agents array schema with the reference
if (contentSchema.properties?.agents) {
contentSchema.properties.agents = {
type: 'array',
items: {
$ref: '#/components/schemas/AgentState',
},
};
}
}
}
// go through the paths and remove "user_id"/"actor_id" from the headers
for (const path of Object.keys(lettaAgentsAPI.paths)) {
for (const method of Object.keys(lettaAgentsAPI.paths[path])) {
// @ts-expect-error - a
if (lettaAgentsAPI.paths[path][method]?.parameters) {
// @ts-expect-error - a
lettaAgentsAPI.paths[path][method].parameters = lettaAgentsAPI.paths[
path
][method].parameters.filter(
(param: Record<string, string>) =>
param.in !== 'header' ||
(
param.name !== 'user_id' &&
param.name !== 'User-Agent' &&
param.name !== 'X-Project-Id' &&
param.name !== 'X-Letta-Source' &&
param.name !== 'X-Stainless-Package-Version' &&
!param.name.startsWith('X-Experimental') &&
!param.name.startsWith('X-Billing')
),
);
}
}
}
const result = merge([
{
oas: lettaAgentsAPI,
},
{
oas: lettaWebOpenAPI,
},
]);
if (isErrorResult(result)) {
console.error(`${result.message} (${result.type})`);
process.exit(1);
}
result.output.openapi = '3.1.0';
result.output.info = {
title: 'Letta API',
version: '1.0.0',
};
result.output.servers = [
{
url: 'https://app.letta.com',
description: 'Letta Cloud',
},
{
url: 'http://localhost:8283',
description: 'Self-hosted',
},
];
result.output.components = {
...result.output.components,
securitySchemes: {
bearerAuth: {
type: 'http',
scheme: 'bearer',
},
},
};
result.output.security = [
...(result.output.security || []),
{
bearerAuth: [],
},
];
// omit all instances of "user_id" from the openapi.json file
function deepOmitPreserveArrays(obj: unknown, key: string): unknown {
if (Array.isArray(obj)) {
return obj.map((item) => deepOmitPreserveArrays(item, key));
}
if (typeof obj !== 'object' || obj === null) {
return obj;
}
if (key in obj) {
return omit(obj, key);
}
return Object.fromEntries(
Object.entries(obj).map(([k, v]) => [k, deepOmitPreserveArrays(v, key)]),
);
}
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
result.output.components = deepOmitPreserveArrays(
result.output.components,
'user_id',
);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
result.output.components = deepOmitPreserveArrays(
result.output.components,
'actor_id',
);
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
// @ts-ignore
result.output.components = deepOmitPreserveArrays(
result.output.components,
'organization_id',
);
fs.writeFileSync(
path.join(__dirname, '..', 'openapi.json'),
JSON.stringify(result.output, null, 2),
);
function formatOpenAPIJson() {
const openApiPath = path.join(__dirname, '..', 'openapi.json');
try {
execSync(`npx prettier --write "${openApiPath}"`, { stdio: 'inherit' });
console.log('Successfully formatted openapi.json with Prettier');
} catch (error) {
console.error('Error formatting openapi.json:', error);
process.exit(1);
}
}
formatOpenAPIJson();

View File

@@ -7,6 +7,7 @@ from letta.schemas.letta_message import LettaMessage
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
from letta.schemas.provider_trace import BillingContext
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.telemetry_manager import TelemetryManager
@@ -31,6 +32,7 @@ class LettaLLMAdapter(ABC):
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
billing_context: BillingContext | None = None,
) -> None:
self.llm_client: LLMClientBase = llm_client
self.llm_config: LLMConfig = llm_config
@@ -40,6 +42,7 @@ class LettaLLMAdapter(ABC):
self.run_id: str | None = run_id
self.org_id: str | None = org_id
self.user_id: str | None = user_id
self.billing_context: BillingContext | None = billing_context
self.message_id: str | None = None
self.request_data: dict | None = None
self.response_data: dict | None = None

View File

@@ -10,7 +10,7 @@ from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
from letta.schemas.enums import LLMCallType, ProviderType
from letta.schemas.letta_message import LettaMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.provider_trace import BillingContext, ProviderTrace
from letta.schemas.user import User
from letta.settings import settings
from letta.utils import safe_create_task
@@ -36,6 +36,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
billing_context: "BillingContext | None" = None,
) -> None:
super().__init__(
llm_client,
@@ -46,6 +47,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
run_id=run_id,
org_id=org_id,
user_id=user_id,
billing_context=billing_context,
)
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None

View File

@@ -51,6 +51,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
billing_context=self.billing_context,
)
try:
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)

View File

@@ -278,6 +278,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
billing_context=self.billing_context,
),
),
label="create_provider_trace",

View File

@@ -15,6 +15,7 @@ from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.provider_trace import BillingContext
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -51,7 +52,11 @@ class BaseAgent(ABC):
@abstractmethod
async def step(
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None
self,
input_messages: List[MessageCreate],
max_steps: int = DEFAULT_MAX_STEPS,
run_id: Optional[str] = None,
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
"""
Main execution loop for the agent.

View File

@@ -12,6 +12,7 @@ from letta.schemas.user import User
if TYPE_CHECKING:
from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.provider_trace import BillingContext
class BaseAgentV2(ABC):
@@ -52,6 +53,7 @@ class BaseAgentV2(ABC):
request_start_timestamp_ns: int | None = None,
client_tools: list["ClientToolSchema"] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -76,6 +78,7 @@ class BaseAgentV2(ABC):
conversation_id: str | None = None,
client_tools: list["ClientToolSchema"] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
billing_context: "BillingContext | None" = None,
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.

View File

@@ -48,6 +48,7 @@ from letta.schemas.openai.chat_completion_response import (
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
from letta.schemas.provider_trace import BillingContext
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -179,6 +180,7 @@ class LettaAgent(BaseAgent):
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
dry_run: bool = False,
billing_context: "BillingContext | None" = None,
) -> Union[LettaResponse, dict]:
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
agent_state = await self.agent_manager.get_agent_by_id_async(

View File

@@ -44,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import (
UsageStatisticsCompletionTokenDetails,
UsageStatisticsPromptTokenDetails,
)
from letta.schemas.provider_trace import BillingContext
from letta.schemas.step import Step, StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool import Tool
@@ -185,6 +186,7 @@ class LettaAgentV2(BaseAgentV2):
request_start_timestamp_ns: int | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -290,6 +292,7 @@ class LettaAgentV2(BaseAgentV2):
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
billing_context: BillingContext | None = None,
) -> AsyncGenerator[str, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.

View File

@@ -45,6 +45,7 @@ from letta.schemas.letta_response import LettaResponse, TurnTokenData
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, ToolCall, ToolCallDenial, UsageStatistics
from letta.schemas.provider_trace import BillingContext
from letta.schemas.step import StepProgression
from letta.schemas.step_metrics import StepMetrics
from letta.schemas.tool_execution_result import ToolExecutionResult
@@ -149,6 +150,7 @@ class LettaAgentV3(LettaAgentV2):
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
"""
Execute the agent loop in blocking mode, returning all messages at once.
@@ -232,6 +234,7 @@ class LettaAgentV3(LettaAgentV2):
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
billing_context=billing_context,
)
credit_task = None
@@ -362,6 +365,7 @@ class LettaAgentV3(LettaAgentV2):
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
billing_context: BillingContext | None = None,
) -> AsyncGenerator[str, None]:
"""
Execute the agent loop in streaming mode, yielding chunks as they become available.
@@ -419,6 +423,7 @@ class LettaAgentV3(LettaAgentV2):
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
billing_context=billing_context,
)
elif use_sglang_native:
# Use SGLang native adapter for multi-turn RL training
@@ -431,6 +436,7 @@ class LettaAgentV3(LettaAgentV2):
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
billing_context=billing_context,
)
# Reset turns tracking for this step
self.turns = []
@@ -444,6 +450,7 @@ class LettaAgentV3(LettaAgentV2):
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
billing_context=billing_context,
)
try:

View File

@@ -13,6 +13,7 @@ from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.message import Message, MessageCreate
from letta.schemas.provider_trace import BillingContext
from letta.schemas.run import Run
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
@@ -69,6 +70,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
use_assistant_message: bool = True,
request_start_timestamp_ns: int | None = None,
include_return_message_types: list[MessageType] | None = None,
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
run_ids = []
@@ -100,6 +102,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
run_id=run_id,
use_assistant_message=use_assistant_message,
include_return_message_types=include_return_message_types,
billing_context=billing_context,
)
# Get last response messages

View File

@@ -15,6 +15,7 @@ from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.provider_trace import BillingContext
from letta.schemas.run import Run, RunUpdate
from letta.schemas.user import User
from letta.services.group_manager import GroupManager
@@ -47,6 +48,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
request_start_timestamp_ns: int | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
self.run_ids = []
@@ -62,6 +64,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
request_start_timestamp_ns=request_start_timestamp_ns,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
billing_context=billing_context,
)
await self.run_sleeptime_agents()

View File

@@ -14,6 +14,7 @@ from letta.schemas.letta_request import ClientToolSchema
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import StopReasonType
from letta.schemas.message import Message, MessageCreate
from letta.schemas.provider_trace import BillingContext
from letta.schemas.run import Run, RunUpdate
from letta.schemas.user import User
from letta.services.group_manager import GroupManager
@@ -47,6 +48,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
conversation_id: str | None = None,
client_tools: list[ClientToolSchema] | None = None,
include_compaction_messages: bool = False,
billing_context: "BillingContext | None" = None,
) -> LettaResponse:
self.run_ids = []
@@ -63,6 +65,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
conversation_id=conversation_id,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
billing_context=billing_context,
)
run_ids = await self.run_sleeptime_agents()

View File

@@ -14,7 +14,7 @@ from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.schemas.provider_trace import ProviderTrace
from letta.schemas.provider_trace import BillingContext, ProviderTrace
from letta.schemas.usage import LettaUsageStatistics
from letta.services.telemetry_manager import TelemetryManager
from letta.settings import settings
@@ -48,6 +48,7 @@ class LLMClientBase:
self._telemetry_user_id: Optional[str] = None
self._telemetry_compaction_settings: Optional[Dict] = None
self._telemetry_llm_config: Optional[Dict] = None
self._telemetry_billing_context: Optional[BillingContext] = None
def set_telemetry_context(
self,
@@ -62,6 +63,7 @@ class LLMClientBase:
compaction_settings: Optional[Dict] = None,
llm_config: Optional[Dict] = None,
actor: Optional["User"] = None,
billing_context: Optional[BillingContext] = None,
) -> None:
"""Set telemetry context for provider trace logging."""
if actor is not None:
@@ -76,6 +78,7 @@ class LLMClientBase:
self._telemetry_user_id = user_id
self._telemetry_compaction_settings = compaction_settings
self._telemetry_llm_config = llm_config
self._telemetry_billing_context = billing_context
def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics:
"""Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics."""
@@ -125,6 +128,7 @@ class LLMClientBase:
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
billing_context=self._telemetry_billing_context,
),
)
except Exception as e:
@@ -186,6 +190,7 @@ class LLMClientBase:
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
billing_context=self._telemetry_billing_context,
),
)
except Exception as e:

View File

@@ -95,6 +95,11 @@ class LLMTrace(LettaBase):
response_json: str = Field(..., description="Full response payload as JSON string")
llm_config_json: str = Field(default="", description="LLM config as JSON string")
# Billing context
billing_plan_type: Optional[str] = Field(default=None, description="Subscription tier (e.g., 'basic', 'standard', 'max', 'enterprise')")
billing_cost_source: Optional[str] = Field(default=None, description="Cost source: 'quota' or 'credits'")
billing_customer_id: Optional[str] = Field(default=None, description="Customer ID for cross-referencing billing records")
# Timestamp
created_at: datetime = Field(default_factory=get_utc_time, description="When the trace was created")
@@ -128,6 +133,9 @@ class LLMTrace(LettaBase):
self.request_json,
self.response_json,
self.llm_config_json,
self.billing_plan_type or "",
self.billing_cost_source or "",
self.billing_customer_id or "",
self.created_at,
)
@@ -162,5 +170,8 @@ class LLMTrace(LettaBase):
"request_json",
"response_json",
"llm_config_json",
"billing_plan_type",
"billing_cost_source",
"billing_customer_id",
"created_at",
]

View File

@@ -3,13 +3,21 @@ from __future__ import annotations
from datetime import datetime
from typing import Any, Dict, Optional
from pydantic import Field
from pydantic import BaseModel, Field
from letta.helpers.datetime_helpers import get_utc_time
from letta.schemas.enums import PrimitiveType
from letta.schemas.letta_base import OrmMetadataBase
class BillingContext(BaseModel):
"""Billing context for LLM request cost tracking."""
plan_type: Optional[str] = Field(None, description="Subscription tier")
cost_source: Optional[str] = Field(None, description="Cost source: 'quota' or 'credits'")
customer_id: Optional[str] = Field(None, description="Customer ID for billing records")
class BaseProviderTrace(OrmMetadataBase):
__id_prefix__ = PrimitiveType.PROVIDER_TRACE.value
@@ -53,6 +61,8 @@ class ProviderTrace(BaseProviderTrace):
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
billing_context: Optional[BillingContext] = Field(None, description="Billing context from request headers")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@@ -6,6 +6,7 @@ from pydantic import BaseModel
from letta.errors import LettaInvalidArgumentError
from letta.otel.tracing import tracer
from letta.schemas.enums import PrimitiveType
from letta.schemas.provider_trace import BillingContext
from letta.validators import PRIMITIVE_ID_PATTERNS
if TYPE_CHECKING:
@@ -30,18 +31,24 @@ class HeaderParams(BaseModel):
letta_source: Optional[str] = None
sdk_version: Optional[str] = None
experimental_params: Optional[ExperimentalParams] = None
billing_context: Optional[BillingContext] = None
def get_headers(
actor_id: Optional[str] = Header(None, alias="user_id"),
user_agent: Optional[str] = Header(None, alias="User-Agent"),
project_id: Optional[str] = Header(None, alias="X-Project-Id"),
letta_source: Optional[str] = Header(None, alias="X-Letta-Source"),
sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version"),
message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async"),
letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent"),
letta_v1_agent_message_async: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent-Message-Async"),
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox"),
letta_source: Optional[str] = Header(None, alias="X-Letta-Source", include_in_schema=False),
sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version", include_in_schema=False),
message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async", include_in_schema=False),
letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent", include_in_schema=False),
letta_v1_agent_message_async: Optional[str] = Header(
None, alias="X-Experimental-Letta-V1-Agent-Message-Async", include_in_schema=False
),
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox", include_in_schema=False),
billing_plan_type: Optional[str] = Header(None, alias="X-Billing-Plan-Type", include_in_schema=False),
billing_cost_source: Optional[str] = Header(None, alias="X-Billing-Cost-Source", include_in_schema=False),
billing_customer_id: Optional[str] = Header(None, alias="X-Billing-Customer-Id", include_in_schema=False),
) -> HeaderParams:
"""Dependency injection function to extract common headers from requests."""
with tracer.start_as_current_span("dependency.get_headers"):
@@ -63,6 +70,13 @@ def get_headers(
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
),
billing_context=BillingContext(
plan_type=billing_plan_type,
cost_source=billing_cost_source,
customer_id=billing_customer_id,
)
if any([billing_plan_type, billing_cost_source, billing_customer_id])
else None,
)

View File

@@ -49,6 +49,7 @@ from letta.schemas.memory import (
)
from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult
from letta.schemas.passage import Passage
from letta.schemas.provider_trace import BillingContext
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.source import Source
from letta.schemas.tool import Tool
@@ -1697,6 +1698,7 @@ async def send_message(
actor=actor,
request=request,
run_type="send_message",
billing_context=headers.billing_context,
)
return result
@@ -1767,6 +1769,7 @@ async def send_message(
include_return_message_types=request.include_return_message_types,
client_tools=request.client_tools,
include_compaction_messages=request.include_compaction_messages,
billing_context=headers.billing_context,
)
run_status = result.stop_reason.stop_reason.run_status
return result
@@ -1845,6 +1848,7 @@ async def send_message_streaming(
actor=actor,
request=request,
run_type="send_message_streaming",
billing_context=headers.billing_context,
)
return result
@@ -2043,6 +2047,7 @@ async def _process_message_background(
include_return_message_types: list[MessageType] | None = None,
override_model: str | None = None,
include_compaction_messages: bool = False,
billing_context: "BillingContext | None" = None,
) -> None:
"""Background task to process the message and update run status."""
request_start_timestamp_ns = get_utc_timestamp_ns()
@@ -2074,6 +2079,7 @@ async def _process_message_background(
request_start_timestamp_ns=request_start_timestamp_ns,
include_return_message_types=include_return_message_types,
include_compaction_messages=include_compaction_messages,
billing_context=billing_context,
)
runs_manager = RunManager()
from letta.schemas.enums import RunStatus
@@ -2242,6 +2248,7 @@ async def send_message_async(
include_return_message_types=request.include_return_message_types,
override_model=request.override_model,
include_compaction_messages=request.include_compaction_messages,
billing_context=headers.billing_context,
),
label=f"process_message_background_{run.id}",
)

View File

@@ -19,6 +19,7 @@ from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import LettaMessageUnion
from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.provider_trace import BillingContext
from letta.schemas.run import Run as PydanticRun
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
@@ -211,6 +212,7 @@ async def _send_agent_direct_message(
request: ConversationMessageRequest,
server: SyncServer,
actor,
billing_context: "BillingContext | None" = None,
) -> StreamingResponse | LettaResponse:
"""
Handle agent-direct messaging with locking but without conversation features.
@@ -244,6 +246,7 @@ async def _send_agent_direct_message(
run_type="send_message",
conversation_id=None,
should_lock=True,
billing_context=billing_context,
)
return result
@@ -299,6 +302,7 @@ async def _send_agent_direct_message(
client_tools=request.client_tools,
conversation_id=None,
include_compaction_messages=request.include_compaction_messages,
billing_context=billing_context,
)
finally:
# Release lock
@@ -351,6 +355,7 @@ async def send_conversation_message(
request=request,
server=server,
actor=actor,
billing_context=headers.billing_context,
)
# Normal conversation mode
@@ -383,6 +388,7 @@ async def send_conversation_message(
request=streaming_request,
run_type="send_conversation_message",
conversation_id=conversation_id,
billing_context=headers.billing_context,
)
return result
@@ -445,6 +451,7 @@ async def send_conversation_message(
client_tools=request.client_tools,
conversation_id=conversation_id,
include_compaction_messages=request.include_compaction_messages,
billing_context=headers.billing_context,
)

View File

@@ -141,6 +141,9 @@ class ClickhouseProviderTraceBackend(ProviderTraceBackendClient):
request_json=request_json_str,
response_json=response_json_str,
llm_config_json=llm_config_json_str,
billing_plan_type=provider_trace.billing_context.plan_type if provider_trace.billing_context else None,
billing_cost_source=provider_trace.billing_context.cost_source if provider_trace.billing_context else None,
billing_customer_id=provider_trace.billing_context.customer_id if provider_trace.billing_context else None,
)
def _extract_usage(self, response_json: dict, provider: str) -> dict:

View File

@@ -29,7 +29,7 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
) -> ProviderTrace:
"""Write full provider trace to provider_traces table."""
async with db_registry.async_session() as session:
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump())
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump(exclude={"billing_context"}))
provider_trace_model.organization_id = actor.organization_id
if provider_trace.request_json:

View File

@@ -34,6 +34,7 @@ from letta.schemas.letta_request import ClientToolSchema, LettaStreamingRequest
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import MessageCreate
from letta.schemas.provider_trace import BillingContext
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
@@ -78,6 +79,7 @@ class StreamingService:
run_type: str = "streaming",
conversation_id: Optional[str] = None,
should_lock: bool = False,
billing_context: "BillingContext | None" = None,
) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]:
"""
Create a streaming response for an agent.
@@ -176,6 +178,7 @@ class StreamingService:
lock_key=lock_key, # For lock release (may differ from conversation_id)
client_tools=request.client_tools,
include_compaction_messages=request.include_compaction_messages,
billing_context=billing_context,
)
# handle background streaming if requested
@@ -340,6 +343,7 @@ class StreamingService:
lock_key: Optional[str] = None,
client_tools: Optional[list[ClientToolSchema]] = None,
include_compaction_messages: bool = False,
billing_context: BillingContext | None = None,
) -> AsyncIterator:
"""
Create a stream with unified error handling.
@@ -368,6 +372,7 @@ class StreamingService:
conversation_id=conversation_id,
client_tools=client_tools,
include_compaction_messages=include_compaction_messages,
billing_context=billing_context,
)
async for chunk in stream:

View File

@@ -24,6 +24,9 @@ def test_get_headers_user_id_allows_none():
letta_v1_agent=None,
letta_v1_agent_message_async=None,
modal_sandbox=None,
billing_plan_type=None,
billing_cost_source=None,
billing_customer_id=None,
)
assert isinstance(headers, HeaderParams)
@@ -40,6 +43,9 @@ def test_get_headers_user_id_rejects_invalid_format():
letta_v1_agent=None,
letta_v1_agent_message_async=None,
modal_sandbox=None,
billing_plan_type=None,
billing_cost_source=None,
billing_customer_id=None,
)
@@ -54,6 +60,9 @@ def test_get_headers_user_id_accepts_valid_format():
letta_v1_agent=None,
letta_v1_agent_message_async=None,
modal_sandbox=None,
billing_plan_type=None,
billing_cost_source=None,
billing_customer_id=None,
)
assert headers.actor_id == "user-123e4567-e89b-42d3-8456-426614174000"