Add billing context to LLM telemetry traces (#9745)
* feat: add billing context to LLM telemetry traces Add billing metadata (plan type, cost source, customer ID) to LLM traces in ClickHouse for cost analytics and attribution. **Data Flow:** - Cloud-API: Extract billing info from subscription in rate limiting, set x-billing-* headers - Core: Parse headers into BillingContext object via dependencies - Adapters: Flow billing_context through all LLM adapters (blocking & streaming) - Agent: Pass billing_context to step() and stream() methods - ClickHouse: Store in billing_plan_type, billing_cost_source, billing_customer_id columns **Changes:** - Add BillingContext schema to provider_trace.py - Add billing columns to llm_traces ClickHouse table DDL - Update getCustomerSubscription to fetch stripeCustomerId from organization_billing_details - Propagate billing_context through agent step flow, adapters, and streaming service - Update ProviderTrace and LLMTrace to include billing metadata - Regenerate SDK with autogen **Production Deployment:** Requires env vars: LETTA_PROVIDER_TRACE_BACKEND=clickhouse, LETTA_STORE_LLM_TRACES=true, CLICKHOUSE_* 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: add billing_context parameter to agent step methods - Add billing_context to BaseAgent and BaseAgentV2 abstract methods - Update LettaAgent, LettaAgentV2, LettaAgentV3 step methods - Update multi-agent groups: SleeptimeMultiAgentV2, V3, V4 - Fix test_utils.py to include billing header parameters - Import BillingContext in all affected files * fix: add billing_context to stream methods - Add billing_context parameter to BaseAgentV2.stream() - Add billing_context parameter to LettaAgentV2.stream() - LettaAgentV3.stream() already has it from previous commit * fix: exclude billing headers from OpenAPI spec Mark billing headers as internal (include_in_schema=False) so they don't appear in the public API. These are internal headers between cloud-api and core, not part of the public SDK. Regenerated SDK with stage-api - removes 10,650 lines of bloat that was causing OOM during Next.js build. * refactor: return billing context from handleUnifiedRateLimiting instead of mutating req Instead of passing req into handleUnifiedRateLimiting and mutating headers inside it: - Return billing context fields (billingPlanType, billingCostSource, billingCustomerId) from handleUnifiedRateLimiting - Set headers in handleMessageRateLimiting (middleware layer) after getting the result - This fixes step-orchestrator compatibility since it doesn't have a real Express req object * chore: remove extra gencode * p --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -29385,6 +29385,49 @@
|
||||
"title": "BedrockModelSettings",
|
||||
"description": "AWS Bedrock model configuration."
|
||||
},
|
||||
"BillingContext": {
|
||||
"properties": {
|
||||
"plan_type": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Plan Type",
|
||||
"description": "Subscription tier"
|
||||
},
|
||||
"cost_source": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Cost Source",
|
||||
"description": "Cost source: 'quota' or 'credits'"
|
||||
},
|
||||
"customer_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Customer Id",
|
||||
"description": "Customer ID for billing records"
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "BillingContext",
|
||||
"description": "Billing context for LLM request cost tracking."
|
||||
},
|
||||
"Block": {
|
||||
"properties": {
|
||||
"value": {
|
||||
@@ -42964,6 +43007,17 @@
|
||||
],
|
||||
"title": "Llm Config",
|
||||
"description": "LLM configuration used for this call (non-summarization calls only)"
|
||||
},
|
||||
"billing_context": {
|
||||
"anyOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/BillingContext"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"description": "Billing context from request headers"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
|
||||
220
fern/scripts/prepare-openapi.ts
Normal file
220
fern/scripts/prepare-openapi.ts
Normal file
@@ -0,0 +1,220 @@
|
||||
import * as fs from 'fs';
|
||||
import * as path from 'path';
|
||||
|
||||
import { omit } from 'lodash';
|
||||
import { execSync } from 'child_process';
|
||||
import { merge, isErrorResult } from 'openapi-merge';
|
||||
import type { Swagger } from 'atlassian-openapi';
|
||||
import { RESTRICTED_ROUTE_BASE_PATHS } from '@letta-cloud/sdk-core';
|
||||
|
||||
const lettaWebOpenAPIPath = path.join(
|
||||
__dirname,
|
||||
'..',
|
||||
'..',
|
||||
'..',
|
||||
'web',
|
||||
'autogenerated',
|
||||
'letta-web-openapi.json',
|
||||
);
|
||||
const lettaAgentsAPIPath = path.join(
|
||||
__dirname,
|
||||
'..',
|
||||
'..',
|
||||
'letta',
|
||||
'server',
|
||||
'openapi_letta.json',
|
||||
);
|
||||
|
||||
const lettaWebOpenAPI = JSON.parse(
|
||||
fs.readFileSync(lettaWebOpenAPIPath, 'utf8'),
|
||||
) as Swagger.SwaggerV3;
|
||||
const lettaAgentsAPI = JSON.parse(
|
||||
fs.readFileSync(lettaAgentsAPIPath, 'utf8'),
|
||||
) as Swagger.SwaggerV3;
|
||||
|
||||
// removes any routes that are restricted
|
||||
lettaAgentsAPI.paths = Object.fromEntries(
|
||||
Object.entries(lettaAgentsAPI.paths).filter(([path]) =>
|
||||
RESTRICTED_ROUTE_BASE_PATHS.every(
|
||||
(restrictedPath) => !path.startsWith(restrictedPath),
|
||||
),
|
||||
),
|
||||
);
|
||||
|
||||
const lettaAgentsAPIWithNoEndslash = Object.keys(lettaAgentsAPI.paths).reduce(
|
||||
(acc, path) => {
|
||||
const pathWithoutSlash = path.endsWith('/')
|
||||
? path.slice(0, path.length - 1)
|
||||
: path;
|
||||
acc[pathWithoutSlash] = lettaAgentsAPI.paths[path];
|
||||
return acc;
|
||||
},
|
||||
{} as Swagger.SwaggerV3['paths'],
|
||||
);
|
||||
|
||||
// remove duplicate paths, delete from letta-web-openapi if it exists in sdk-core
|
||||
// some paths will have an extra / at the end, so we need to remove that as well
|
||||
lettaWebOpenAPI.paths = Object.fromEntries(
|
||||
Object.entries(lettaWebOpenAPI.paths).filter(([path]) => {
|
||||
const pathWithoutSlash = path.endsWith('/')
|
||||
? path.slice(0, path.length - 1)
|
||||
: path;
|
||||
return !lettaAgentsAPIWithNoEndslash[pathWithoutSlash];
|
||||
}),
|
||||
);
|
||||
|
||||
const agentStatePathsToOverride: Array<[string, string]> = [
|
||||
['/v1/templates/{project}/{template_version}/agents', '201'],
|
||||
['/v1/agents/search', '200'],
|
||||
];
|
||||
|
||||
for (const [path, responseCode] of agentStatePathsToOverride) {
|
||||
if (lettaWebOpenAPI.paths[path]?.post?.responses?.[responseCode]) {
|
||||
// Get direct reference to the schema object
|
||||
const responseSchema =
|
||||
lettaWebOpenAPI.paths[path].post.responses[responseCode];
|
||||
const contentSchema = responseSchema.content['application/json'].schema;
|
||||
|
||||
// Replace the entire agents array schema with the reference
|
||||
if (contentSchema.properties?.agents) {
|
||||
contentSchema.properties.agents = {
|
||||
type: 'array',
|
||||
items: {
|
||||
$ref: '#/components/schemas/AgentState',
|
||||
},
|
||||
};
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// go through the paths and remove "user_id"/"actor_id" from the headers
|
||||
for (const path of Object.keys(lettaAgentsAPI.paths)) {
|
||||
for (const method of Object.keys(lettaAgentsAPI.paths[path])) {
|
||||
// @ts-expect-error - a
|
||||
if (lettaAgentsAPI.paths[path][method]?.parameters) {
|
||||
// @ts-expect-error - a
|
||||
lettaAgentsAPI.paths[path][method].parameters = lettaAgentsAPI.paths[
|
||||
path
|
||||
][method].parameters.filter(
|
||||
(param: Record<string, string>) =>
|
||||
param.in !== 'header' ||
|
||||
(
|
||||
param.name !== 'user_id' &&
|
||||
param.name !== 'User-Agent' &&
|
||||
param.name !== 'X-Project-Id' &&
|
||||
param.name !== 'X-Letta-Source' &&
|
||||
param.name !== 'X-Stainless-Package-Version' &&
|
||||
!param.name.startsWith('X-Experimental') &&
|
||||
!param.name.startsWith('X-Billing')
|
||||
),
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const result = merge([
|
||||
{
|
||||
oas: lettaAgentsAPI,
|
||||
},
|
||||
{
|
||||
oas: lettaWebOpenAPI,
|
||||
},
|
||||
]);
|
||||
|
||||
if (isErrorResult(result)) {
|
||||
console.error(`${result.message} (${result.type})`);
|
||||
process.exit(1);
|
||||
}
|
||||
|
||||
result.output.openapi = '3.1.0';
|
||||
result.output.info = {
|
||||
title: 'Letta API',
|
||||
version: '1.0.0',
|
||||
};
|
||||
|
||||
result.output.servers = [
|
||||
{
|
||||
url: 'https://app.letta.com',
|
||||
description: 'Letta Cloud',
|
||||
},
|
||||
{
|
||||
url: 'http://localhost:8283',
|
||||
description: 'Self-hosted',
|
||||
},
|
||||
];
|
||||
|
||||
result.output.components = {
|
||||
...result.output.components,
|
||||
securitySchemes: {
|
||||
bearerAuth: {
|
||||
type: 'http',
|
||||
scheme: 'bearer',
|
||||
},
|
||||
},
|
||||
};
|
||||
|
||||
result.output.security = [
|
||||
...(result.output.security || []),
|
||||
{
|
||||
bearerAuth: [],
|
||||
},
|
||||
];
|
||||
|
||||
// omit all instances of "user_id" from the openapi.json file
|
||||
function deepOmitPreserveArrays(obj: unknown, key: string): unknown {
|
||||
if (Array.isArray(obj)) {
|
||||
return obj.map((item) => deepOmitPreserveArrays(item, key));
|
||||
}
|
||||
|
||||
if (typeof obj !== 'object' || obj === null) {
|
||||
return obj;
|
||||
}
|
||||
|
||||
if (key in obj) {
|
||||
return omit(obj, key);
|
||||
}
|
||||
|
||||
return Object.fromEntries(
|
||||
Object.entries(obj).map(([k, v]) => [k, deepOmitPreserveArrays(v, key)]),
|
||||
);
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore
|
||||
result.output.components = deepOmitPreserveArrays(
|
||||
result.output.components,
|
||||
'user_id',
|
||||
);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore
|
||||
result.output.components = deepOmitPreserveArrays(
|
||||
result.output.components,
|
||||
'actor_id',
|
||||
);
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||
// @ts-ignore
|
||||
result.output.components = deepOmitPreserveArrays(
|
||||
result.output.components,
|
||||
'organization_id',
|
||||
);
|
||||
|
||||
fs.writeFileSync(
|
||||
path.join(__dirname, '..', 'openapi.json'),
|
||||
JSON.stringify(result.output, null, 2),
|
||||
);
|
||||
|
||||
function formatOpenAPIJson() {
|
||||
const openApiPath = path.join(__dirname, '..', 'openapi.json');
|
||||
|
||||
try {
|
||||
execSync(`npx prettier --write "${openApiPath}"`, { stdio: 'inherit' });
|
||||
console.log('Successfully formatted openapi.json with Prettier');
|
||||
} catch (error) {
|
||||
console.error('Error formatting openapi.json:', error);
|
||||
process.exit(1);
|
||||
}
|
||||
}
|
||||
|
||||
formatOpenAPIJson();
|
||||
@@ -7,6 +7,7 @@ from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
@@ -31,6 +32,7 @@ class LettaLLMAdapter(ABC):
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
billing_context: BillingContext | None = None,
|
||||
) -> None:
|
||||
self.llm_client: LLMClientBase = llm_client
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
@@ -40,6 +42,7 @@ class LettaLLMAdapter(ABC):
|
||||
self.run_id: str | None = run_id
|
||||
self.org_id: str | None = org_id
|
||||
self.user_id: str | None = user_id
|
||||
self.billing_context: BillingContext | None = billing_context
|
||||
self.message_id: str | None = None
|
||||
self.request_data: dict | None = None
|
||||
self.response_data: dict | None = None
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
|
||||
from letta.schemas.enums import LLMCallType, ProviderType
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.provider_trace import BillingContext, ProviderTrace
|
||||
from letta.schemas.user import User
|
||||
from letta.settings import settings
|
||||
from letta.utils import safe_create_task
|
||||
@@ -36,6 +36,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
llm_client,
|
||||
@@ -46,6 +47,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
run_id=run_id,
|
||||
org_id=org_id,
|
||||
user_id=user_id,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
||||
|
||||
|
||||
@@ -51,6 +51,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
billing_context=self.billing_context,
|
||||
)
|
||||
try:
|
||||
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
||||
|
||||
@@ -278,6 +278,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
billing_context=self.billing_context,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -51,7 +52,11 @@ class BaseAgent(ABC):
|
||||
|
||||
@abstractmethod
|
||||
async def step(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None
|
||||
self,
|
||||
input_messages: List[MessageCreate],
|
||||
max_steps: int = DEFAULT_MAX_STEPS,
|
||||
run_id: Optional[str] = None,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Main execution loop for the agent.
|
||||
|
||||
@@ -12,6 +12,7 @@ from letta.schemas.user import User
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
|
||||
|
||||
class BaseAgentV2(ABC):
|
||||
@@ -52,6 +53,7 @@ class BaseAgentV2(ABC):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -76,6 +78,7 @@ class BaseAgentV2(ABC):
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list["ClientToolSchema"] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
|
||||
@@ -48,6 +48,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatisticsCompletionTokenDetails,
|
||||
UsageStatisticsPromptTokenDetails,
|
||||
)
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.step import StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -179,6 +180,7 @@ class LettaAgent(BaseAgent):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
dry_run: bool = False,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> Union[LettaResponse, dict]:
|
||||
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||
|
||||
@@ -44,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import (
|
||||
UsageStatisticsCompletionTokenDetails,
|
||||
UsageStatisticsPromptTokenDetails,
|
||||
)
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.step import Step, StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -185,6 +186,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -290,6 +292,7 @@ class LettaAgentV2(BaseAgentV2):
|
||||
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||
billing_context: BillingContext | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
|
||||
@@ -45,6 +45,7 @@ from letta.schemas.letta_response import LettaResponse, TurnTokenData
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, ToolCall, ToolCallDenial, UsageStatistics
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.step import StepProgression
|
||||
from letta.schemas.step_metrics import StepMetrics
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
@@ -149,6 +150,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
"""
|
||||
Execute the agent loop in blocking mode, returning all messages at once.
|
||||
@@ -232,6 +234,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
credit_task = None
|
||||
@@ -362,6 +365,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
billing_context: BillingContext | None = None,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||
@@ -419,6 +423,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
elif use_sglang_native:
|
||||
# Use SGLang native adapter for multi-turn RL training
|
||||
@@ -431,6 +436,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
# Reset turns tracking for this step
|
||||
self.turns = []
|
||||
@@ -444,6 +450,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.letta_message import MessageType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.run import Run
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -69,6 +70,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
use_assistant_message: bool = True,
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
run_ids = []
|
||||
|
||||
@@ -100,6 +102,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
include_return_message_types=include_return_message_types,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
# Get last response messages
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.run import Run, RunUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.group_manager import GroupManager
|
||||
@@ -47,6 +48,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
request_start_timestamp_ns: int | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -62,6 +64,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
await self.run_sleeptime_agents()
|
||||
|
||||
@@ -14,6 +14,7 @@ from letta.schemas.letta_request import ClientToolSchema
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import StopReasonType
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.run import Run, RunUpdate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.group_manager import GroupManager
|
||||
@@ -47,6 +48,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
conversation_id: str | None = None,
|
||||
client_tools: list[ClientToolSchema] | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> LettaResponse:
|
||||
self.run_ids = []
|
||||
|
||||
@@ -63,6 +65,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
run_ids = await self.run_sleeptime_agents()
|
||||
|
||||
@@ -14,7 +14,7 @@ from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.provider_trace import ProviderTrace
|
||||
from letta.schemas.provider_trace import BillingContext, ProviderTrace
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.telemetry_manager import TelemetryManager
|
||||
from letta.settings import settings
|
||||
@@ -48,6 +48,7 @@ class LLMClientBase:
|
||||
self._telemetry_user_id: Optional[str] = None
|
||||
self._telemetry_compaction_settings: Optional[Dict] = None
|
||||
self._telemetry_llm_config: Optional[Dict] = None
|
||||
self._telemetry_billing_context: Optional[BillingContext] = None
|
||||
|
||||
def set_telemetry_context(
|
||||
self,
|
||||
@@ -62,6 +63,7 @@ class LLMClientBase:
|
||||
compaction_settings: Optional[Dict] = None,
|
||||
llm_config: Optional[Dict] = None,
|
||||
actor: Optional["User"] = None,
|
||||
billing_context: Optional[BillingContext] = None,
|
||||
) -> None:
|
||||
"""Set telemetry context for provider trace logging."""
|
||||
if actor is not None:
|
||||
@@ -76,6 +78,7 @@ class LLMClientBase:
|
||||
self._telemetry_user_id = user_id
|
||||
self._telemetry_compaction_settings = compaction_settings
|
||||
self._telemetry_llm_config = llm_config
|
||||
self._telemetry_billing_context = billing_context
|
||||
|
||||
def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||
"""Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics."""
|
||||
@@ -125,6 +128,7 @@ class LLMClientBase:
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
||||
billing_context=self._telemetry_billing_context,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -186,6 +190,7 @@ class LLMClientBase:
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
||||
billing_context=self._telemetry_billing_context,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -95,6 +95,11 @@ class LLMTrace(LettaBase):
|
||||
response_json: str = Field(..., description="Full response payload as JSON string")
|
||||
llm_config_json: str = Field(default="", description="LLM config as JSON string")
|
||||
|
||||
# Billing context
|
||||
billing_plan_type: Optional[str] = Field(default=None, description="Subscription tier (e.g., 'basic', 'standard', 'max', 'enterprise')")
|
||||
billing_cost_source: Optional[str] = Field(default=None, description="Cost source: 'quota' or 'credits'")
|
||||
billing_customer_id: Optional[str] = Field(default=None, description="Customer ID for cross-referencing billing records")
|
||||
|
||||
# Timestamp
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="When the trace was created")
|
||||
|
||||
@@ -128,6 +133,9 @@ class LLMTrace(LettaBase):
|
||||
self.request_json,
|
||||
self.response_json,
|
||||
self.llm_config_json,
|
||||
self.billing_plan_type or "",
|
||||
self.billing_cost_source or "",
|
||||
self.billing_customer_id or "",
|
||||
self.created_at,
|
||||
)
|
||||
|
||||
@@ -162,5 +170,8 @@ class LLMTrace(LettaBase):
|
||||
"request_json",
|
||||
"response_json",
|
||||
"llm_config_json",
|
||||
"billing_plan_type",
|
||||
"billing_cost_source",
|
||||
"billing_customer_id",
|
||||
"created_at",
|
||||
]
|
||||
|
||||
@@ -3,13 +3,21 @@ from __future__ import annotations
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.letta_base import OrmMetadataBase
|
||||
|
||||
|
||||
class BillingContext(BaseModel):
|
||||
"""Billing context for LLM request cost tracking."""
|
||||
|
||||
plan_type: Optional[str] = Field(None, description="Subscription tier")
|
||||
cost_source: Optional[str] = Field(None, description="Cost source: 'quota' or 'credits'")
|
||||
customer_id: Optional[str] = Field(None, description="Customer ID for billing records")
|
||||
|
||||
|
||||
class BaseProviderTrace(OrmMetadataBase):
|
||||
__id_prefix__ = PrimitiveType.PROVIDER_TRACE.value
|
||||
|
||||
@@ -53,6 +61,8 @@ class ProviderTrace(BaseProviderTrace):
|
||||
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
|
||||
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
|
||||
|
||||
billing_context: Optional[BillingContext] = Field(None, description="Billing context from request headers")
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
||||
from letta.errors import LettaInvalidArgumentError
|
||||
from letta.otel.tracing import tracer
|
||||
from letta.schemas.enums import PrimitiveType
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -30,18 +31,24 @@ class HeaderParams(BaseModel):
|
||||
letta_source: Optional[str] = None
|
||||
sdk_version: Optional[str] = None
|
||||
experimental_params: Optional[ExperimentalParams] = None
|
||||
billing_context: Optional[BillingContext] = None
|
||||
|
||||
|
||||
def get_headers(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
user_agent: Optional[str] = Header(None, alias="User-Agent"),
|
||||
project_id: Optional[str] = Header(None, alias="X-Project-Id"),
|
||||
letta_source: Optional[str] = Header(None, alias="X-Letta-Source"),
|
||||
sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version"),
|
||||
message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async"),
|
||||
letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent"),
|
||||
letta_v1_agent_message_async: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent-Message-Async"),
|
||||
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox"),
|
||||
letta_source: Optional[str] = Header(None, alias="X-Letta-Source", include_in_schema=False),
|
||||
sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version", include_in_schema=False),
|
||||
message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async", include_in_schema=False),
|
||||
letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent", include_in_schema=False),
|
||||
letta_v1_agent_message_async: Optional[str] = Header(
|
||||
None, alias="X-Experimental-Letta-V1-Agent-Message-Async", include_in_schema=False
|
||||
),
|
||||
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox", include_in_schema=False),
|
||||
billing_plan_type: Optional[str] = Header(None, alias="X-Billing-Plan-Type", include_in_schema=False),
|
||||
billing_cost_source: Optional[str] = Header(None, alias="X-Billing-Cost-Source", include_in_schema=False),
|
||||
billing_customer_id: Optional[str] = Header(None, alias="X-Billing-Customer-Id", include_in_schema=False),
|
||||
) -> HeaderParams:
|
||||
"""Dependency injection function to extract common headers from requests."""
|
||||
with tracer.start_as_current_span("dependency.get_headers"):
|
||||
@@ -63,6 +70,13 @@ def get_headers(
|
||||
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
|
||||
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
|
||||
),
|
||||
billing_context=BillingContext(
|
||||
plan_type=billing_plan_type,
|
||||
cost_source=billing_cost_source,
|
||||
customer_id=billing_customer_id,
|
||||
)
|
||||
if any([billing_plan_type, billing_cost_source, billing_customer_id])
|
||||
else None,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -49,6 +49,7 @@ from letta.schemas.memory import (
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
@@ -1697,6 +1698,7 @@ async def send_message(
|
||||
actor=actor,
|
||||
request=request,
|
||||
run_type="send_message",
|
||||
billing_context=headers.billing_context,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1767,6 +1769,7 @@ async def send_message(
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
client_tools=request.client_tools,
|
||||
include_compaction_messages=request.include_compaction_messages,
|
||||
billing_context=headers.billing_context,
|
||||
)
|
||||
run_status = result.stop_reason.stop_reason.run_status
|
||||
return result
|
||||
@@ -1845,6 +1848,7 @@ async def send_message_streaming(
|
||||
actor=actor,
|
||||
request=request,
|
||||
run_type="send_message_streaming",
|
||||
billing_context=headers.billing_context,
|
||||
)
|
||||
|
||||
return result
|
||||
@@ -2043,6 +2047,7 @@ async def _process_message_background(
|
||||
include_return_message_types: list[MessageType] | None = None,
|
||||
override_model: str | None = None,
|
||||
include_compaction_messages: bool = False,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> None:
|
||||
"""Background task to process the message and update run status."""
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
@@ -2074,6 +2079,7 @@ async def _process_message_background(
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
include_return_message_types=include_return_message_types,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
runs_manager = RunManager()
|
||||
from letta.schemas.enums import RunStatus
|
||||
@@ -2242,6 +2248,7 @@ async def send_message_async(
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
override_model=request.override_model,
|
||||
include_compaction_messages=request.include_compaction_messages,
|
||||
billing_context=headers.billing_context,
|
||||
),
|
||||
label=f"process_message_background_{run.id}",
|
||||
)
|
||||
|
||||
@@ -19,6 +19,7 @@ from letta.schemas.job import LettaRequestConfig
|
||||
from letta.schemas.letta_message import LettaMessageUnion
|
||||
from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||
@@ -211,6 +212,7 @@ async def _send_agent_direct_message(
|
||||
request: ConversationMessageRequest,
|
||||
server: SyncServer,
|
||||
actor,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> StreamingResponse | LettaResponse:
|
||||
"""
|
||||
Handle agent-direct messaging with locking but without conversation features.
|
||||
@@ -244,6 +246,7 @@ async def _send_agent_direct_message(
|
||||
run_type="send_message",
|
||||
conversation_id=None,
|
||||
should_lock=True,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -299,6 +302,7 @@ async def _send_agent_direct_message(
|
||||
client_tools=request.client_tools,
|
||||
conversation_id=None,
|
||||
include_compaction_messages=request.include_compaction_messages,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
finally:
|
||||
# Release lock
|
||||
@@ -351,6 +355,7 @@ async def send_conversation_message(
|
||||
request=request,
|
||||
server=server,
|
||||
actor=actor,
|
||||
billing_context=headers.billing_context,
|
||||
)
|
||||
|
||||
# Normal conversation mode
|
||||
@@ -383,6 +388,7 @@ async def send_conversation_message(
|
||||
request=streaming_request,
|
||||
run_type="send_conversation_message",
|
||||
conversation_id=conversation_id,
|
||||
billing_context=headers.billing_context,
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -445,6 +451,7 @@ async def send_conversation_message(
|
||||
client_tools=request.client_tools,
|
||||
conversation_id=conversation_id,
|
||||
include_compaction_messages=request.include_compaction_messages,
|
||||
billing_context=headers.billing_context,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -141,6 +141,9 @@ class ClickhouseProviderTraceBackend(ProviderTraceBackendClient):
|
||||
request_json=request_json_str,
|
||||
response_json=response_json_str,
|
||||
llm_config_json=llm_config_json_str,
|
||||
billing_plan_type=provider_trace.billing_context.plan_type if provider_trace.billing_context else None,
|
||||
billing_cost_source=provider_trace.billing_context.cost_source if provider_trace.billing_context else None,
|
||||
billing_customer_id=provider_trace.billing_context.customer_id if provider_trace.billing_context else None,
|
||||
)
|
||||
|
||||
def _extract_usage(self, response_json: dict, provider: str) -> dict:
|
||||
|
||||
@@ -29,7 +29,7 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
|
||||
) -> ProviderTrace:
|
||||
"""Write full provider trace to provider_traces table."""
|
||||
async with db_registry.async_session() as session:
|
||||
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump())
|
||||
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump(exclude={"billing_context"}))
|
||||
provider_trace_model.organization_id = actor.organization_id
|
||||
|
||||
if provider_trace.request_json:
|
||||
|
||||
@@ -34,6 +34,7 @@ from letta.schemas.letta_request import ClientToolSchema, LettaStreamingRequest
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.provider_trace import BillingContext
|
||||
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
@@ -78,6 +79,7 @@ class StreamingService:
|
||||
run_type: str = "streaming",
|
||||
conversation_id: Optional[str] = None,
|
||||
should_lock: bool = False,
|
||||
billing_context: "BillingContext | None" = None,
|
||||
) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]:
|
||||
"""
|
||||
Create a streaming response for an agent.
|
||||
@@ -176,6 +178,7 @@ class StreamingService:
|
||||
lock_key=lock_key, # For lock release (may differ from conversation_id)
|
||||
client_tools=request.client_tools,
|
||||
include_compaction_messages=request.include_compaction_messages,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
# handle background streaming if requested
|
||||
@@ -340,6 +343,7 @@ class StreamingService:
|
||||
lock_key: Optional[str] = None,
|
||||
client_tools: Optional[list[ClientToolSchema]] = None,
|
||||
include_compaction_messages: bool = False,
|
||||
billing_context: BillingContext | None = None,
|
||||
) -> AsyncIterator:
|
||||
"""
|
||||
Create a stream with unified error handling.
|
||||
@@ -368,6 +372,7 @@ class StreamingService:
|
||||
conversation_id=conversation_id,
|
||||
client_tools=client_tools,
|
||||
include_compaction_messages=include_compaction_messages,
|
||||
billing_context=billing_context,
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
|
||||
@@ -24,6 +24,9 @@ def test_get_headers_user_id_allows_none():
|
||||
letta_v1_agent=None,
|
||||
letta_v1_agent_message_async=None,
|
||||
modal_sandbox=None,
|
||||
billing_plan_type=None,
|
||||
billing_cost_source=None,
|
||||
billing_customer_id=None,
|
||||
)
|
||||
assert isinstance(headers, HeaderParams)
|
||||
|
||||
@@ -40,6 +43,9 @@ def test_get_headers_user_id_rejects_invalid_format():
|
||||
letta_v1_agent=None,
|
||||
letta_v1_agent_message_async=None,
|
||||
modal_sandbox=None,
|
||||
billing_plan_type=None,
|
||||
billing_cost_source=None,
|
||||
billing_customer_id=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -54,6 +60,9 @@ def test_get_headers_user_id_accepts_valid_format():
|
||||
letta_v1_agent=None,
|
||||
letta_v1_agent_message_async=None,
|
||||
modal_sandbox=None,
|
||||
billing_plan_type=None,
|
||||
billing_cost_source=None,
|
||||
billing_customer_id=None,
|
||||
)
|
||||
assert headers.actor_id == "user-123e4567-e89b-42d3-8456-426614174000"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user