From e3fb00f97009cafe527cde93983cda0dfdd7e574 Mon Sep 17 00:00:00 2001 From: Kian Jones <11655409+kianjones9@users.noreply.github.com> Date: Wed, 21 Jan 2026 21:57:22 -0800 Subject: [PATCH] feat(crouton): add orgId, userId, Compaction_Settings and LLM_Config (#9022) * LC one shot? * api changes * fix summarizer nameerror --- ...d_v2_protocol_fields_to_provider_traces.py | 32 +++++++++ fern/openapi.json | 40 ++++++++++- letta/adapters/letta_llm_adapter.py | 4 ++ letta/adapters/letta_llm_request_adapter.py | 3 + letta/adapters/letta_llm_stream_adapter.py | 7 +- letta/adapters/simple_llm_request_adapter.py | 3 + letta/adapters/simple_llm_stream_adapter.py | 3 + letta/agents/letta_agent.py | 9 +++ letta/agents/letta_agent_v2.py | 12 +++- letta/agents/letta_agent_v3.py | 6 ++ letta/llm_api/llm_client_base.py | 20 ++++++ letta/orm/provider_trace.py | 10 +++ letta/schemas/provider_trace.py | 9 +++ .../provider_trace_backends/socket.py | 8 ++- letta/services/summarizer/summarizer.py | 12 ++++ letta/services/summarizer/summarizer_all.py | 4 ++ .../summarizer/summarizer_sliding_window.py | 7 ++ tests/test_provider_trace_backends.py | 72 +++++++++++++++++++ 18 files changed, 257 insertions(+), 4 deletions(-) create mode 100644 alembic/versions/9275f62ad282_add_v2_protocol_fields_to_provider_traces.py diff --git a/alembic/versions/9275f62ad282_add_v2_protocol_fields_to_provider_traces.py b/alembic/versions/9275f62ad282_add_v2_protocol_fields_to_provider_traces.py new file mode 100644 index 00000000..97fa2f73 --- /dev/null +++ b/alembic/versions/9275f62ad282_add_v2_protocol_fields_to_provider_traces.py @@ -0,0 +1,32 @@ +"""Add v2 protocol fields to provider_traces + +Revision ID: 9275f62ad282 +Revises: 297e8217e952 +Create Date: 2026-01-22 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +revision: str = "9275f62ad282" +down_revision: Union[str, None] = "297e8217e952" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + op.add_column("provider_traces", sa.Column("org_id", sa.String(), nullable=True)) + op.add_column("provider_traces", sa.Column("user_id", sa.String(), nullable=True)) + op.add_column("provider_traces", sa.Column("compaction_settings", sa.JSON(), nullable=True)) + op.add_column("provider_traces", sa.Column("llm_config", sa.JSON(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("provider_traces", "llm_config") + op.drop_column("provider_traces", "compaction_settings") + op.drop_column("provider_traces", "user_id") + op.drop_column("provider_traces", "org_id") diff --git a/fern/openapi.json b/fern/openapi.json index 618b0ca0..184200e4 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -39340,13 +39340,51 @@ ], "title": "Source", "description": "Source service that generated this trace (memgpt-server, lettuce-py)" + }, + "org_id": { + "anyOf": [ + { + "type": "string" + }, + { + "type": "null" + } + ], + "title": "Org Id", + "description": "ID of the organization" + }, + "compaction_settings": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Compaction Settings", + "description": "Compaction/summarization settings (summarization calls only)" + }, + "llm_config": { + "anyOf": [ + { + "additionalProperties": true, + "type": "object" + }, + { + "type": "null" + } + ], + "title": "Llm Config", + "description": "LLM configuration used for this call (non-summarization calls only)" } }, "additionalProperties": false, "type": "object", "required": ["request_json", "response_json"], "title": "ProviderTrace", - "description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n created_at (datetime): The timestamp when the object was created." + "description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n user_id (str): The unique identifier of the user who initiated the request.\n compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).\n llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).\n created_at (datetime): The timestamp when the object was created." }, "ProviderType": { "type": "string", diff --git a/letta/adapters/letta_llm_adapter.py b/letta/adapters/letta_llm_adapter.py index b782fa3d..b00a8edb 100644 --- a/letta/adapters/letta_llm_adapter.py +++ b/letta/adapters/letta_llm_adapter.py @@ -27,12 +27,16 @@ class LettaLLMAdapter(ABC): agent_id: str | None = None, agent_tags: list[str] | None = None, run_id: str | None = None, + org_id: str | None = None, + user_id: str | None = None, ) -> None: self.llm_client: LLMClientBase = llm_client self.llm_config: LLMConfig = llm_config self.agent_id: str | None = agent_id self.agent_tags: list[str] | None = agent_tags self.run_id: str | None = run_id + self.org_id: str | None = org_id + self.user_id: str | None = user_id self.message_id: str | None = None self.request_data: dict | None = None self.response_data: dict | None = None diff --git a/letta/adapters/letta_llm_request_adapter.py b/letta/adapters/letta_llm_request_adapter.py index 7635d424..5e472a35 100644 --- a/letta/adapters/letta_llm_request_adapter.py +++ b/letta/adapters/letta_llm_request_adapter.py @@ -127,6 +127,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter): agent_id=self.agent_id, agent_tags=self.agent_tags, run_id=self.run_id, + org_id=self.org_id, + user_id=self.user_id, + llm_config=self.llm_config.model_dump() if self.llm_config else None, ), ), label="create_provider_trace", diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index 2929b1c4..46659618 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -33,8 +33,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): agent_id: str | None = None, agent_tags: list[str] | None = None, run_id: str | None = None, + org_id: str | None = None, + user_id: str | None = None, ) -> None: - super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id) + super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id) self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None async def invoke_llm( @@ -236,6 +238,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): agent_id=self.agent_id, agent_tags=self.agent_tags, run_id=self.run_id, + org_id=self.org_id, + user_id=self.user_id, + llm_config=self.llm_config.model_dump() if self.llm_config else None, ), ), label="create_provider_trace", diff --git a/letta/adapters/simple_llm_request_adapter.py b/letta/adapters/simple_llm_request_adapter.py index 30243a9b..cf2dc741 100644 --- a/letta/adapters/simple_llm_request_adapter.py +++ b/letta/adapters/simple_llm_request_adapter.py @@ -46,6 +46,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter): agent_tags=self.agent_tags, run_id=self.run_id, call_type="agent_step", + org_id=self.org_id, + user_id=self.user_id, + llm_config=self.llm_config.model_dump() if self.llm_config else None, ) try: self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config) diff --git a/letta/adapters/simple_llm_stream_adapter.py b/letta/adapters/simple_llm_stream_adapter.py index c2af996c..c3d14ffa 100644 --- a/letta/adapters/simple_llm_stream_adapter.py +++ b/letta/adapters/simple_llm_stream_adapter.py @@ -283,6 +283,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter): agent_id=self.agent_id, agent_tags=self.agent_tags, run_id=self.run_id, + org_id=self.org_id, + user_id=self.user_id, + llm_config=self.llm_config.model_dump() if self.llm_config else None, ), ), label="create_provider_trace", diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index f317bd81..3b359c72 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -420,6 +420,9 @@ class LettaAgent(BaseAgent): agent_id=self.agent_id, agent_tags=agent_state.tags, run_id=self.current_run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, + llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None, ), ) step_progression = StepProgression.LOGGED_TRACE @@ -770,6 +773,9 @@ class LettaAgent(BaseAgent): agent_id=self.agent_id, agent_tags=agent_state.tags, run_id=self.current_run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, + llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None, ), ) step_progression = StepProgression.LOGGED_TRACE @@ -1242,6 +1248,9 @@ class LettaAgent(BaseAgent): agent_id=self.agent_id, agent_tags=agent_state.tags, run_id=self.current_run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, + llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None, ), ) step_progression = StepProgression.LOGGED_TRACE diff --git a/letta/agents/letta_agent_v2.py b/letta/agents/letta_agent_v2.py index 34569655..58379c78 100644 --- a/letta/agents/letta_agent_v2.py +++ b/letta/agents/letta_agent_v2.py @@ -156,7 +156,11 @@ class LettaAgentV2(BaseAgentV2): run_id=None, messages=in_context_messages + input_messages_to_persist, llm_adapter=LettaLLMRequestAdapter( - llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_tags=self.agent_state.tags + llm_client=self.llm_client, + llm_config=self.agent_state.llm_config, + agent_tags=self.agent_state.tags, + org_id=self.actor.organization_id, + user_id=self.actor.id, ), dry_run=True, enforce_run_id_set=False, @@ -213,6 +217,8 @@ class LettaAgentV2(BaseAgentV2): agent_id=self.agent_state.id, agent_tags=self.agent_state.tags, run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, ), run_id=run_id, use_assistant_message=use_assistant_message, @@ -298,6 +304,8 @@ class LettaAgentV2(BaseAgentV2): agent_id=self.agent_state.id, agent_tags=self.agent_state.tags, run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, ) else: llm_adapter = LettaLLMRequestAdapter( @@ -306,6 +314,8 @@ class LettaAgentV2(BaseAgentV2): agent_id=self.agent_state.id, agent_tags=self.agent_state.tags, run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, ) try: diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 9c454bf2..b0df650c 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -173,6 +173,8 @@ class LettaAgentV3(LettaAgentV2): agent_id=self.agent_state.id, agent_tags=self.agent_state.tags, run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, ), run_id=run_id, # use_assistant_message=use_assistant_message, @@ -316,6 +318,8 @@ class LettaAgentV3(LettaAgentV2): agent_id=self.agent_state.id, agent_tags=self.agent_state.tags, run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, ) else: llm_adapter = SimpleLLMRequestAdapter( @@ -324,6 +328,8 @@ class LettaAgentV3(LettaAgentV2): agent_id=self.agent_state.id, agent_tags=self.agent_state.tags, run_id=run_id, + org_id=self.actor.organization_id, + user_id=self.actor.id, ) try: diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 697e0961..3491e093 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -43,6 +43,10 @@ class LLMClientBase: self._telemetry_run_id: Optional[str] = None self._telemetry_step_id: Optional[str] = None self._telemetry_call_type: Optional[str] = None + self._telemetry_org_id: Optional[str] = None + self._telemetry_user_id: Optional[str] = None + self._telemetry_compaction_settings: Optional[Dict] = None + self._telemetry_llm_config: Optional[Dict] = None def set_telemetry_context( self, @@ -52,6 +56,10 @@ class LLMClientBase: run_id: Optional[str] = None, step_id: Optional[str] = None, call_type: Optional[str] = None, + org_id: Optional[str] = None, + user_id: Optional[str] = None, + compaction_settings: Optional[Dict] = None, + llm_config: Optional[Dict] = None, ) -> None: """Set telemetry context for provider trace logging.""" self._telemetry_manager = telemetry_manager @@ -60,6 +68,10 @@ class LLMClientBase: self._telemetry_run_id = run_id self._telemetry_step_id = step_id self._telemetry_call_type = call_type + self._telemetry_org_id = org_id + self._telemetry_user_id = user_id + self._telemetry_compaction_settings = compaction_settings + self._telemetry_llm_config = llm_config async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict: """Wrapper around request_async that logs telemetry for all requests including errors. @@ -96,6 +108,10 @@ class LLMClientBase: agent_tags=self._telemetry_agent_tags, run_id=self._telemetry_run_id, call_type=self._telemetry_call_type, + org_id=self._telemetry_org_id, + user_id=self._telemetry_user_id, + compaction_settings=self._telemetry_compaction_settings, + llm_config=self._telemetry_llm_config, ), ) except Exception as e: @@ -137,6 +153,10 @@ class LLMClientBase: agent_tags=self._telemetry_agent_tags, run_id=self._telemetry_run_id, call_type=self._telemetry_call_type, + org_id=self._telemetry_org_id, + user_id=self._telemetry_user_id, + compaction_settings=self._telemetry_compaction_settings, + llm_config=self._telemetry_llm_config, ), ) except Exception as e: diff --git a/letta/orm/provider_trace.py b/letta/orm/provider_trace.py index b0cbb181..90399b5d 100644 --- a/letta/orm/provider_trace.py +++ b/letta/orm/provider_trace.py @@ -32,5 +32,15 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin): String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)" ) + # v2 protocol fields + org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization") + user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request") + compaction_settings: Mapped[Optional[dict]] = mapped_column( + JSON, nullable=True, doc="Compaction/summarization settings (summarization calls only)" + ) + llm_config: Mapped[Optional[dict]] = mapped_column( + JSON, nullable=True, doc="LLM configuration used for this call (non-summarization calls only)" + ) + # Relationships organization: Mapped["Organization"] = relationship("Organization", lazy="selectin") diff --git a/letta/schemas/provider_trace.py b/letta/schemas/provider_trace.py index 10ca5c3a..42ee6672 100644 --- a/letta/schemas/provider_trace.py +++ b/letta/schemas/provider_trace.py @@ -29,6 +29,9 @@ class ProviderTrace(BaseProviderTrace): run_id (str): ID of the run this trace is associated with. source (str): Source service that generated this trace (memgpt-server, lettuce-py). organization_id (str): The unique identifier of the organization. + user_id (str): The unique identifier of the user who initiated the request. + compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls). + llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls). created_at (datetime): The timestamp when the object was created. """ @@ -44,4 +47,10 @@ class ProviderTrace(BaseProviderTrace): run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with") source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)") + # v2 protocol fields + org_id: Optional[str] = Field(None, description="ID of the organization") + user_id: Optional[str] = Field(None, description="ID of the user who initiated the request") + compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)") + llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)") + created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.") diff --git a/letta/services/provider_trace_backends/socket.py b/letta/services/provider_trace_backends/socket.py index dfb4ef8e..1d375e57 100644 --- a/letta/services/provider_trace_backends/socket.py +++ b/letta/services/provider_trace_backends/socket.py @@ -17,7 +17,8 @@ logger = get_logger(__name__) # Protocol version for crouton communication. # Bump this when making breaking changes to the record schema. # Must match ProtocolVersion in apps/crouton/main.go. -PROTOCOL_VERSION = 1 +# v2: Added user_id, compaction_settings (summarization), llm_config (non-summarization) +PROTOCOL_VERSION = 2 class SocketProviderTraceBackend(ProviderTraceBackendClient): @@ -94,6 +95,11 @@ class SocketProviderTraceBackend(ProviderTraceBackendClient): "error": error, "error_type": error_type, "timestamp": datetime.now(timezone.utc).isoformat(), + # v2 protocol fields + "org_id": provider_trace.org_id, + "user_id": provider_trace.user_id, + "compaction_settings": provider_trace.compaction_settings, + "llm_config": provider_trace.llm_config, } # Fire-and-forget in background thread diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index ed855e84..64e9f8ba 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -181,6 +181,8 @@ class Summarizer: agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor) # TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version + from letta.settings import summarizer_settings + summary_message_str = await simple_summary( messages=messages_to_summarize, llm_config=agent_state.llm_config, @@ -190,6 +192,12 @@ class Summarizer: agent_tags=agent_state.tags, run_id=run_id if run_id is not None else self.run_id, step_id=step_id if step_id is not None else self.step_id, + compaction_settings={ + "mode": str(summarizer_settings.mode.value), + "message_buffer_limit": summarizer_settings.message_buffer_limit, + "message_buffer_min": summarizer_settings.message_buffer_min, + "partial_evict_summarizer_percentage": summarizer_settings.partial_evict_summarizer_percentage, + }, ) # TODO add counts back @@ -450,6 +458,7 @@ async def simple_summary( agent_tags: List[str] | None = None, run_id: str | None = None, step_id: str | None = None, + compaction_settings: dict | None = None, ) -> str: """Generate a simple summary from a list of messages. @@ -474,6 +483,9 @@ async def simple_summary( run_id=run_id, step_id=step_id, call_type="summarization", + org_id=actor.organization_id if actor else None, + user_id=actor.id if actor else None, + compaction_settings=compaction_settings, ) # Prepare the messages payload to send to the LLM diff --git a/letta/services/summarizer/summarizer_all.py b/letta/services/summarizer/summarizer_all.py index 3d9b2ffa..fc183214 100644 --- a/letta/services/summarizer/summarizer_all.py +++ b/letta/services/summarizer/summarizer_all.py @@ -68,6 +68,10 @@ async def summarize_all( agent_tags=agent_tags, run_id=run_id, step_id=step_id, + compaction_settings={ + "mode": "summarize_all", + "clip_chars": summarizer_config.clip_chars, + }, ) logger.info(f"Summarized {len(messages_to_summarize)} messages") diff --git a/letta/services/summarizer/summarizer_sliding_window.py b/letta/services/summarizer/summarizer_sliding_window.py index 10a409d2..87739393 100644 --- a/letta/services/summarizer/summarizer_sliding_window.py +++ b/letta/services/summarizer/summarizer_sliding_window.py @@ -146,6 +146,13 @@ async def summarize_via_sliding_window( agent_tags=agent_tags, run_id=run_id, step_id=step_id, + compaction_settings={ + "mode": "sliding_window", + "messages_summarized": len(messages_to_summarize), + "messages_kept": total_message_count - assistant_message_index, + "sliding_window_percentage": summarizer_config.sliding_window_percentage, + "clip_chars": summarizer_config.clip_chars, + }, ) if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars: diff --git a/tests/test_provider_trace_backends.py b/tests/test_provider_trace_backends.py index 3d64e04b..f1051d1c 100644 --- a/tests/test_provider_trace_backends.py +++ b/tests/test_provider_trace_backends.py @@ -92,6 +92,48 @@ class TestProviderTrace: assert trace.call_type == "summarization" assert trace.run_id == "run-789" + def test_v2_protocol_fields(self): + """Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config).""" + trace = ProviderTrace( + request_json={}, + response_json={}, + step_id="step-123", + org_id="org-123", + user_id="user-123", + compaction_settings={"mode": "sliding_window", "target_message_count": 50}, + llm_config={"model": "gpt-4", "temperature": 0.7}, + ) + assert trace.org_id == "org-123" + assert trace.user_id == "user-123" + assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50} + assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7} + + def test_v2_fields_mutually_exclusive_by_convention(self): + """Test that compaction_settings is set for summarization, llm_config for non-summarization.""" + summarization_trace = ProviderTrace( + request_json={}, + response_json={}, + step_id="step-123", + call_type="summarization", + compaction_settings={"mode": "partial_evict"}, + llm_config=None, + ) + assert summarization_trace.call_type == "summarization" + assert summarization_trace.compaction_settings is not None + assert summarization_trace.llm_config is None + + agent_step_trace = ProviderTrace( + request_json={}, + response_json={}, + step_id="step-456", + call_type="agent_step", + compaction_settings=None, + llm_config={"model": "claude-3"}, + ) + assert agent_step_trace.call_type == "agent_step" + assert agent_step_trace.compaction_settings is None + assert agent_step_trace.llm_config is not None + class TestSocketProviderTraceBackend: """Tests for SocketProviderTraceBackend.""" @@ -246,6 +288,36 @@ class TestSocketProviderTraceBackend: assert captured_records[0]["error"] == "Rate limit exceeded" assert captured_records[0]["response"] is None + def test_record_includes_v2_protocol_fields(self): + """Test that v2 protocol fields are included in the socket record.""" + trace = ProviderTrace( + request_json={"model": "gpt-4"}, + response_json={"id": "test"}, + step_id="step-123", + org_id="org-456", + user_id="user-456", + compaction_settings={"mode": "sliding_window"}, + llm_config={"model": "gpt-4", "temperature": 0.5}, + ) + + backend = SocketProviderTraceBackend(socket_path="/fake/path") + + captured_records = [] + + def capture_record(record): + captured_records.append(record) + + with patch.object(backend, "_send_async", side_effect=capture_record): + backend._send_to_crouton(trace) + + assert len(captured_records) == 1 + record = captured_records[0] + assert record["protocol_version"] == 2 + assert record["org_id"] == "org-456" + assert record["user_id"] == "user-456" + assert record["compaction_settings"] == {"mode": "sliding_window"} + assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5} + class TestBackendFactory: """Tests for backend factory."""