feat(crouton): add orgId, userId, Compaction_Settings and LLM_Config (#9022)
* LC one shot? * api changes * fix summarizer nameerror
This commit is contained in:
@@ -0,0 +1,32 @@
|
||||
"""Add v2 protocol fields to provider_traces
|
||||
|
||||
Revision ID: 9275f62ad282
|
||||
Revises: 297e8217e952
|
||||
Create Date: 2026-01-22
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
revision: str = "9275f62ad282"
|
||||
down_revision: Union[str, None] = "297e8217e952"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.add_column("provider_traces", sa.Column("org_id", sa.String(), nullable=True))
|
||||
op.add_column("provider_traces", sa.Column("user_id", sa.String(), nullable=True))
|
||||
op.add_column("provider_traces", sa.Column("compaction_settings", sa.JSON(), nullable=True))
|
||||
op.add_column("provider_traces", sa.Column("llm_config", sa.JSON(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("provider_traces", "llm_config")
|
||||
op.drop_column("provider_traces", "compaction_settings")
|
||||
op.drop_column("provider_traces", "user_id")
|
||||
op.drop_column("provider_traces", "org_id")
|
||||
@@ -39340,13 +39340,51 @@
|
||||
],
|
||||
"title": "Source",
|
||||
"description": "Source service that generated this trace (memgpt-server, lettuce-py)"
|
||||
},
|
||||
"org_id": {
|
||||
"anyOf": [
|
||||
{
|
||||
"type": "string"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Org Id",
|
||||
"description": "ID of the organization"
|
||||
},
|
||||
"compaction_settings": {
|
||||
"anyOf": [
|
||||
{
|
||||
"additionalProperties": true,
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Compaction Settings",
|
||||
"description": "Compaction/summarization settings (summarization calls only)"
|
||||
},
|
||||
"llm_config": {
|
||||
"anyOf": [
|
||||
{
|
||||
"additionalProperties": true,
|
||||
"type": "object"
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Llm Config",
|
||||
"description": "LLM configuration used for this call (non-summarization calls only)"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"type": "object",
|
||||
"required": ["request_json", "response_json"],
|
||||
"title": "ProviderTrace",
|
||||
"description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n created_at (datetime): The timestamp when the object was created."
|
||||
"description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n user_id (str): The unique identifier of the user who initiated the request.\n compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).\n llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).\n created_at (datetime): The timestamp when the object was created."
|
||||
},
|
||||
"ProviderType": {
|
||||
"type": "string",
|
||||
|
||||
@@ -27,12 +27,16 @@ class LettaLLMAdapter(ABC):
|
||||
agent_id: str | None = None,
|
||||
agent_tags: list[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
self.llm_client: LLMClientBase = llm_client
|
||||
self.llm_config: LLMConfig = llm_config
|
||||
self.agent_id: str | None = agent_id
|
||||
self.agent_tags: list[str] | None = agent_tags
|
||||
self.run_id: str | None = run_id
|
||||
self.org_id: str | None = org_id
|
||||
self.user_id: str | None = user_id
|
||||
self.message_id: str | None = None
|
||||
self.request_data: dict | None = None
|
||||
self.response_data: dict | None = None
|
||||
|
||||
@@ -127,6 +127,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -33,8 +33,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
agent_id: str | None = None,
|
||||
agent_tags: list[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
org_id: str | None = None,
|
||||
user_id: str | None = None,
|
||||
) -> None:
|
||||
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id)
|
||||
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
|
||||
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
||||
|
||||
async def invoke_llm(
|
||||
@@ -236,6 +238,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -46,6 +46,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
call_type="agent_step",
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
)
|
||||
try:
|
||||
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
||||
|
||||
@@ -283,6 +283,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=self.agent_tags,
|
||||
run_id=self.run_id,
|
||||
org_id=self.org_id,
|
||||
user_id=self.user_id,
|
||||
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||
),
|
||||
),
|
||||
label="create_provider_trace",
|
||||
|
||||
@@ -420,6 +420,9 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
@@ -770,6 +773,9 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
@@ -1242,6 +1248,9 @@ class LettaAgent(BaseAgent):
|
||||
agent_id=self.agent_id,
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=self.current_run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||
),
|
||||
)
|
||||
step_progression = StepProgression.LOGGED_TRACE
|
||||
|
||||
@@ -156,7 +156,11 @@ class LettaAgentV2(BaseAgentV2):
|
||||
run_id=None,
|
||||
messages=in_context_messages + input_messages_to_persist,
|
||||
llm_adapter=LettaLLMRequestAdapter(
|
||||
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_tags=self.agent_state.tags
|
||||
llm_client=self.llm_client,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
agent_tags=self.agent_state.tags,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
dry_run=True,
|
||||
enforce_run_id_set=False,
|
||||
@@ -213,6 +217,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
run_id=run_id,
|
||||
use_assistant_message=use_assistant_message,
|
||||
@@ -298,6 +304,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
else:
|
||||
llm_adapter = LettaLLMRequestAdapter(
|
||||
@@ -306,6 +314,8 @@ class LettaAgentV2(BaseAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -173,6 +173,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
),
|
||||
run_id=run_id,
|
||||
# use_assistant_message=use_assistant_message,
|
||||
@@ -316,6 +318,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
else:
|
||||
llm_adapter = SimpleLLMRequestAdapter(
|
||||
@@ -324,6 +328,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
agent_id=self.agent_state.id,
|
||||
agent_tags=self.agent_state.tags,
|
||||
run_id=run_id,
|
||||
org_id=self.actor.organization_id,
|
||||
user_id=self.actor.id,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -43,6 +43,10 @@ class LLMClientBase:
|
||||
self._telemetry_run_id: Optional[str] = None
|
||||
self._telemetry_step_id: Optional[str] = None
|
||||
self._telemetry_call_type: Optional[str] = None
|
||||
self._telemetry_org_id: Optional[str] = None
|
||||
self._telemetry_user_id: Optional[str] = None
|
||||
self._telemetry_compaction_settings: Optional[Dict] = None
|
||||
self._telemetry_llm_config: Optional[Dict] = None
|
||||
|
||||
def set_telemetry_context(
|
||||
self,
|
||||
@@ -52,6 +56,10 @@ class LLMClientBase:
|
||||
run_id: Optional[str] = None,
|
||||
step_id: Optional[str] = None,
|
||||
call_type: Optional[str] = None,
|
||||
org_id: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
compaction_settings: Optional[Dict] = None,
|
||||
llm_config: Optional[Dict] = None,
|
||||
) -> None:
|
||||
"""Set telemetry context for provider trace logging."""
|
||||
self._telemetry_manager = telemetry_manager
|
||||
@@ -60,6 +68,10 @@ class LLMClientBase:
|
||||
self._telemetry_run_id = run_id
|
||||
self._telemetry_step_id = step_id
|
||||
self._telemetry_call_type = call_type
|
||||
self._telemetry_org_id = org_id
|
||||
self._telemetry_user_id = user_id
|
||||
self._telemetry_compaction_settings = compaction_settings
|
||||
self._telemetry_llm_config = llm_config
|
||||
|
||||
async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""Wrapper around request_async that logs telemetry for all requests including errors.
|
||||
@@ -96,6 +108,10 @@ class LLMClientBase:
|
||||
agent_tags=self._telemetry_agent_tags,
|
||||
run_id=self._telemetry_run_id,
|
||||
call_type=self._telemetry_call_type,
|
||||
org_id=self._telemetry_org_id,
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=self._telemetry_llm_config,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
@@ -137,6 +153,10 @@ class LLMClientBase:
|
||||
agent_tags=self._telemetry_agent_tags,
|
||||
run_id=self._telemetry_run_id,
|
||||
call_type=self._telemetry_call_type,
|
||||
org_id=self._telemetry_org_id,
|
||||
user_id=self._telemetry_user_id,
|
||||
compaction_settings=self._telemetry_compaction_settings,
|
||||
llm_config=self._telemetry_llm_config,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -32,5 +32,15 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
||||
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
|
||||
)
|
||||
|
||||
# v2 protocol fields
|
||||
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
|
||||
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
||||
compaction_settings: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, nullable=True, doc="Compaction/summarization settings (summarization calls only)"
|
||||
)
|
||||
llm_config: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, nullable=True, doc="LLM configuration used for this call (non-summarization calls only)"
|
||||
)
|
||||
|
||||
# Relationships
|
||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||
|
||||
@@ -29,6 +29,9 @@ class ProviderTrace(BaseProviderTrace):
|
||||
run_id (str): ID of the run this trace is associated with.
|
||||
source (str): Source service that generated this trace (memgpt-server, lettuce-py).
|
||||
organization_id (str): The unique identifier of the organization.
|
||||
user_id (str): The unique identifier of the user who initiated the request.
|
||||
compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).
|
||||
llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).
|
||||
created_at (datetime): The timestamp when the object was created.
|
||||
"""
|
||||
|
||||
@@ -44,4 +47,10 @@ class ProviderTrace(BaseProviderTrace):
|
||||
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
|
||||
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
|
||||
|
||||
# v2 protocol fields
|
||||
org_id: Optional[str] = Field(None, description="ID of the organization")
|
||||
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
|
||||
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
|
||||
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
|
||||
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||
|
||||
@@ -17,7 +17,8 @@ logger = get_logger(__name__)
|
||||
# Protocol version for crouton communication.
|
||||
# Bump this when making breaking changes to the record schema.
|
||||
# Must match ProtocolVersion in apps/crouton/main.go.
|
||||
PROTOCOL_VERSION = 1
|
||||
# v2: Added user_id, compaction_settings (summarization), llm_config (non-summarization)
|
||||
PROTOCOL_VERSION = 2
|
||||
|
||||
|
||||
class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
||||
@@ -94,6 +95,11 @@ class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
||||
"error": error,
|
||||
"error_type": error_type,
|
||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||
# v2 protocol fields
|
||||
"org_id": provider_trace.org_id,
|
||||
"user_id": provider_trace.user_id,
|
||||
"compaction_settings": provider_trace.compaction_settings,
|
||||
"llm_config": provider_trace.llm_config,
|
||||
}
|
||||
|
||||
# Fire-and-forget in background thread
|
||||
|
||||
@@ -181,6 +181,8 @@ class Summarizer:
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||||
|
||||
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
|
||||
from letta.settings import summarizer_settings
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=messages_to_summarize,
|
||||
llm_config=agent_state.llm_config,
|
||||
@@ -190,6 +192,12 @@ class Summarizer:
|
||||
agent_tags=agent_state.tags,
|
||||
run_id=run_id if run_id is not None else self.run_id,
|
||||
step_id=step_id if step_id is not None else self.step_id,
|
||||
compaction_settings={
|
||||
"mode": str(summarizer_settings.mode.value),
|
||||
"message_buffer_limit": summarizer_settings.message_buffer_limit,
|
||||
"message_buffer_min": summarizer_settings.message_buffer_min,
|
||||
"partial_evict_summarizer_percentage": summarizer_settings.partial_evict_summarizer_percentage,
|
||||
},
|
||||
)
|
||||
|
||||
# TODO add counts back
|
||||
@@ -450,6 +458,7 @@ async def simple_summary(
|
||||
agent_tags: List[str] | None = None,
|
||||
run_id: str | None = None,
|
||||
step_id: str | None = None,
|
||||
compaction_settings: dict | None = None,
|
||||
) -> str:
|
||||
"""Generate a simple summary from a list of messages.
|
||||
|
||||
@@ -474,6 +483,9 @@ async def simple_summary(
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
call_type="summarization",
|
||||
org_id=actor.organization_id if actor else None,
|
||||
user_id=actor.id if actor else None,
|
||||
compaction_settings=compaction_settings,
|
||||
)
|
||||
|
||||
# Prepare the messages payload to send to the LLM
|
||||
|
||||
@@ -68,6 +68,10 @@ async def summarize_all(
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
compaction_settings={
|
||||
"mode": "summarize_all",
|
||||
"clip_chars": summarizer_config.clip_chars,
|
||||
},
|
||||
)
|
||||
logger.info(f"Summarized {len(messages_to_summarize)} messages")
|
||||
|
||||
|
||||
@@ -146,6 +146,13 @@ async def summarize_via_sliding_window(
|
||||
agent_tags=agent_tags,
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
compaction_settings={
|
||||
"mode": "sliding_window",
|
||||
"messages_summarized": len(messages_to_summarize),
|
||||
"messages_kept": total_message_count - assistant_message_index,
|
||||
"sliding_window_percentage": summarizer_config.sliding_window_percentage,
|
||||
"clip_chars": summarizer_config.clip_chars,
|
||||
},
|
||||
)
|
||||
|
||||
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
|
||||
|
||||
@@ -92,6 +92,48 @@ class TestProviderTrace:
|
||||
assert trace.call_type == "summarization"
|
||||
assert trace.run_id == "run-789"
|
||||
|
||||
def test_v2_protocol_fields(self):
|
||||
"""Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config)."""
|
||||
trace = ProviderTrace(
|
||||
request_json={},
|
||||
response_json={},
|
||||
step_id="step-123",
|
||||
org_id="org-123",
|
||||
user_id="user-123",
|
||||
compaction_settings={"mode": "sliding_window", "target_message_count": 50},
|
||||
llm_config={"model": "gpt-4", "temperature": 0.7},
|
||||
)
|
||||
assert trace.org_id == "org-123"
|
||||
assert trace.user_id == "user-123"
|
||||
assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50}
|
||||
assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7}
|
||||
|
||||
def test_v2_fields_mutually_exclusive_by_convention(self):
|
||||
"""Test that compaction_settings is set for summarization, llm_config for non-summarization."""
|
||||
summarization_trace = ProviderTrace(
|
||||
request_json={},
|
||||
response_json={},
|
||||
step_id="step-123",
|
||||
call_type="summarization",
|
||||
compaction_settings={"mode": "partial_evict"},
|
||||
llm_config=None,
|
||||
)
|
||||
assert summarization_trace.call_type == "summarization"
|
||||
assert summarization_trace.compaction_settings is not None
|
||||
assert summarization_trace.llm_config is None
|
||||
|
||||
agent_step_trace = ProviderTrace(
|
||||
request_json={},
|
||||
response_json={},
|
||||
step_id="step-456",
|
||||
call_type="agent_step",
|
||||
compaction_settings=None,
|
||||
llm_config={"model": "claude-3"},
|
||||
)
|
||||
assert agent_step_trace.call_type == "agent_step"
|
||||
assert agent_step_trace.compaction_settings is None
|
||||
assert agent_step_trace.llm_config is not None
|
||||
|
||||
|
||||
class TestSocketProviderTraceBackend:
|
||||
"""Tests for SocketProviderTraceBackend."""
|
||||
@@ -246,6 +288,36 @@ class TestSocketProviderTraceBackend:
|
||||
assert captured_records[0]["error"] == "Rate limit exceeded"
|
||||
assert captured_records[0]["response"] is None
|
||||
|
||||
def test_record_includes_v2_protocol_fields(self):
|
||||
"""Test that v2 protocol fields are included in the socket record."""
|
||||
trace = ProviderTrace(
|
||||
request_json={"model": "gpt-4"},
|
||||
response_json={"id": "test"},
|
||||
step_id="step-123",
|
||||
org_id="org-456",
|
||||
user_id="user-456",
|
||||
compaction_settings={"mode": "sliding_window"},
|
||||
llm_config={"model": "gpt-4", "temperature": 0.5},
|
||||
)
|
||||
|
||||
backend = SocketProviderTraceBackend(socket_path="/fake/path")
|
||||
|
||||
captured_records = []
|
||||
|
||||
def capture_record(record):
|
||||
captured_records.append(record)
|
||||
|
||||
with patch.object(backend, "_send_async", side_effect=capture_record):
|
||||
backend._send_to_crouton(trace)
|
||||
|
||||
assert len(captured_records) == 1
|
||||
record = captured_records[0]
|
||||
assert record["protocol_version"] == 2
|
||||
assert record["org_id"] == "org-456"
|
||||
assert record["user_id"] == "user-456"
|
||||
assert record["compaction_settings"] == {"mode": "sliding_window"}
|
||||
assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5}
|
||||
|
||||
|
||||
class TestBackendFactory:
|
||||
"""Tests for backend factory."""
|
||||
|
||||
Reference in New Issue
Block a user