feat(crouton): add orgId, userId, Compaction_Settings and LLM_Config (#9022)

* LC one shot?

* api changes

* fix summarizer nameerror
This commit is contained in:
Kian Jones
2026-01-21 21:57:22 -08:00
committed by Caren Thomas
parent 194fa7d1c6
commit e3fb00f970
18 changed files with 257 additions and 4 deletions

View File

@@ -0,0 +1,32 @@
"""Add v2 protocol fields to provider_traces
Revision ID: 9275f62ad282
Revises: 297e8217e952
Create Date: 2026-01-22
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
revision: str = "9275f62ad282"
down_revision: Union[str, None] = "297e8217e952"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
op.add_column("provider_traces", sa.Column("org_id", sa.String(), nullable=True))
op.add_column("provider_traces", sa.Column("user_id", sa.String(), nullable=True))
op.add_column("provider_traces", sa.Column("compaction_settings", sa.JSON(), nullable=True))
op.add_column("provider_traces", sa.Column("llm_config", sa.JSON(), nullable=True))
def downgrade() -> None:
op.drop_column("provider_traces", "llm_config")
op.drop_column("provider_traces", "compaction_settings")
op.drop_column("provider_traces", "user_id")
op.drop_column("provider_traces", "org_id")

View File

@@ -39340,13 +39340,51 @@
],
"title": "Source",
"description": "Source service that generated this trace (memgpt-server, lettuce-py)"
},
"org_id": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"title": "Org Id",
"description": "ID of the organization"
},
"compaction_settings": {
"anyOf": [
{
"additionalProperties": true,
"type": "object"
},
{
"type": "null"
}
],
"title": "Compaction Settings",
"description": "Compaction/summarization settings (summarization calls only)"
},
"llm_config": {
"anyOf": [
{
"additionalProperties": true,
"type": "object"
},
{
"type": "null"
}
],
"title": "Llm Config",
"description": "LLM configuration used for this call (non-summarization calls only)"
}
},
"additionalProperties": false,
"type": "object",
"required": ["request_json", "response_json"],
"title": "ProviderTrace",
"description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n created_at (datetime): The timestamp when the object was created."
"description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n user_id (str): The unique identifier of the user who initiated the request.\n compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).\n llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).\n created_at (datetime): The timestamp when the object was created."
},
"ProviderType": {
"type": "string",

View File

@@ -27,12 +27,16 @@ class LettaLLMAdapter(ABC):
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
) -> None:
self.llm_client: LLMClientBase = llm_client
self.llm_config: LLMConfig = llm_config
self.agent_id: str | None = agent_id
self.agent_tags: list[str] | None = agent_tags
self.run_id: str | None = run_id
self.org_id: str | None = org_id
self.user_id: str | None = user_id
self.message_id: str | None = None
self.request_data: dict | None = None
self.response_data: dict | None = None

View File

@@ -127,6 +127,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
),
),
label="create_provider_trace",

View File

@@ -33,8 +33,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
agent_id: str | None = None,
agent_tags: list[str] | None = None,
run_id: str | None = None,
org_id: str | None = None,
user_id: str | None = None,
) -> None:
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id)
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
async def invoke_llm(
@@ -236,6 +238,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
),
),
label="create_provider_trace",

View File

@@ -46,6 +46,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
agent_tags=self.agent_tags,
run_id=self.run_id,
call_type="agent_step",
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
)
try:
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)

View File

@@ -283,6 +283,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
agent_id=self.agent_id,
agent_tags=self.agent_tags,
run_id=self.run_id,
org_id=self.org_id,
user_id=self.user_id,
llm_config=self.llm_config.model_dump() if self.llm_config else None,
),
),
label="create_provider_trace",

View File

@@ -420,6 +420,9 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
@@ -770,6 +773,9 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE
@@ -1242,6 +1248,9 @@ class LettaAgent(BaseAgent):
agent_id=self.agent_id,
agent_tags=agent_state.tags,
run_id=self.current_run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
),
)
step_progression = StepProgression.LOGGED_TRACE

View File

@@ -156,7 +156,11 @@ class LettaAgentV2(BaseAgentV2):
run_id=None,
messages=in_context_messages + input_messages_to_persist,
llm_adapter=LettaLLMRequestAdapter(
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_tags=self.agent_state.tags
llm_client=self.llm_client,
llm_config=self.agent_state.llm_config,
agent_tags=self.agent_state.tags,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
dry_run=True,
enforce_run_id_set=False,
@@ -213,6 +217,8 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
run_id=run_id,
use_assistant_message=use_assistant_message,
@@ -298,6 +304,8 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
else:
llm_adapter = LettaLLMRequestAdapter(
@@ -306,6 +314,8 @@ class LettaAgentV2(BaseAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
try:

View File

@@ -173,6 +173,8 @@ class LettaAgentV3(LettaAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
),
run_id=run_id,
# use_assistant_message=use_assistant_message,
@@ -316,6 +318,8 @@ class LettaAgentV3(LettaAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
else:
llm_adapter = SimpleLLMRequestAdapter(
@@ -324,6 +328,8 @@ class LettaAgentV3(LettaAgentV2):
agent_id=self.agent_state.id,
agent_tags=self.agent_state.tags,
run_id=run_id,
org_id=self.actor.organization_id,
user_id=self.actor.id,
)
try:

View File

@@ -43,6 +43,10 @@ class LLMClientBase:
self._telemetry_run_id: Optional[str] = None
self._telemetry_step_id: Optional[str] = None
self._telemetry_call_type: Optional[str] = None
self._telemetry_org_id: Optional[str] = None
self._telemetry_user_id: Optional[str] = None
self._telemetry_compaction_settings: Optional[Dict] = None
self._telemetry_llm_config: Optional[Dict] = None
def set_telemetry_context(
self,
@@ -52,6 +56,10 @@ class LLMClientBase:
run_id: Optional[str] = None,
step_id: Optional[str] = None,
call_type: Optional[str] = None,
org_id: Optional[str] = None,
user_id: Optional[str] = None,
compaction_settings: Optional[Dict] = None,
llm_config: Optional[Dict] = None,
) -> None:
"""Set telemetry context for provider trace logging."""
self._telemetry_manager = telemetry_manager
@@ -60,6 +68,10 @@ class LLMClientBase:
self._telemetry_run_id = run_id
self._telemetry_step_id = step_id
self._telemetry_call_type = call_type
self._telemetry_org_id = org_id
self._telemetry_user_id = user_id
self._telemetry_compaction_settings = compaction_settings
self._telemetry_llm_config = llm_config
async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""Wrapper around request_async that logs telemetry for all requests including errors.
@@ -96,6 +108,10 @@ class LLMClientBase:
agent_tags=self._telemetry_agent_tags,
run_id=self._telemetry_run_id,
call_type=self._telemetry_call_type,
org_id=self._telemetry_org_id,
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=self._telemetry_llm_config,
),
)
except Exception as e:
@@ -137,6 +153,10 @@ class LLMClientBase:
agent_tags=self._telemetry_agent_tags,
run_id=self._telemetry_run_id,
call_type=self._telemetry_call_type,
org_id=self._telemetry_org_id,
user_id=self._telemetry_user_id,
compaction_settings=self._telemetry_compaction_settings,
llm_config=self._telemetry_llm_config,
),
)
except Exception as e:

View File

@@ -32,5 +32,15 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
)
# v2 protocol fields
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
compaction_settings: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, doc="Compaction/summarization settings (summarization calls only)"
)
llm_config: Mapped[Optional[dict]] = mapped_column(
JSON, nullable=True, doc="LLM configuration used for this call (non-summarization calls only)"
)
# Relationships
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")

View File

@@ -29,6 +29,9 @@ class ProviderTrace(BaseProviderTrace):
run_id (str): ID of the run this trace is associated with.
source (str): Source service that generated this trace (memgpt-server, lettuce-py).
organization_id (str): The unique identifier of the organization.
user_id (str): The unique identifier of the user who initiated the request.
compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).
llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).
created_at (datetime): The timestamp when the object was created.
"""
@@ -44,4 +47,10 @@ class ProviderTrace(BaseProviderTrace):
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
# v2 protocol fields
org_id: Optional[str] = Field(None, description="ID of the organization")
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")

View File

@@ -17,7 +17,8 @@ logger = get_logger(__name__)
# Protocol version for crouton communication.
# Bump this when making breaking changes to the record schema.
# Must match ProtocolVersion in apps/crouton/main.go.
PROTOCOL_VERSION = 1
# v2: Added user_id, compaction_settings (summarization), llm_config (non-summarization)
PROTOCOL_VERSION = 2
class SocketProviderTraceBackend(ProviderTraceBackendClient):
@@ -94,6 +95,11 @@ class SocketProviderTraceBackend(ProviderTraceBackendClient):
"error": error,
"error_type": error_type,
"timestamp": datetime.now(timezone.utc).isoformat(),
# v2 protocol fields
"org_id": provider_trace.org_id,
"user_id": provider_trace.user_id,
"compaction_settings": provider_trace.compaction_settings,
"llm_config": provider_trace.llm_config,
}
# Fire-and-forget in background thread

