Add billing context to LLM telemetry traces (#9745)
* feat: add billing context to LLM telemetry traces Add billing metadata (plan type, cost source, customer ID) to LLM traces in ClickHouse for cost analytics and attribution. **Data Flow:** - Cloud-API: Extract billing info from subscription in rate limiting, set x-billing-* headers - Core: Parse headers into BillingContext object via dependencies - Adapters: Flow billing_context through all LLM adapters (blocking & streaming) - Agent: Pass billing_context to step() and stream() methods - ClickHouse: Store in billing_plan_type, billing_cost_source, billing_customer_id columns **Changes:** - Add BillingContext schema to provider_trace.py - Add billing columns to llm_traces ClickHouse table DDL - Update getCustomerSubscription to fetch stripeCustomerId from organization_billing_details - Propagate billing_context through agent step flow, adapters, and streaming service - Update ProviderTrace and LLMTrace to include billing metadata - Regenerate SDK with autogen **Production Deployment:** Requires env vars: LETTA_PROVIDER_TRACE_BACKEND=clickhouse, LETTA_STORE_LLM_TRACES=true, CLICKHOUSE_* 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix: add billing_context parameter to agent step methods - Add billing_context to BaseAgent and BaseAgentV2 abstract methods - Update LettaAgent, LettaAgentV2, LettaAgentV3 step methods - Update multi-agent groups: SleeptimeMultiAgentV2, V3, V4 - Fix test_utils.py to include billing header parameters - Import BillingContext in all affected files * fix: add billing_context to stream methods - Add billing_context parameter to BaseAgentV2.stream() - Add billing_context parameter to LettaAgentV2.stream() - LettaAgentV3.stream() already has it from previous commit * fix: exclude billing headers from OpenAPI spec Mark billing headers as internal (include_in_schema=False) so they don't appear in the public API. These are internal headers between cloud-api and core, not part of the public SDK. Regenerated SDK with stage-api - removes 10,650 lines of bloat that was causing OOM during Next.js build. * refactor: return billing context from handleUnifiedRateLimiting instead of mutating req Instead of passing req into handleUnifiedRateLimiting and mutating headers inside it: - Return billing context fields (billingPlanType, billingCostSource, billingCustomerId) from handleUnifiedRateLimiting - Set headers in handleMessageRateLimiting (middleware layer) after getting the result - This fixes step-orchestrator compatibility since it doesn't have a real Express req object * chore: remove extra gencode * p --------- Co-authored-by: Letta <noreply@letta.com>
This commit is contained in:
@@ -29385,6 +29385,49 @@
|
|||||||
"title": "BedrockModelSettings",
|
"title": "BedrockModelSettings",
|
||||||
"description": "AWS Bedrock model configuration."
|
"description": "AWS Bedrock model configuration."
|
||||||
},
|
},
|
||||||
|
"BillingContext": {
|
||||||
|
"properties": {
|
||||||
|
"plan_type": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Plan Type",
|
||||||
|
"description": "Subscription tier"
|
||||||
|
},
|
||||||
|
"cost_source": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Cost Source",
|
||||||
|
"description": "Cost source: 'quota' or 'credits'"
|
||||||
|
},
|
||||||
|
"customer_id": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Customer Id",
|
||||||
|
"description": "Customer ID for billing records"
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"type": "object",
|
||||||
|
"title": "BillingContext",
|
||||||
|
"description": "Billing context for LLM request cost tracking."
|
||||||
|
},
|
||||||
"Block": {
|
"Block": {
|
||||||
"properties": {
|
"properties": {
|
||||||
"value": {
|
"value": {
|
||||||
@@ -42964,6 +43007,17 @@
|
|||||||
],
|
],
|
||||||
"title": "Llm Config",
|
"title": "Llm Config",
|
||||||
"description": "LLM configuration used for this call (non-summarization calls only)"
|
"description": "LLM configuration used for this call (non-summarization calls only)"
|
||||||
|
},
|
||||||
|
"billing_context": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"$ref": "#/components/schemas/BillingContext"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"description": "Billing context from request headers"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
|||||||
220
fern/scripts/prepare-openapi.ts
Normal file
220
fern/scripts/prepare-openapi.ts
Normal file
@@ -0,0 +1,220 @@
|
|||||||
|
import * as fs from 'fs';
|
||||||
|
import * as path from 'path';
|
||||||
|
|
||||||
|
import { omit } from 'lodash';
|
||||||
|
import { execSync } from 'child_process';
|
||||||
|
import { merge, isErrorResult } from 'openapi-merge';
|
||||||
|
import type { Swagger } from 'atlassian-openapi';
|
||||||
|
import { RESTRICTED_ROUTE_BASE_PATHS } from '@letta-cloud/sdk-core';
|
||||||
|
|
||||||
|
const lettaWebOpenAPIPath = path.join(
|
||||||
|
__dirname,
|
||||||
|
'..',
|
||||||
|
'..',
|
||||||
|
'..',
|
||||||
|
'web',
|
||||||
|
'autogenerated',
|
||||||
|
'letta-web-openapi.json',
|
||||||
|
);
|
||||||
|
const lettaAgentsAPIPath = path.join(
|
||||||
|
__dirname,
|
||||||
|
'..',
|
||||||
|
'..',
|
||||||
|
'letta',
|
||||||
|
'server',
|
||||||
|
'openapi_letta.json',
|
||||||
|
);
|
||||||
|
|
||||||
|
const lettaWebOpenAPI = JSON.parse(
|
||||||
|
fs.readFileSync(lettaWebOpenAPIPath, 'utf8'),
|
||||||
|
) as Swagger.SwaggerV3;
|
||||||
|
const lettaAgentsAPI = JSON.parse(
|
||||||
|
fs.readFileSync(lettaAgentsAPIPath, 'utf8'),
|
||||||
|
) as Swagger.SwaggerV3;
|
||||||
|
|
||||||
|
// removes any routes that are restricted
|
||||||
|
lettaAgentsAPI.paths = Object.fromEntries(
|
||||||
|
Object.entries(lettaAgentsAPI.paths).filter(([path]) =>
|
||||||
|
RESTRICTED_ROUTE_BASE_PATHS.every(
|
||||||
|
(restrictedPath) => !path.startsWith(restrictedPath),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
);
|
||||||
|
|
||||||
|
const lettaAgentsAPIWithNoEndslash = Object.keys(lettaAgentsAPI.paths).reduce(
|
||||||
|
(acc, path) => {
|
||||||
|
const pathWithoutSlash = path.endsWith('/')
|
||||||
|
? path.slice(0, path.length - 1)
|
||||||
|
: path;
|
||||||
|
acc[pathWithoutSlash] = lettaAgentsAPI.paths[path];
|
||||||
|
return acc;
|
||||||
|
},
|
||||||
|
{} as Swagger.SwaggerV3['paths'],
|
||||||
|
);
|
||||||
|
|
||||||
|
// remove duplicate paths, delete from letta-web-openapi if it exists in sdk-core
|
||||||
|
// some paths will have an extra / at the end, so we need to remove that as well
|
||||||
|
lettaWebOpenAPI.paths = Object.fromEntries(
|
||||||
|
Object.entries(lettaWebOpenAPI.paths).filter(([path]) => {
|
||||||
|
const pathWithoutSlash = path.endsWith('/')
|
||||||
|
? path.slice(0, path.length - 1)
|
||||||
|
: path;
|
||||||
|
return !lettaAgentsAPIWithNoEndslash[pathWithoutSlash];
|
||||||
|
}),
|
||||||
|
);
|
||||||
|
|
||||||
|
const agentStatePathsToOverride: Array<[string, string]> = [
|
||||||
|
['/v1/templates/{project}/{template_version}/agents', '201'],
|
||||||
|
['/v1/agents/search', '200'],
|
||||||
|
];
|
||||||
|
|
||||||
|
for (const [path, responseCode] of agentStatePathsToOverride) {
|
||||||
|
if (lettaWebOpenAPI.paths[path]?.post?.responses?.[responseCode]) {
|
||||||
|
// Get direct reference to the schema object
|
||||||
|
const responseSchema =
|
||||||
|
lettaWebOpenAPI.paths[path].post.responses[responseCode];
|
||||||
|
const contentSchema = responseSchema.content['application/json'].schema;
|
||||||
|
|
||||||
|
// Replace the entire agents array schema with the reference
|
||||||
|
if (contentSchema.properties?.agents) {
|
||||||
|
contentSchema.properties.agents = {
|
||||||
|
type: 'array',
|
||||||
|
items: {
|
||||||
|
$ref: '#/components/schemas/AgentState',
|
||||||
|
},
|
||||||
|
};
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// go through the paths and remove "user_id"/"actor_id" from the headers
|
||||||
|
for (const path of Object.keys(lettaAgentsAPI.paths)) {
|
||||||
|
for (const method of Object.keys(lettaAgentsAPI.paths[path])) {
|
||||||
|
// @ts-expect-error - a
|
||||||
|
if (lettaAgentsAPI.paths[path][method]?.parameters) {
|
||||||
|
// @ts-expect-error - a
|
||||||
|
lettaAgentsAPI.paths[path][method].parameters = lettaAgentsAPI.paths[
|
||||||
|
path
|
||||||
|
][method].parameters.filter(
|
||||||
|
(param: Record<string, string>) =>
|
||||||
|
param.in !== 'header' ||
|
||||||
|
(
|
||||||
|
param.name !== 'user_id' &&
|
||||||
|
param.name !== 'User-Agent' &&
|
||||||
|
param.name !== 'X-Project-Id' &&
|
||||||
|
param.name !== 'X-Letta-Source' &&
|
||||||
|
param.name !== 'X-Stainless-Package-Version' &&
|
||||||
|
!param.name.startsWith('X-Experimental') &&
|
||||||
|
!param.name.startsWith('X-Billing')
|
||||||
|
),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const result = merge([
|
||||||
|
{
|
||||||
|
oas: lettaAgentsAPI,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
oas: lettaWebOpenAPI,
|
||||||
|
},
|
||||||
|
]);
|
||||||
|
|
||||||
|
if (isErrorResult(result)) {
|
||||||
|
console.error(`${result.message} (${result.type})`);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
|
||||||
|
result.output.openapi = '3.1.0';
|
||||||
|
result.output.info = {
|
||||||
|
title: 'Letta API',
|
||||||
|
version: '1.0.0',
|
||||||
|
};
|
||||||
|
|
||||||
|
result.output.servers = [
|
||||||
|
{
|
||||||
|
url: 'https://app.letta.com',
|
||||||
|
description: 'Letta Cloud',
|
||||||
|
},
|
||||||
|
{
|
||||||
|
url: 'http://localhost:8283',
|
||||||
|
description: 'Self-hosted',
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
result.output.components = {
|
||||||
|
...result.output.components,
|
||||||
|
securitySchemes: {
|
||||||
|
bearerAuth: {
|
||||||
|
type: 'http',
|
||||||
|
scheme: 'bearer',
|
||||||
|
},
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
|
result.output.security = [
|
||||||
|
...(result.output.security || []),
|
||||||
|
{
|
||||||
|
bearerAuth: [],
|
||||||
|
},
|
||||||
|
];
|
||||||
|
|
||||||
|
// omit all instances of "user_id" from the openapi.json file
|
||||||
|
function deepOmitPreserveArrays(obj: unknown, key: string): unknown {
|
||||||
|
if (Array.isArray(obj)) {
|
||||||
|
return obj.map((item) => deepOmitPreserveArrays(item, key));
|
||||||
|
}
|
||||||
|
|
||||||
|
if (typeof obj !== 'object' || obj === null) {
|
||||||
|
return obj;
|
||||||
|
}
|
||||||
|
|
||||||
|
if (key in obj) {
|
||||||
|
return omit(obj, key);
|
||||||
|
}
|
||||||
|
|
||||||
|
return Object.fromEntries(
|
||||||
|
Object.entries(obj).map(([k, v]) => [k, deepOmitPreserveArrays(v, key)]),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||||
|
// @ts-ignore
|
||||||
|
result.output.components = deepOmitPreserveArrays(
|
||||||
|
result.output.components,
|
||||||
|
'user_id',
|
||||||
|
);
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||||
|
// @ts-ignore
|
||||||
|
result.output.components = deepOmitPreserveArrays(
|
||||||
|
result.output.components,
|
||||||
|
'actor_id',
|
||||||
|
);
|
||||||
|
|
||||||
|
// eslint-disable-next-line @typescript-eslint/ban-ts-comment
|
||||||
|
// @ts-ignore
|
||||||
|
result.output.components = deepOmitPreserveArrays(
|
||||||
|
result.output.components,
|
||||||
|
'organization_id',
|
||||||
|
);
|
||||||
|
|
||||||
|
fs.writeFileSync(
|
||||||
|
path.join(__dirname, '..', 'openapi.json'),
|
||||||
|
JSON.stringify(result.output, null, 2),
|
||||||
|
);
|
||||||
|
|
||||||
|
function formatOpenAPIJson() {
|
||||||
|
const openApiPath = path.join(__dirname, '..', 'openapi.json');
|
||||||
|
|
||||||
|
try {
|
||||||
|
execSync(`npx prettier --write "${openApiPath}"`, { stdio: 'inherit' });
|
||||||
|
console.log('Successfully formatted openapi.json with Prettier');
|
||||||
|
} catch (error) {
|
||||||
|
console.error('Error formatting openapi.json:', error);
|
||||||
|
process.exit(1);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
formatOpenAPIJson();
|
||||||
@@ -7,6 +7,7 @@ from letta.schemas.letta_message import LettaMessage
|
|||||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, ChoiceLogprobs, ToolCall
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.telemetry_manager import TelemetryManager
|
from letta.services.telemetry_manager import TelemetryManager
|
||||||
@@ -31,6 +32,7 @@ class LettaLLMAdapter(ABC):
|
|||||||
run_id: str | None = None,
|
run_id: str | None = None,
|
||||||
org_id: str | None = None,
|
org_id: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
|
billing_context: BillingContext | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm_client: LLMClientBase = llm_client
|
self.llm_client: LLMClientBase = llm_client
|
||||||
self.llm_config: LLMConfig = llm_config
|
self.llm_config: LLMConfig = llm_config
|
||||||
@@ -40,6 +42,7 @@ class LettaLLMAdapter(ABC):
|
|||||||
self.run_id: str | None = run_id
|
self.run_id: str | None = run_id
|
||||||
self.org_id: str | None = org_id
|
self.org_id: str | None = org_id
|
||||||
self.user_id: str | None = user_id
|
self.user_id: str | None = user_id
|
||||||
|
self.billing_context: BillingContext | None = billing_context
|
||||||
self.message_id: str | None = None
|
self.message_id: str | None = None
|
||||||
self.request_data: dict | None = None
|
self.request_data: dict | None = None
|
||||||
self.response_data: dict | None = None
|
self.response_data: dict | None = None
|
||||||
|
|||||||
@@ -10,7 +10,7 @@ from letta.otel.tracing import log_attributes, safe_json_dumps, trace_method
|
|||||||
from letta.schemas.enums import LLMCallType, ProviderType
|
from letta.schemas.enums import LLMCallType, ProviderType
|
||||||
from letta.schemas.letta_message import LettaMessage
|
from letta.schemas.letta_message import LettaMessage
|
||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.provider_trace import ProviderTrace
|
from letta.schemas.provider_trace import BillingContext, ProviderTrace
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
from letta.utils import safe_create_task
|
from letta.utils import safe_create_task
|
||||||
@@ -36,6 +36,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
|||||||
run_id: str | None = None,
|
run_id: str | None = None,
|
||||||
org_id: str | None = None,
|
org_id: str | None = None,
|
||||||
user_id: str | None = None,
|
user_id: str | None = None,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(
|
super().__init__(
|
||||||
llm_client,
|
llm_client,
|
||||||
@@ -46,6 +47,7 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
org_id=org_id,
|
org_id=org_id,
|
||||||
user_id=user_id,
|
user_id=user_id,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -51,6 +51,7 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
|||||||
org_id=self.org_id,
|
org_id=self.org_id,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||||
|
billing_context=self.billing_context,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
||||||
|
|||||||
@@ -278,6 +278,7 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
|||||||
org_id=self.org_id,
|
org_id=self.org_id,
|
||||||
user_id=self.user_id,
|
user_id=self.user_id,
|
||||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||||
|
billing_context=self.billing_context,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
label="create_provider_trace",
|
label="create_provider_trace",
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from letta.schemas.letta_message_content import TextContent
|
|||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.agent_manager import AgentManager
|
from letta.services.agent_manager import AgentManager
|
||||||
@@ -51,7 +52,11 @@ class BaseAgent(ABC):
|
|||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def step(
|
async def step(
|
||||||
self, input_messages: List[MessageCreate], max_steps: int = DEFAULT_MAX_STEPS, run_id: Optional[str] = None
|
self,
|
||||||
|
input_messages: List[MessageCreate],
|
||||||
|
max_steps: int = DEFAULT_MAX_STEPS,
|
||||||
|
run_id: Optional[str] = None,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
"""
|
"""
|
||||||
Main execution loop for the agent.
|
Main execution loop for the agent.
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ from letta.schemas.user import User
|
|||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from letta.schemas.letta_request import ClientToolSchema
|
from letta.schemas.letta_request import ClientToolSchema
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
|
|
||||||
|
|
||||||
class BaseAgentV2(ABC):
|
class BaseAgentV2(ABC):
|
||||||
@@ -52,6 +53,7 @@ class BaseAgentV2(ABC):
|
|||||||
request_start_timestamp_ns: int | None = None,
|
request_start_timestamp_ns: int | None = None,
|
||||||
client_tools: list["ClientToolSchema"] | None = None,
|
client_tools: list["ClientToolSchema"] | None = None,
|
||||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
"""
|
"""
|
||||||
Execute the agent loop in blocking mode, returning all messages at once.
|
Execute the agent loop in blocking mode, returning all messages at once.
|
||||||
@@ -76,6 +78,7 @@ class BaseAgentV2(ABC):
|
|||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
client_tools: list["ClientToolSchema"] | None = None,
|
client_tools: list["ClientToolSchema"] | None = None,
|
||||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
) -> AsyncGenerator[LettaMessage | LegacyLettaMessage | MessageStreamStatus, None]:
|
||||||
"""
|
"""
|
||||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||||
|
|||||||
@@ -48,6 +48,7 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
UsageStatisticsCompletionTokenDetails,
|
UsageStatisticsCompletionTokenDetails,
|
||||||
UsageStatisticsPromptTokenDetails,
|
UsageStatisticsPromptTokenDetails,
|
||||||
)
|
)
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.step import StepProgression
|
from letta.schemas.step import StepProgression
|
||||||
from letta.schemas.step_metrics import StepMetrics
|
from letta.schemas.step_metrics import StepMetrics
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
@@ -179,6 +180,7 @@ class LettaAgent(BaseAgent):
|
|||||||
request_start_timestamp_ns: int | None = None,
|
request_start_timestamp_ns: int | None = None,
|
||||||
include_return_message_types: list[MessageType] | None = None,
|
include_return_message_types: list[MessageType] | None = None,
|
||||||
dry_run: bool = False,
|
dry_run: bool = False,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> Union[LettaResponse, dict]:
|
) -> Union[LettaResponse, dict]:
|
||||||
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
# TODO (cliandy): pass in run_id and use at send_message endpoints for all step functions
|
||||||
agent_state = await self.agent_manager.get_agent_by_id_async(
|
agent_state = await self.agent_manager.get_agent_by_id_async(
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from letta.schemas.openai.chat_completion_response import (
|
|||||||
UsageStatisticsCompletionTokenDetails,
|
UsageStatisticsCompletionTokenDetails,
|
||||||
UsageStatisticsPromptTokenDetails,
|
UsageStatisticsPromptTokenDetails,
|
||||||
)
|
)
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.step import Step, StepProgression
|
from letta.schemas.step import Step, StepProgression
|
||||||
from letta.schemas.step_metrics import StepMetrics
|
from letta.schemas.step_metrics import StepMetrics
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
@@ -185,6 +186,7 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
request_start_timestamp_ns: int | None = None,
|
request_start_timestamp_ns: int | None = None,
|
||||||
client_tools: list[ClientToolSchema] | None = None,
|
client_tools: list[ClientToolSchema] | None = None,
|
||||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
"""
|
"""
|
||||||
Execute the agent loop in blocking mode, returning all messages at once.
|
Execute the agent loop in blocking mode, returning all messages at once.
|
||||||
@@ -290,6 +292,7 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
|
conversation_id: str | None = None, # Not used in V2, but accepted for API compatibility
|
||||||
client_tools: list[ClientToolSchema] | None = None,
|
client_tools: list[ClientToolSchema] | None = None,
|
||||||
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
include_compaction_messages: bool = False, # Not used in V2, but accepted for API compatibility
|
||||||
|
billing_context: BillingContext | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||||
|
|||||||
@@ -45,6 +45,7 @@ from letta.schemas.letta_response import LettaResponse, TurnTokenData
|
|||||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||||
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, ToolCall, ToolCallDenial, UsageStatistics
|
from letta.schemas.openai.chat_completion_response import ChoiceLogprobs, ToolCall, ToolCallDenial, UsageStatistics
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.step import StepProgression
|
from letta.schemas.step import StepProgression
|
||||||
from letta.schemas.step_metrics import StepMetrics
|
from letta.schemas.step_metrics import StepMetrics
|
||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
@@ -149,6 +150,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
client_tools: list[ClientToolSchema] | None = None,
|
client_tools: list[ClientToolSchema] | None = None,
|
||||||
include_compaction_messages: bool = False,
|
include_compaction_messages: bool = False,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
"""
|
"""
|
||||||
Execute the agent loop in blocking mode, returning all messages at once.
|
Execute the agent loop in blocking mode, returning all messages at once.
|
||||||
@@ -232,6 +234,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
org_id=self.actor.organization_id,
|
org_id=self.actor.organization_id,
|
||||||
user_id=self.actor.id,
|
user_id=self.actor.id,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
credit_task = None
|
credit_task = None
|
||||||
@@ -362,6 +365,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
client_tools: list[ClientToolSchema] | None = None,
|
client_tools: list[ClientToolSchema] | None = None,
|
||||||
include_compaction_messages: bool = False,
|
include_compaction_messages: bool = False,
|
||||||
|
billing_context: BillingContext | None = None,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
"""
|
"""
|
||||||
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
Execute the agent loop in streaming mode, yielding chunks as they become available.
|
||||||
@@ -419,6 +423,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
org_id=self.actor.organization_id,
|
org_id=self.actor.organization_id,
|
||||||
user_id=self.actor.id,
|
user_id=self.actor.id,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
elif use_sglang_native:
|
elif use_sglang_native:
|
||||||
# Use SGLang native adapter for multi-turn RL training
|
# Use SGLang native adapter for multi-turn RL training
|
||||||
@@ -431,6 +436,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
org_id=self.actor.organization_id,
|
org_id=self.actor.organization_id,
|
||||||
user_id=self.actor.id,
|
user_id=self.actor.id,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
# Reset turns tracking for this step
|
# Reset turns tracking for this step
|
||||||
self.turns = []
|
self.turns = []
|
||||||
@@ -444,6 +450,7 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
org_id=self.actor.organization_id,
|
org_id=self.actor.organization_id,
|
||||||
user_id=self.actor.id,
|
user_id=self.actor.id,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ from letta.schemas.letta_message import MessageType
|
|||||||
from letta.schemas.letta_message_content import TextContent
|
from letta.schemas.letta_message_content import TextContent
|
||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.message import Message, MessageCreate
|
from letta.schemas.message import Message, MessageCreate
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.run import Run
|
from letta.schemas.run import Run
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.agent_manager import AgentManager
|
from letta.services.agent_manager import AgentManager
|
||||||
@@ -69,6 +70,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
|||||||
use_assistant_message: bool = True,
|
use_assistant_message: bool = True,
|
||||||
request_start_timestamp_ns: int | None = None,
|
request_start_timestamp_ns: int | None = None,
|
||||||
include_return_message_types: list[MessageType] | None = None,
|
include_return_message_types: list[MessageType] | None = None,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
run_ids = []
|
run_ids = []
|
||||||
|
|
||||||
@@ -100,6 +102,7 @@ class SleeptimeMultiAgentV2(BaseAgent):
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
use_assistant_message=use_assistant_message,
|
use_assistant_message=use_assistant_message,
|
||||||
include_return_message_types=include_return_message_types,
|
include_return_message_types=include_return_message_types,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Get last response messages
|
# Get last response messages
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ from letta.schemas.letta_request import ClientToolSchema
|
|||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.letta_stop_reason import StopReasonType
|
from letta.schemas.letta_stop_reason import StopReasonType
|
||||||
from letta.schemas.message import Message, MessageCreate
|
from letta.schemas.message import Message, MessageCreate
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.run import Run, RunUpdate
|
from letta.schemas.run import Run, RunUpdate
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.group_manager import GroupManager
|
from letta.services.group_manager import GroupManager
|
||||||
@@ -47,6 +48,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
|||||||
request_start_timestamp_ns: int | None = None,
|
request_start_timestamp_ns: int | None = None,
|
||||||
client_tools: list[ClientToolSchema] | None = None,
|
client_tools: list[ClientToolSchema] | None = None,
|
||||||
include_compaction_messages: bool = False,
|
include_compaction_messages: bool = False,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
self.run_ids = []
|
self.run_ids = []
|
||||||
|
|
||||||
@@ -62,6 +64,7 @@ class SleeptimeMultiAgentV3(LettaAgentV2):
|
|||||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||||
client_tools=client_tools,
|
client_tools=client_tools,
|
||||||
include_compaction_messages=include_compaction_messages,
|
include_compaction_messages=include_compaction_messages,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.run_sleeptime_agents()
|
await self.run_sleeptime_agents()
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ from letta.schemas.letta_request import ClientToolSchema
|
|||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.letta_stop_reason import StopReasonType
|
from letta.schemas.letta_stop_reason import StopReasonType
|
||||||
from letta.schemas.message import Message, MessageCreate
|
from letta.schemas.message import Message, MessageCreate
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.run import Run, RunUpdate
|
from letta.schemas.run import Run, RunUpdate
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.group_manager import GroupManager
|
from letta.services.group_manager import GroupManager
|
||||||
@@ -47,6 +48,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
|||||||
conversation_id: str | None = None,
|
conversation_id: str | None = None,
|
||||||
client_tools: list[ClientToolSchema] | None = None,
|
client_tools: list[ClientToolSchema] | None = None,
|
||||||
include_compaction_messages: bool = False,
|
include_compaction_messages: bool = False,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> LettaResponse:
|
) -> LettaResponse:
|
||||||
self.run_ids = []
|
self.run_ids = []
|
||||||
|
|
||||||
@@ -63,6 +65,7 @@ class SleeptimeMultiAgentV4(LettaAgentV3):
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
client_tools=client_tools,
|
client_tools=client_tools,
|
||||||
include_compaction_messages=include_compaction_messages,
|
include_compaction_messages=include_compaction_messages,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
run_ids = await self.run_sleeptime_agents()
|
run_ids = await self.run_sleeptime_agents()
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ from letta.schemas.enums import AgentType, LLMCallType, ProviderCategory
|
|||||||
from letta.schemas.llm_config import LLMConfig
|
from letta.schemas.llm_config import LLMConfig
|
||||||
from letta.schemas.message import Message
|
from letta.schemas.message import Message
|
||||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||||
from letta.schemas.provider_trace import ProviderTrace
|
from letta.schemas.provider_trace import BillingContext, ProviderTrace
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.services.telemetry_manager import TelemetryManager
|
from letta.services.telemetry_manager import TelemetryManager
|
||||||
from letta.settings import settings
|
from letta.settings import settings
|
||||||
@@ -48,6 +48,7 @@ class LLMClientBase:
|
|||||||
self._telemetry_user_id: Optional[str] = None
|
self._telemetry_user_id: Optional[str] = None
|
||||||
self._telemetry_compaction_settings: Optional[Dict] = None
|
self._telemetry_compaction_settings: Optional[Dict] = None
|
||||||
self._telemetry_llm_config: Optional[Dict] = None
|
self._telemetry_llm_config: Optional[Dict] = None
|
||||||
|
self._telemetry_billing_context: Optional[BillingContext] = None
|
||||||
|
|
||||||
def set_telemetry_context(
|
def set_telemetry_context(
|
||||||
self,
|
self,
|
||||||
@@ -62,6 +63,7 @@ class LLMClientBase:
|
|||||||
compaction_settings: Optional[Dict] = None,
|
compaction_settings: Optional[Dict] = None,
|
||||||
llm_config: Optional[Dict] = None,
|
llm_config: Optional[Dict] = None,
|
||||||
actor: Optional["User"] = None,
|
actor: Optional["User"] = None,
|
||||||
|
billing_context: Optional[BillingContext] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set telemetry context for provider trace logging."""
|
"""Set telemetry context for provider trace logging."""
|
||||||
if actor is not None:
|
if actor is not None:
|
||||||
@@ -76,6 +78,7 @@ class LLMClientBase:
|
|||||||
self._telemetry_user_id = user_id
|
self._telemetry_user_id = user_id
|
||||||
self._telemetry_compaction_settings = compaction_settings
|
self._telemetry_compaction_settings = compaction_settings
|
||||||
self._telemetry_llm_config = llm_config
|
self._telemetry_llm_config = llm_config
|
||||||
|
self._telemetry_billing_context = billing_context
|
||||||
|
|
||||||
def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics:
|
def extract_usage_statistics(self, response_data: Optional[dict], llm_config: LLMConfig) -> LettaUsageStatistics:
|
||||||
"""Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics."""
|
"""Provider-specific usage parsing hook (override in subclasses). Returns LettaUsageStatistics."""
|
||||||
@@ -125,6 +128,7 @@ class LLMClientBase:
|
|||||||
user_id=self._telemetry_user_id,
|
user_id=self._telemetry_user_id,
|
||||||
compaction_settings=self._telemetry_compaction_settings,
|
compaction_settings=self._telemetry_compaction_settings,
|
||||||
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
||||||
|
billing_context=self._telemetry_billing_context,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -186,6 +190,7 @@ class LLMClientBase:
|
|||||||
user_id=self._telemetry_user_id,
|
user_id=self._telemetry_user_id,
|
||||||
compaction_settings=self._telemetry_compaction_settings,
|
compaction_settings=self._telemetry_compaction_settings,
|
||||||
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
llm_config=llm_config.model_dump() if llm_config else self._telemetry_llm_config,
|
||||||
|
billing_context=self._telemetry_billing_context,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -95,6 +95,11 @@ class LLMTrace(LettaBase):
|
|||||||
response_json: str = Field(..., description="Full response payload as JSON string")
|
response_json: str = Field(..., description="Full response payload as JSON string")
|
||||||
llm_config_json: str = Field(default="", description="LLM config as JSON string")
|
llm_config_json: str = Field(default="", description="LLM config as JSON string")
|
||||||
|
|
||||||
|
# Billing context
|
||||||
|
billing_plan_type: Optional[str] = Field(default=None, description="Subscription tier (e.g., 'basic', 'standard', 'max', 'enterprise')")
|
||||||
|
billing_cost_source: Optional[str] = Field(default=None, description="Cost source: 'quota' or 'credits'")
|
||||||
|
billing_customer_id: Optional[str] = Field(default=None, description="Customer ID for cross-referencing billing records")
|
||||||
|
|
||||||
# Timestamp
|
# Timestamp
|
||||||
created_at: datetime = Field(default_factory=get_utc_time, description="When the trace was created")
|
created_at: datetime = Field(default_factory=get_utc_time, description="When the trace was created")
|
||||||
|
|
||||||
@@ -128,6 +133,9 @@ class LLMTrace(LettaBase):
|
|||||||
self.request_json,
|
self.request_json,
|
||||||
self.response_json,
|
self.response_json,
|
||||||
self.llm_config_json,
|
self.llm_config_json,
|
||||||
|
self.billing_plan_type or "",
|
||||||
|
self.billing_cost_source or "",
|
||||||
|
self.billing_customer_id or "",
|
||||||
self.created_at,
|
self.created_at,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -162,5 +170,8 @@ class LLMTrace(LettaBase):
|
|||||||
"request_json",
|
"request_json",
|
||||||
"response_json",
|
"response_json",
|
||||||
"llm_config_json",
|
"llm_config_json",
|
||||||
|
"billing_plan_type",
|
||||||
|
"billing_cost_source",
|
||||||
|
"billing_customer_id",
|
||||||
"created_at",
|
"created_at",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -3,13 +3,21 @@ from __future__ import annotations
|
|||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from pydantic import Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from letta.helpers.datetime_helpers import get_utc_time
|
from letta.helpers.datetime_helpers import get_utc_time
|
||||||
from letta.schemas.enums import PrimitiveType
|
from letta.schemas.enums import PrimitiveType
|
||||||
from letta.schemas.letta_base import OrmMetadataBase
|
from letta.schemas.letta_base import OrmMetadataBase
|
||||||
|
|
||||||
|
|
||||||
|
class BillingContext(BaseModel):
|
||||||
|
"""Billing context for LLM request cost tracking."""
|
||||||
|
|
||||||
|
plan_type: Optional[str] = Field(None, description="Subscription tier")
|
||||||
|
cost_source: Optional[str] = Field(None, description="Cost source: 'quota' or 'credits'")
|
||||||
|
customer_id: Optional[str] = Field(None, description="Customer ID for billing records")
|
||||||
|
|
||||||
|
|
||||||
class BaseProviderTrace(OrmMetadataBase):
|
class BaseProviderTrace(OrmMetadataBase):
|
||||||
__id_prefix__ = PrimitiveType.PROVIDER_TRACE.value
|
__id_prefix__ = PrimitiveType.PROVIDER_TRACE.value
|
||||||
|
|
||||||
@@ -53,6 +61,8 @@ class ProviderTrace(BaseProviderTrace):
|
|||||||
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
|
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
|
||||||
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
|
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
|
||||||
|
|
||||||
|
billing_context: Optional[BillingContext] = Field(None, description="Billing context from request headers")
|
||||||
|
|
||||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ from pydantic import BaseModel
|
|||||||
from letta.errors import LettaInvalidArgumentError
|
from letta.errors import LettaInvalidArgumentError
|
||||||
from letta.otel.tracing import tracer
|
from letta.otel.tracing import tracer
|
||||||
from letta.schemas.enums import PrimitiveType
|
from letta.schemas.enums import PrimitiveType
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.validators import PRIMITIVE_ID_PATTERNS
|
from letta.validators import PRIMITIVE_ID_PATTERNS
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@@ -30,18 +31,24 @@ class HeaderParams(BaseModel):
|
|||||||
letta_source: Optional[str] = None
|
letta_source: Optional[str] = None
|
||||||
sdk_version: Optional[str] = None
|
sdk_version: Optional[str] = None
|
||||||
experimental_params: Optional[ExperimentalParams] = None
|
experimental_params: Optional[ExperimentalParams] = None
|
||||||
|
billing_context: Optional[BillingContext] = None
|
||||||
|
|
||||||
|
|
||||||
def get_headers(
|
def get_headers(
|
||||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||||
user_agent: Optional[str] = Header(None, alias="User-Agent"),
|
user_agent: Optional[str] = Header(None, alias="User-Agent"),
|
||||||
project_id: Optional[str] = Header(None, alias="X-Project-Id"),
|
project_id: Optional[str] = Header(None, alias="X-Project-Id"),
|
||||||
letta_source: Optional[str] = Header(None, alias="X-Letta-Source"),
|
letta_source: Optional[str] = Header(None, alias="X-Letta-Source", include_in_schema=False),
|
||||||
sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version"),
|
sdk_version: Optional[str] = Header(None, alias="X-Stainless-Package-Version", include_in_schema=False),
|
||||||
message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async"),
|
message_async: Optional[str] = Header(None, alias="X-Experimental-Message-Async", include_in_schema=False),
|
||||||
letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent"),
|
letta_v1_agent: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent", include_in_schema=False),
|
||||||
letta_v1_agent_message_async: Optional[str] = Header(None, alias="X-Experimental-Letta-V1-Agent-Message-Async"),
|
letta_v1_agent_message_async: Optional[str] = Header(
|
||||||
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox"),
|
None, alias="X-Experimental-Letta-V1-Agent-Message-Async", include_in_schema=False
|
||||||
|
),
|
||||||
|
modal_sandbox: Optional[str] = Header(None, alias="X-Experimental-Modal-Sandbox", include_in_schema=False),
|
||||||
|
billing_plan_type: Optional[str] = Header(None, alias="X-Billing-Plan-Type", include_in_schema=False),
|
||||||
|
billing_cost_source: Optional[str] = Header(None, alias="X-Billing-Cost-Source", include_in_schema=False),
|
||||||
|
billing_customer_id: Optional[str] = Header(None, alias="X-Billing-Customer-Id", include_in_schema=False),
|
||||||
) -> HeaderParams:
|
) -> HeaderParams:
|
||||||
"""Dependency injection function to extract common headers from requests."""
|
"""Dependency injection function to extract common headers from requests."""
|
||||||
with tracer.start_as_current_span("dependency.get_headers"):
|
with tracer.start_as_current_span("dependency.get_headers"):
|
||||||
@@ -63,6 +70,13 @@ def get_headers(
|
|||||||
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
|
letta_v1_agent_message_async=(letta_v1_agent_message_async == "true") if letta_v1_agent_message_async else None,
|
||||||
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
|
modal_sandbox=(modal_sandbox == "true") if modal_sandbox else None,
|
||||||
),
|
),
|
||||||
|
billing_context=BillingContext(
|
||||||
|
plan_type=billing_plan_type,
|
||||||
|
cost_source=billing_cost_source,
|
||||||
|
customer_id=billing_customer_id,
|
||||||
|
)
|
||||||
|
if any([billing_plan_type, billing_cost_source, billing_customer_id])
|
||||||
|
else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ from letta.schemas.memory import (
|
|||||||
)
|
)
|
||||||
from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult
|
from letta.schemas.message import Message, MessageCreate, MessageCreateType, MessageSearchRequest, MessageSearchResult
|
||||||
from letta.schemas.passage import Passage
|
from letta.schemas.passage import Passage
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||||
from letta.schemas.source import Source
|
from letta.schemas.source import Source
|
||||||
from letta.schemas.tool import Tool
|
from letta.schemas.tool import Tool
|
||||||
@@ -1697,6 +1698,7 @@ async def send_message(
|
|||||||
actor=actor,
|
actor=actor,
|
||||||
request=request,
|
request=request,
|
||||||
run_type="send_message",
|
run_type="send_message",
|
||||||
|
billing_context=headers.billing_context,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -1767,6 +1769,7 @@ async def send_message(
|
|||||||
include_return_message_types=request.include_return_message_types,
|
include_return_message_types=request.include_return_message_types,
|
||||||
client_tools=request.client_tools,
|
client_tools=request.client_tools,
|
||||||
include_compaction_messages=request.include_compaction_messages,
|
include_compaction_messages=request.include_compaction_messages,
|
||||||
|
billing_context=headers.billing_context,
|
||||||
)
|
)
|
||||||
run_status = result.stop_reason.stop_reason.run_status
|
run_status = result.stop_reason.stop_reason.run_status
|
||||||
return result
|
return result
|
||||||
@@ -1845,6 +1848,7 @@ async def send_message_streaming(
|
|||||||
actor=actor,
|
actor=actor,
|
||||||
request=request,
|
request=request,
|
||||||
run_type="send_message_streaming",
|
run_type="send_message_streaming",
|
||||||
|
billing_context=headers.billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
return result
|
return result
|
||||||
@@ -2043,6 +2047,7 @@ async def _process_message_background(
|
|||||||
include_return_message_types: list[MessageType] | None = None,
|
include_return_message_types: list[MessageType] | None = None,
|
||||||
override_model: str | None = None,
|
override_model: str | None = None,
|
||||||
include_compaction_messages: bool = False,
|
include_compaction_messages: bool = False,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Background task to process the message and update run status."""
|
"""Background task to process the message and update run status."""
|
||||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||||
@@ -2074,6 +2079,7 @@ async def _process_message_background(
|
|||||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||||
include_return_message_types=include_return_message_types,
|
include_return_message_types=include_return_message_types,
|
||||||
include_compaction_messages=include_compaction_messages,
|
include_compaction_messages=include_compaction_messages,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
runs_manager = RunManager()
|
runs_manager = RunManager()
|
||||||
from letta.schemas.enums import RunStatus
|
from letta.schemas.enums import RunStatus
|
||||||
@@ -2242,6 +2248,7 @@ async def send_message_async(
|
|||||||
include_return_message_types=request.include_return_message_types,
|
include_return_message_types=request.include_return_message_types,
|
||||||
override_model=request.override_model,
|
override_model=request.override_model,
|
||||||
include_compaction_messages=request.include_compaction_messages,
|
include_compaction_messages=request.include_compaction_messages,
|
||||||
|
billing_context=headers.billing_context,
|
||||||
),
|
),
|
||||||
label=f"process_message_background_{run.id}",
|
label=f"process_message_background_{run.id}",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ from letta.schemas.job import LettaRequestConfig
|
|||||||
from letta.schemas.letta_message import LettaMessageUnion
|
from letta.schemas.letta_message import LettaMessageUnion
|
||||||
from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest
|
from letta.schemas.letta_request import ConversationMessageRequest, LettaStreamingRequest, RetrieveStreamRequest
|
||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.run import Run as PydanticRun
|
from letta.schemas.run import Run as PydanticRun
|
||||||
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
from letta.server.rest_api.dependencies import HeaderParams, get_headers, get_letta_server
|
||||||
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
from letta.server.rest_api.redis_stream_manager import redis_sse_stream_generator
|
||||||
@@ -211,6 +212,7 @@ async def _send_agent_direct_message(
|
|||||||
request: ConversationMessageRequest,
|
request: ConversationMessageRequest,
|
||||||
server: SyncServer,
|
server: SyncServer,
|
||||||
actor,
|
actor,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> StreamingResponse | LettaResponse:
|
) -> StreamingResponse | LettaResponse:
|
||||||
"""
|
"""
|
||||||
Handle agent-direct messaging with locking but without conversation features.
|
Handle agent-direct messaging with locking but without conversation features.
|
||||||
@@ -244,6 +246,7 @@ async def _send_agent_direct_message(
|
|||||||
run_type="send_message",
|
run_type="send_message",
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
should_lock=True,
|
should_lock=True,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -299,6 +302,7 @@ async def _send_agent_direct_message(
|
|||||||
client_tools=request.client_tools,
|
client_tools=request.client_tools,
|
||||||
conversation_id=None,
|
conversation_id=None,
|
||||||
include_compaction_messages=request.include_compaction_messages,
|
include_compaction_messages=request.include_compaction_messages,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
finally:
|
finally:
|
||||||
# Release lock
|
# Release lock
|
||||||
@@ -351,6 +355,7 @@ async def send_conversation_message(
|
|||||||
request=request,
|
request=request,
|
||||||
server=server,
|
server=server,
|
||||||
actor=actor,
|
actor=actor,
|
||||||
|
billing_context=headers.billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Normal conversation mode
|
# Normal conversation mode
|
||||||
@@ -383,6 +388,7 @@ async def send_conversation_message(
|
|||||||
request=streaming_request,
|
request=streaming_request,
|
||||||
run_type="send_conversation_message",
|
run_type="send_conversation_message",
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
|
billing_context=headers.billing_context,
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@@ -445,6 +451,7 @@ async def send_conversation_message(
|
|||||||
client_tools=request.client_tools,
|
client_tools=request.client_tools,
|
||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
include_compaction_messages=request.include_compaction_messages,
|
include_compaction_messages=request.include_compaction_messages,
|
||||||
|
billing_context=headers.billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -141,6 +141,9 @@ class ClickhouseProviderTraceBackend(ProviderTraceBackendClient):
|
|||||||
request_json=request_json_str,
|
request_json=request_json_str,
|
||||||
response_json=response_json_str,
|
response_json=response_json_str,
|
||||||
llm_config_json=llm_config_json_str,
|
llm_config_json=llm_config_json_str,
|
||||||
|
billing_plan_type=provider_trace.billing_context.plan_type if provider_trace.billing_context else None,
|
||||||
|
billing_cost_source=provider_trace.billing_context.cost_source if provider_trace.billing_context else None,
|
||||||
|
billing_customer_id=provider_trace.billing_context.customer_id if provider_trace.billing_context else None,
|
||||||
)
|
)
|
||||||
|
|
||||||
def _extract_usage(self, response_json: dict, provider: str) -> dict:
|
def _extract_usage(self, response_json: dict, provider: str) -> dict:
|
||||||
|
|||||||
@@ -29,7 +29,7 @@ class PostgresProviderTraceBackend(ProviderTraceBackendClient):
|
|||||||
) -> ProviderTrace:
|
) -> ProviderTrace:
|
||||||
"""Write full provider trace to provider_traces table."""
|
"""Write full provider trace to provider_traces table."""
|
||||||
async with db_registry.async_session() as session:
|
async with db_registry.async_session() as session:
|
||||||
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump())
|
provider_trace_model = ProviderTraceModel(**provider_trace.model_dump(exclude={"billing_context"}))
|
||||||
provider_trace_model.organization_id = actor.organization_id
|
provider_trace_model.organization_id = actor.organization_id
|
||||||
|
|
||||||
if provider_trace.request_json:
|
if provider_trace.request_json:
|
||||||
|
|||||||
@@ -34,6 +34,7 @@ from letta.schemas.letta_request import ClientToolSchema, LettaStreamingRequest
|
|||||||
from letta.schemas.letta_response import LettaResponse
|
from letta.schemas.letta_response import LettaResponse
|
||||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||||
from letta.schemas.message import MessageCreate
|
from letta.schemas.message import MessageCreate
|
||||||
|
from letta.schemas.provider_trace import BillingContext
|
||||||
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
||||||
from letta.schemas.usage import LettaUsageStatistics
|
from letta.schemas.usage import LettaUsageStatistics
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
@@ -78,6 +79,7 @@ class StreamingService:
|
|||||||
run_type: str = "streaming",
|
run_type: str = "streaming",
|
||||||
conversation_id: Optional[str] = None,
|
conversation_id: Optional[str] = None,
|
||||||
should_lock: bool = False,
|
should_lock: bool = False,
|
||||||
|
billing_context: "BillingContext | None" = None,
|
||||||
) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]:
|
) -> tuple[Optional[PydanticRun], Union[StreamingResponse, LettaResponse]]:
|
||||||
"""
|
"""
|
||||||
Create a streaming response for an agent.
|
Create a streaming response for an agent.
|
||||||
@@ -176,6 +178,7 @@ class StreamingService:
|
|||||||
lock_key=lock_key, # For lock release (may differ from conversation_id)
|
lock_key=lock_key, # For lock release (may differ from conversation_id)
|
||||||
client_tools=request.client_tools,
|
client_tools=request.client_tools,
|
||||||
include_compaction_messages=request.include_compaction_messages,
|
include_compaction_messages=request.include_compaction_messages,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
# handle background streaming if requested
|
# handle background streaming if requested
|
||||||
@@ -340,6 +343,7 @@ class StreamingService:
|
|||||||
lock_key: Optional[str] = None,
|
lock_key: Optional[str] = None,
|
||||||
client_tools: Optional[list[ClientToolSchema]] = None,
|
client_tools: Optional[list[ClientToolSchema]] = None,
|
||||||
include_compaction_messages: bool = False,
|
include_compaction_messages: bool = False,
|
||||||
|
billing_context: BillingContext | None = None,
|
||||||
) -> AsyncIterator:
|
) -> AsyncIterator:
|
||||||
"""
|
"""
|
||||||
Create a stream with unified error handling.
|
Create a stream with unified error handling.
|
||||||
@@ -368,6 +372,7 @@ class StreamingService:
|
|||||||
conversation_id=conversation_id,
|
conversation_id=conversation_id,
|
||||||
client_tools=client_tools,
|
client_tools=client_tools,
|
||||||
include_compaction_messages=include_compaction_messages,
|
include_compaction_messages=include_compaction_messages,
|
||||||
|
billing_context=billing_context,
|
||||||
)
|
)
|
||||||
|
|
||||||
async for chunk in stream:
|
async for chunk in stream:
|
||||||
|
|||||||
@@ -24,6 +24,9 @@ def test_get_headers_user_id_allows_none():
|
|||||||
letta_v1_agent=None,
|
letta_v1_agent=None,
|
||||||
letta_v1_agent_message_async=None,
|
letta_v1_agent_message_async=None,
|
||||||
modal_sandbox=None,
|
modal_sandbox=None,
|
||||||
|
billing_plan_type=None,
|
||||||
|
billing_cost_source=None,
|
||||||
|
billing_customer_id=None,
|
||||||
)
|
)
|
||||||
assert isinstance(headers, HeaderParams)
|
assert isinstance(headers, HeaderParams)
|
||||||
|
|
||||||
@@ -40,6 +43,9 @@ def test_get_headers_user_id_rejects_invalid_format():
|
|||||||
letta_v1_agent=None,
|
letta_v1_agent=None,
|
||||||
letta_v1_agent_message_async=None,
|
letta_v1_agent_message_async=None,
|
||||||
modal_sandbox=None,
|
modal_sandbox=None,
|
||||||
|
billing_plan_type=None,
|
||||||
|
billing_cost_source=None,
|
||||||
|
billing_customer_id=None,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -54,6 +60,9 @@ def test_get_headers_user_id_accepts_valid_format():
|
|||||||
letta_v1_agent=None,
|
letta_v1_agent=None,
|
||||||
letta_v1_agent_message_async=None,
|
letta_v1_agent_message_async=None,
|
||||||
modal_sandbox=None,
|
modal_sandbox=None,
|
||||||
|
billing_plan_type=None,
|
||||||
|
billing_cost_source=None,
|
||||||
|
billing_customer_id=None,
|
||||||
)
|
)
|
||||||
assert headers.actor_id == "user-123e4567-e89b-42d3-8456-426614174000"
|
assert headers.actor_id == "user-123e4567-e89b-42d3-8456-426614174000"
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user