View File

@@ -181,6 +181,8 @@ class Summarizer:
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
from letta.settings import summarizer_settings
summary_message_str = await simple_summary(
messages=messages_to_summarize,
llm_config=agent_state.llm_config,
@@ -190,6 +192,12 @@ class Summarizer:
agent_tags=agent_state.tags,
run_id=run_id if run_id is not None else self.run_id,
step_id=step_id if step_id is not None else self.step_id,
compaction_settings={
"mode": str(summarizer_settings.mode.value),
"message_buffer_limit": summarizer_settings.message_buffer_limit,
"message_buffer_min": summarizer_settings.message_buffer_min,
"partial_evict_summarizer_percentage": summarizer_settings.partial_evict_summarizer_percentage,
},
)
# TODO add counts back
@@ -450,6 +458,7 @@ async def simple_summary(
agent_tags: List[str] | None = None,
run_id: str | None = None,
step_id: str | None = None,
compaction_settings: dict | None = None,
) -> str:
"""Generate a simple summary from a list of messages.
@@ -474,6 +483,9 @@ async def simple_summary(
run_id=run_id,
step_id=step_id,
call_type="summarization",
org_id=actor.organization_id if actor else None,
user_id=actor.id if actor else None,
compaction_settings=compaction_settings,
)
# Prepare the messages payload to send to the LLM

View File

@@ -68,6 +68,10 @@ async def summarize_all(
agent_tags=agent_tags,
run_id=run_id,
step_id=step_id,
compaction_settings={
"mode": "summarize_all",
"clip_chars": summarizer_config.clip_chars,
},
)
logger.info(f"Summarized {len(messages_to_summarize)} messages")

View File

@@ -146,6 +146,13 @@ async def summarize_via_sliding_window(
agent_tags=agent_tags,
run_id=run_id,
step_id=step_id,
compaction_settings={
"mode": "sliding_window",
"messages_summarized": len(messages_to_summarize),
"messages_kept": total_message_count - assistant_message_index,
"sliding_window_percentage": summarizer_config.sliding_window_percentage,
"clip_chars": summarizer_config.clip_chars,
},
)
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:

View File

@@ -92,6 +92,48 @@ class TestProviderTrace:
assert trace.call_type == "summarization"
assert trace.run_id == "run-789"
def test_v2_protocol_fields(self):
"""Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config)."""
trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
org_id="org-123",
user_id="user-123",
compaction_settings={"mode": "sliding_window", "target_message_count": 50},
llm_config={"model": "gpt-4", "temperature": 0.7},
)
assert trace.org_id == "org-123"
assert trace.user_id == "user-123"
assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50}
assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7}
def test_v2_fields_mutually_exclusive_by_convention(self):
"""Test that compaction_settings is set for summarization, llm_config for non-summarization."""
summarization_trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-123",
call_type="summarization",
compaction_settings={"mode": "partial_evict"},
llm_config=None,
)
assert summarization_trace.call_type == "summarization"
assert summarization_trace.compaction_settings is not None
assert summarization_trace.llm_config is None
agent_step_trace = ProviderTrace(
request_json={},
response_json={},
step_id="step-456",
call_type="agent_step",
compaction_settings=None,
llm_config={"model": "claude-3"},
)
assert agent_step_trace.call_type == "agent_step"
assert agent_step_trace.compaction_settings is None
assert agent_step_trace.llm_config is not None
class TestSocketProviderTraceBackend:
"""Tests for SocketProviderTraceBackend."""
@@ -246,6 +288,36 @@ class TestSocketProviderTraceBackend:
assert captured_records[0]["error"] == "Rate limit exceeded"
assert captured_records[0]["response"] is None
def test_record_includes_v2_protocol_fields(self):
"""Test that v2 protocol fields are included in the socket record."""
trace = ProviderTrace(
request_json={"model": "gpt-4"},
response_json={"id": "test"},
step_id="step-123",
org_id="org-456",
user_id="user-456",
compaction_settings={"mode": "sliding_window"},
llm_config={"model": "gpt-4", "temperature": 0.5},
)
backend = SocketProviderTraceBackend(socket_path="/fake/path")
captured_records = []
def capture_record(record):
captured_records.append(record)
with patch.object(backend, "_send_async", side_effect=capture_record):
backend._send_to_crouton(trace)
assert len(captured_records) == 1
record = captured_records[0]
assert record["protocol_version"] == 2
assert record["org_id"] == "org-456"
assert record["user_id"] == "user-456"
assert record["compaction_settings"] == {"mode": "sliding_window"}
assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5}
class TestBackendFactory:
"""Tests for backend factory."""