feat(crouton): add orgId, userId, Compaction_Settings and LLM_Config (#9022)
* LC one shot? * api changes * fix summarizer nameerror
This commit is contained in:
@@ -0,0 +1,32 @@
|
|||||||
|
"""Add v2 protocol fields to provider_traces
|
||||||
|
|
||||||
|
Revision ID: 9275f62ad282
|
||||||
|
Revises: 297e8217e952
|
||||||
|
Create Date: 2026-01-22
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Sequence, Union
|
||||||
|
|
||||||
|
import sqlalchemy as sa
|
||||||
|
|
||||||
|
from alembic import op
|
||||||
|
|
||||||
|
revision: str = "9275f62ad282"
|
||||||
|
down_revision: Union[str, None] = "297e8217e952"
|
||||||
|
branch_labels: Union[str, Sequence[str], None] = None
|
||||||
|
depends_on: Union[str, Sequence[str], None] = None
|
||||||
|
|
||||||
|
|
||||||
|
def upgrade() -> None:
|
||||||
|
op.add_column("provider_traces", sa.Column("org_id", sa.String(), nullable=True))
|
||||||
|
op.add_column("provider_traces", sa.Column("user_id", sa.String(), nullable=True))
|
||||||
|
op.add_column("provider_traces", sa.Column("compaction_settings", sa.JSON(), nullable=True))
|
||||||
|
op.add_column("provider_traces", sa.Column("llm_config", sa.JSON(), nullable=True))
|
||||||
|
|
||||||
|
|
||||||
|
def downgrade() -> None:
|
||||||
|
op.drop_column("provider_traces", "llm_config")
|
||||||
|
op.drop_column("provider_traces", "compaction_settings")
|
||||||
|
op.drop_column("provider_traces", "user_id")
|
||||||
|
op.drop_column("provider_traces", "org_id")
|
||||||
@@ -39340,13 +39340,51 @@
|
|||||||
],
|
],
|
||||||
"title": "Source",
|
"title": "Source",
|
||||||
"description": "Source service that generated this trace (memgpt-server, lettuce-py)"
|
"description": "Source service that generated this trace (memgpt-server, lettuce-py)"
|
||||||
|
},
|
||||||
|
"org_id": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"type": "string"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Org Id",
|
||||||
|
"description": "ID of the organization"
|
||||||
|
},
|
||||||
|
"compaction_settings": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"additionalProperties": true,
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Compaction Settings",
|
||||||
|
"description": "Compaction/summarization settings (summarization calls only)"
|
||||||
|
},
|
||||||
|
"llm_config": {
|
||||||
|
"anyOf": [
|
||||||
|
{
|
||||||
|
"additionalProperties": true,
|
||||||
|
"type": "object"
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"type": "null"
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"title": "Llm Config",
|
||||||
|
"description": "LLM configuration used for this call (non-summarization calls only)"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"type": "object",
|
"type": "object",
|
||||||
"required": ["request_json", "response_json"],
|
"required": ["request_json", "response_json"],
|
||||||
"title": "ProviderTrace",
|
"title": "ProviderTrace",
|
||||||
"description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n created_at (datetime): The timestamp when the object was created."
|
"description": "Letta's internal representation of a provider trace.\n\nAttributes:\n id (str): The unique identifier of the provider trace.\n request_json (Dict[str, Any]): JSON content of the provider request.\n response_json (Dict[str, Any]): JSON content of the provider response.\n step_id (str): ID of the step that this trace is associated with.\n agent_id (str): ID of the agent that generated this trace.\n agent_tags (list[str]): Tags associated with the agent for filtering.\n call_type (str): Type of call (agent_step, summarization, etc.).\n run_id (str): ID of the run this trace is associated with.\n source (str): Source service that generated this trace (memgpt-server, lettuce-py).\n organization_id (str): The unique identifier of the organization.\n user_id (str): The unique identifier of the user who initiated the request.\n compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).\n llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).\n created_at (datetime): The timestamp when the object was created."
|
||||||
},
|
},
|
||||||
"ProviderType": {
|
"ProviderType": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
|||||||
@@ -27,12 +27,16 @@ class LettaLLMAdapter(ABC):
|
|||||||
agent_id: str | None = None,
|
agent_id: str | None = None,
|
||||||
agent_tags: list[str] | None = None,
|
agent_tags: list[str] | None = None,
|
||||||
run_id: str | None = None,
|
run_id: str | None = None,
|
||||||
|
org_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
self.llm_client: LLMClientBase = llm_client
|
self.llm_client: LLMClientBase = llm_client
|
||||||
self.llm_config: LLMConfig = llm_config
|
self.llm_config: LLMConfig = llm_config
|
||||||
self.agent_id: str | None = agent_id
|
self.agent_id: str | None = agent_id
|
||||||
self.agent_tags: list[str] | None = agent_tags
|
self.agent_tags: list[str] | None = agent_tags
|
||||||
self.run_id: str | None = run_id
|
self.run_id: str | None = run_id
|
||||||
|
self.org_id: str | None = org_id
|
||||||
|
self.user_id: str | None = user_id
|
||||||
self.message_id: str | None = None
|
self.message_id: str | None = None
|
||||||
self.request_data: dict | None = None
|
self.request_data: dict | None = None
|
||||||
self.response_data: dict | None = None
|
self.response_data: dict | None = None
|
||||||
|
|||||||
@@ -127,6 +127,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
|||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
agent_tags=self.agent_tags,
|
agent_tags=self.agent_tags,
|
||||||
run_id=self.run_id,
|
run_id=self.run_id,
|
||||||
|
org_id=self.org_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
label="create_provider_trace",
|
label="create_provider_trace",
|
||||||
|
|||||||
@@ -33,8 +33,10 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
|||||||
agent_id: str | None = None,
|
agent_id: str | None = None,
|
||||||
agent_tags: list[str] | None = None,
|
agent_tags: list[str] | None = None,
|
||||||
run_id: str | None = None,
|
run_id: str | None = None,
|
||||||
|
org_id: str | None = None,
|
||||||
|
user_id: str | None = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id)
|
super().__init__(llm_client, llm_config, agent_id=agent_id, agent_tags=agent_tags, run_id=run_id, org_id=org_id, user_id=user_id)
|
||||||
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
self.interface: OpenAIStreamingInterface | AnthropicStreamingInterface | None = None
|
||||||
|
|
||||||
async def invoke_llm(
|
async def invoke_llm(
|
||||||
@@ -236,6 +238,9 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
|||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
agent_tags=self.agent_tags,
|
agent_tags=self.agent_tags,
|
||||||
run_id=self.run_id,
|
run_id=self.run_id,
|
||||||
|
org_id=self.org_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
label="create_provider_trace",
|
label="create_provider_trace",
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
|||||||
agent_tags=self.agent_tags,
|
agent_tags=self.agent_tags,
|
||||||
run_id=self.run_id,
|
run_id=self.run_id,
|
||||||
call_type="agent_step",
|
call_type="agent_step",
|
||||||
|
org_id=self.org_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||||
)
|
)
|
||||||
try:
|
try:
|
||||||
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
self.response_data = await self.llm_client.request_async_with_telemetry(request_data, self.llm_config)
|
||||||
|
|||||||
@@ -283,6 +283,9 @@ class SimpleLLMStreamAdapter(LettaLLMStreamAdapter):
|
|||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
agent_tags=self.agent_tags,
|
agent_tags=self.agent_tags,
|
||||||
run_id=self.run_id,
|
run_id=self.run_id,
|
||||||
|
org_id=self.org_id,
|
||||||
|
user_id=self.user_id,
|
||||||
|
llm_config=self.llm_config.model_dump() if self.llm_config else None,
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
label="create_provider_trace",
|
label="create_provider_trace",
|
||||||
|
|||||||
@@ -420,6 +420,9 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
agent_tags=agent_state.tags,
|
agent_tags=agent_state.tags,
|
||||||
run_id=self.current_run_id,
|
run_id=self.current_run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
|
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
step_progression = StepProgression.LOGGED_TRACE
|
step_progression = StepProgression.LOGGED_TRACE
|
||||||
@@ -770,6 +773,9 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
agent_tags=agent_state.tags,
|
agent_tags=agent_state.tags,
|
||||||
run_id=self.current_run_id,
|
run_id=self.current_run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
|
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
step_progression = StepProgression.LOGGED_TRACE
|
step_progression = StepProgression.LOGGED_TRACE
|
||||||
@@ -1242,6 +1248,9 @@ class LettaAgent(BaseAgent):
|
|||||||
agent_id=self.agent_id,
|
agent_id=self.agent_id,
|
||||||
agent_tags=agent_state.tags,
|
agent_tags=agent_state.tags,
|
||||||
run_id=self.current_run_id,
|
run_id=self.current_run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
|
llm_config=self.agent_state.llm_config.model_dump() if self.agent_state.llm_config else None,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
step_progression = StepProgression.LOGGED_TRACE
|
step_progression = StepProgression.LOGGED_TRACE
|
||||||
|
|||||||
@@ -156,7 +156,11 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
run_id=None,
|
run_id=None,
|
||||||
messages=in_context_messages + input_messages_to_persist,
|
messages=in_context_messages + input_messages_to_persist,
|
||||||
llm_adapter=LettaLLMRequestAdapter(
|
llm_adapter=LettaLLMRequestAdapter(
|
||||||
llm_client=self.llm_client, llm_config=self.agent_state.llm_config, agent_tags=self.agent_state.tags
|
llm_client=self.llm_client,
|
||||||
|
llm_config=self.agent_state.llm_config,
|
||||||
|
agent_tags=self.agent_state.tags,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
),
|
),
|
||||||
dry_run=True,
|
dry_run=True,
|
||||||
enforce_run_id_set=False,
|
enforce_run_id_set=False,
|
||||||
@@ -213,6 +217,8 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
agent_tags=self.agent_state.tags,
|
agent_tags=self.agent_state.tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
),
|
),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
use_assistant_message=use_assistant_message,
|
use_assistant_message=use_assistant_message,
|
||||||
@@ -298,6 +304,8 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
agent_tags=self.agent_state.tags,
|
agent_tags=self.agent_state.tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
llm_adapter = LettaLLMRequestAdapter(
|
llm_adapter = LettaLLMRequestAdapter(
|
||||||
@@ -306,6 +314,8 @@ class LettaAgentV2(BaseAgentV2):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
agent_tags=self.agent_state.tags,
|
agent_tags=self.agent_state.tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -173,6 +173,8 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
agent_tags=self.agent_state.tags,
|
agent_tags=self.agent_state.tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
),
|
),
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
# use_assistant_message=use_assistant_message,
|
# use_assistant_message=use_assistant_message,
|
||||||
@@ -316,6 +318,8 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
agent_tags=self.agent_state.tags,
|
agent_tags=self.agent_state.tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
llm_adapter = SimpleLLMRequestAdapter(
|
llm_adapter = SimpleLLMRequestAdapter(
|
||||||
@@ -324,6 +328,8 @@ class LettaAgentV3(LettaAgentV2):
|
|||||||
agent_id=self.agent_state.id,
|
agent_id=self.agent_state.id,
|
||||||
agent_tags=self.agent_state.tags,
|
agent_tags=self.agent_state.tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
|
org_id=self.actor.organization_id,
|
||||||
|
user_id=self.actor.id,
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -43,6 +43,10 @@ class LLMClientBase:
|
|||||||
self._telemetry_run_id: Optional[str] = None
|
self._telemetry_run_id: Optional[str] = None
|
||||||
self._telemetry_step_id: Optional[str] = None
|
self._telemetry_step_id: Optional[str] = None
|
||||||
self._telemetry_call_type: Optional[str] = None
|
self._telemetry_call_type: Optional[str] = None
|
||||||
|
self._telemetry_org_id: Optional[str] = None
|
||||||
|
self._telemetry_user_id: Optional[str] = None
|
||||||
|
self._telemetry_compaction_settings: Optional[Dict] = None
|
||||||
|
self._telemetry_llm_config: Optional[Dict] = None
|
||||||
|
|
||||||
def set_telemetry_context(
|
def set_telemetry_context(
|
||||||
self,
|
self,
|
||||||
@@ -52,6 +56,10 @@ class LLMClientBase:
|
|||||||
run_id: Optional[str] = None,
|
run_id: Optional[str] = None,
|
||||||
step_id: Optional[str] = None,
|
step_id: Optional[str] = None,
|
||||||
call_type: Optional[str] = None,
|
call_type: Optional[str] = None,
|
||||||
|
org_id: Optional[str] = None,
|
||||||
|
user_id: Optional[str] = None,
|
||||||
|
compaction_settings: Optional[Dict] = None,
|
||||||
|
llm_config: Optional[Dict] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Set telemetry context for provider trace logging."""
|
"""Set telemetry context for provider trace logging."""
|
||||||
self._telemetry_manager = telemetry_manager
|
self._telemetry_manager = telemetry_manager
|
||||||
@@ -60,6 +68,10 @@ class LLMClientBase:
|
|||||||
self._telemetry_run_id = run_id
|
self._telemetry_run_id = run_id
|
||||||
self._telemetry_step_id = step_id
|
self._telemetry_step_id = step_id
|
||||||
self._telemetry_call_type = call_type
|
self._telemetry_call_type = call_type
|
||||||
|
self._telemetry_org_id = org_id
|
||||||
|
self._telemetry_user_id = user_id
|
||||||
|
self._telemetry_compaction_settings = compaction_settings
|
||||||
|
self._telemetry_llm_config = llm_config
|
||||||
|
|
||||||
async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
async def request_async_with_telemetry(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||||
"""Wrapper around request_async that logs telemetry for all requests including errors.
|
"""Wrapper around request_async that logs telemetry for all requests including errors.
|
||||||
@@ -96,6 +108,10 @@ class LLMClientBase:
|
|||||||
agent_tags=self._telemetry_agent_tags,
|
agent_tags=self._telemetry_agent_tags,
|
||||||
run_id=self._telemetry_run_id,
|
run_id=self._telemetry_run_id,
|
||||||
call_type=self._telemetry_call_type,
|
call_type=self._telemetry_call_type,
|
||||||
|
org_id=self._telemetry_org_id,
|
||||||
|
user_id=self._telemetry_user_id,
|
||||||
|
compaction_settings=self._telemetry_compaction_settings,
|
||||||
|
llm_config=self._telemetry_llm_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
@@ -137,6 +153,10 @@ class LLMClientBase:
|
|||||||
agent_tags=self._telemetry_agent_tags,
|
agent_tags=self._telemetry_agent_tags,
|
||||||
run_id=self._telemetry_run_id,
|
run_id=self._telemetry_run_id,
|
||||||
call_type=self._telemetry_call_type,
|
call_type=self._telemetry_call_type,
|
||||||
|
org_id=self._telemetry_org_id,
|
||||||
|
user_id=self._telemetry_user_id,
|
||||||
|
compaction_settings=self._telemetry_compaction_settings,
|
||||||
|
llm_config=self._telemetry_llm_config,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
@@ -32,5 +32,15 @@ class ProviderTrace(SqlalchemyBase, OrganizationMixin):
|
|||||||
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
|
String, nullable=True, doc="Source service that generated this trace (memgpt-server, lettuce-py)"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# v2 protocol fields
|
||||||
|
org_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the organization")
|
||||||
|
user_id: Mapped[Optional[str]] = mapped_column(String, nullable=True, doc="ID of the user who initiated the request")
|
||||||
|
compaction_settings: Mapped[Optional[dict]] = mapped_column(
|
||||||
|
JSON, nullable=True, doc="Compaction/summarization settings (summarization calls only)"
|
||||||
|
)
|
||||||
|
llm_config: Mapped[Optional[dict]] = mapped_column(
|
||||||
|
JSON, nullable=True, doc="LLM configuration used for this call (non-summarization calls only)"
|
||||||
|
)
|
||||||
|
|
||||||
# Relationships
|
# Relationships
|
||||||
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
organization: Mapped["Organization"] = relationship("Organization", lazy="selectin")
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ class ProviderTrace(BaseProviderTrace):
|
|||||||
run_id (str): ID of the run this trace is associated with.
|
run_id (str): ID of the run this trace is associated with.
|
||||||
source (str): Source service that generated this trace (memgpt-server, lettuce-py).
|
source (str): Source service that generated this trace (memgpt-server, lettuce-py).
|
||||||
organization_id (str): The unique identifier of the organization.
|
organization_id (str): The unique identifier of the organization.
|
||||||
|
user_id (str): The unique identifier of the user who initiated the request.
|
||||||
|
compaction_settings (Dict[str, Any]): Compaction/summarization settings (only for summarization calls).
|
||||||
|
llm_config (Dict[str, Any]): LLM configuration used for this call (only for non-summarization calls).
|
||||||
created_at (datetime): The timestamp when the object was created.
|
created_at (datetime): The timestamp when the object was created.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@@ -44,4 +47,10 @@ class ProviderTrace(BaseProviderTrace):
|
|||||||
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
|
run_id: Optional[str] = Field(None, description="ID of the run this trace is associated with")
|
||||||
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
|
source: Optional[str] = Field(None, description="Source service that generated this trace (memgpt-server, lettuce-py)")
|
||||||
|
|
||||||
|
# v2 protocol fields
|
||||||
|
org_id: Optional[str] = Field(None, description="ID of the organization")
|
||||||
|
user_id: Optional[str] = Field(None, description="ID of the user who initiated the request")
|
||||||
|
compaction_settings: Optional[Dict[str, Any]] = Field(None, description="Compaction/summarization settings (summarization calls only)")
|
||||||
|
llm_config: Optional[Dict[str, Any]] = Field(None, description="LLM configuration used for this call (non-summarization calls only)")
|
||||||
|
|
||||||
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
created_at: datetime = Field(default_factory=get_utc_time, description="The timestamp when the object was created.")
|
||||||
|
|||||||
@@ -17,7 +17,8 @@ logger = get_logger(__name__)
|
|||||||
# Protocol version for crouton communication.
|
# Protocol version for crouton communication.
|
||||||
# Bump this when making breaking changes to the record schema.
|
# Bump this when making breaking changes to the record schema.
|
||||||
# Must match ProtocolVersion in apps/crouton/main.go.
|
# Must match ProtocolVersion in apps/crouton/main.go.
|
||||||
PROTOCOL_VERSION = 1
|
# v2: Added user_id, compaction_settings (summarization), llm_config (non-summarization)
|
||||||
|
PROTOCOL_VERSION = 2
|
||||||
|
|
||||||
|
|
||||||
class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
||||||
@@ -94,6 +95,11 @@ class SocketProviderTraceBackend(ProviderTraceBackendClient):
|
|||||||
"error": error,
|
"error": error,
|
||||||
"error_type": error_type,
|
"error_type": error_type,
|
||||||
"timestamp": datetime.now(timezone.utc).isoformat(),
|
"timestamp": datetime.now(timezone.utc).isoformat(),
|
||||||
|
# v2 protocol fields
|
||||||
|
"org_id": provider_trace.org_id,
|
||||||
|
"user_id": provider_trace.user_id,
|
||||||
|
"compaction_settings": provider_trace.compaction_settings,
|
||||||
|
"llm_config": provider_trace.llm_config,
|
||||||
}
|
}
|
||||||
|
|
||||||
# Fire-and-forget in background thread
|
# Fire-and-forget in background thread
|
||||||
|
|||||||
@@ -181,6 +181,8 @@ class Summarizer:
|
|||||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=self.agent_id, actor=self.actor)
|
||||||
|
|
||||||
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
|
# TODO if we do this via the "agent", then we can more easily allow toggling on the memory block version
|
||||||
|
from letta.settings import summarizer_settings
|
||||||
|
|
||||||
summary_message_str = await simple_summary(
|
summary_message_str = await simple_summary(
|
||||||
messages=messages_to_summarize,
|
messages=messages_to_summarize,
|
||||||
llm_config=agent_state.llm_config,
|
llm_config=agent_state.llm_config,
|
||||||
@@ -190,6 +192,12 @@ class Summarizer:
|
|||||||
agent_tags=agent_state.tags,
|
agent_tags=agent_state.tags,
|
||||||
run_id=run_id if run_id is not None else self.run_id,
|
run_id=run_id if run_id is not None else self.run_id,
|
||||||
step_id=step_id if step_id is not None else self.step_id,
|
step_id=step_id if step_id is not None else self.step_id,
|
||||||
|
compaction_settings={
|
||||||
|
"mode": str(summarizer_settings.mode.value),
|
||||||
|
"message_buffer_limit": summarizer_settings.message_buffer_limit,
|
||||||
|
"message_buffer_min": summarizer_settings.message_buffer_min,
|
||||||
|
"partial_evict_summarizer_percentage": summarizer_settings.partial_evict_summarizer_percentage,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO add counts back
|
# TODO add counts back
|
||||||
@@ -450,6 +458,7 @@ async def simple_summary(
|
|||||||
agent_tags: List[str] | None = None,
|
agent_tags: List[str] | None = None,
|
||||||
run_id: str | None = None,
|
run_id: str | None = None,
|
||||||
step_id: str | None = None,
|
step_id: str | None = None,
|
||||||
|
compaction_settings: dict | None = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Generate a simple summary from a list of messages.
|
"""Generate a simple summary from a list of messages.
|
||||||
|
|
||||||
@@ -474,6 +483,9 @@ async def simple_summary(
|
|||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
call_type="summarization",
|
call_type="summarization",
|
||||||
|
org_id=actor.organization_id if actor else None,
|
||||||
|
user_id=actor.id if actor else None,
|
||||||
|
compaction_settings=compaction_settings,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare the messages payload to send to the LLM
|
# Prepare the messages payload to send to the LLM
|
||||||
|
|||||||
@@ -68,6 +68,10 @@ async def summarize_all(
|
|||||||
agent_tags=agent_tags,
|
agent_tags=agent_tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
|
compaction_settings={
|
||||||
|
"mode": "summarize_all",
|
||||||
|
"clip_chars": summarizer_config.clip_chars,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
logger.info(f"Summarized {len(messages_to_summarize)} messages")
|
logger.info(f"Summarized {len(messages_to_summarize)} messages")
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,13 @@ async def summarize_via_sliding_window(
|
|||||||
agent_tags=agent_tags,
|
agent_tags=agent_tags,
|
||||||
run_id=run_id,
|
run_id=run_id,
|
||||||
step_id=step_id,
|
step_id=step_id,
|
||||||
|
compaction_settings={
|
||||||
|
"mode": "sliding_window",
|
||||||
|
"messages_summarized": len(messages_to_summarize),
|
||||||
|
"messages_kept": total_message_count - assistant_message_index,
|
||||||
|
"sliding_window_percentage": summarizer_config.sliding_window_percentage,
|
||||||
|
"clip_chars": summarizer_config.clip_chars,
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
|
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
|
||||||
|
|||||||
@@ -92,6 +92,48 @@ class TestProviderTrace:
|
|||||||
assert trace.call_type == "summarization"
|
assert trace.call_type == "summarization"
|
||||||
assert trace.run_id == "run-789"
|
assert trace.run_id == "run-789"
|
||||||
|
|
||||||
|
def test_v2_protocol_fields(self):
|
||||||
|
"""Test v2 protocol fields (org_id, user_id, compaction_settings, llm_config)."""
|
||||||
|
trace = ProviderTrace(
|
||||||
|
request_json={},
|
||||||
|
response_json={},
|
||||||
|
step_id="step-123",
|
||||||
|
org_id="org-123",
|
||||||
|
user_id="user-123",
|
||||||
|
compaction_settings={"mode": "sliding_window", "target_message_count": 50},
|
||||||
|
llm_config={"model": "gpt-4", "temperature": 0.7},
|
||||||
|
)
|
||||||
|
assert trace.org_id == "org-123"
|
||||||
|
assert trace.user_id == "user-123"
|
||||||
|
assert trace.compaction_settings == {"mode": "sliding_window", "target_message_count": 50}
|
||||||
|
assert trace.llm_config == {"model": "gpt-4", "temperature": 0.7}
|
||||||
|
|
||||||
|
def test_v2_fields_mutually_exclusive_by_convention(self):
|
||||||
|
"""Test that compaction_settings is set for summarization, llm_config for non-summarization."""
|
||||||
|
summarization_trace = ProviderTrace(
|
||||||
|
request_json={},
|
||||||
|
response_json={},
|
||||||
|
step_id="step-123",
|
||||||
|
call_type="summarization",
|
||||||
|
compaction_settings={"mode": "partial_evict"},
|
||||||
|
llm_config=None,
|
||||||
|
)
|
||||||
|
assert summarization_trace.call_type == "summarization"
|
||||||
|
assert summarization_trace.compaction_settings is not None
|
||||||
|
assert summarization_trace.llm_config is None
|
||||||
|
|
||||||
|
agent_step_trace = ProviderTrace(
|
||||||
|
request_json={},
|
||||||
|
response_json={},
|
||||||
|
step_id="step-456",
|
||||||
|
call_type="agent_step",
|
||||||
|
compaction_settings=None,
|
||||||
|
llm_config={"model": "claude-3"},
|
||||||
|
)
|
||||||
|
assert agent_step_trace.call_type == "agent_step"
|
||||||
|
assert agent_step_trace.compaction_settings is None
|
||||||
|
assert agent_step_trace.llm_config is not None
|
||||||
|
|
||||||
|
|
||||||
class TestSocketProviderTraceBackend:
|
class TestSocketProviderTraceBackend:
|
||||||
"""Tests for SocketProviderTraceBackend."""
|
"""Tests for SocketProviderTraceBackend."""
|
||||||
@@ -246,6 +288,36 @@ class TestSocketProviderTraceBackend:
|
|||||||
assert captured_records[0]["error"] == "Rate limit exceeded"
|
assert captured_records[0]["error"] == "Rate limit exceeded"
|
||||||
assert captured_records[0]["response"] is None
|
assert captured_records[0]["response"] is None
|
||||||
|
|
||||||
|
def test_record_includes_v2_protocol_fields(self):
|
||||||
|
"""Test that v2 protocol fields are included in the socket record."""
|
||||||
|
trace = ProviderTrace(
|
||||||
|
request_json={"model": "gpt-4"},
|
||||||
|
response_json={"id": "test"},
|
||||||
|
step_id="step-123",
|
||||||
|
org_id="org-456",
|
||||||
|
user_id="user-456",
|
||||||
|
compaction_settings={"mode": "sliding_window"},
|
||||||
|
llm_config={"model": "gpt-4", "temperature": 0.5},
|
||||||
|
)
|
||||||
|
|
||||||
|
backend = SocketProviderTraceBackend(socket_path="/fake/path")
|
||||||
|
|
||||||
|
captured_records = []
|
||||||
|
|
||||||
|
def capture_record(record):
|
||||||
|
captured_records.append(record)
|
||||||
|
|
||||||
|
with patch.object(backend, "_send_async", side_effect=capture_record):
|
||||||
|
backend._send_to_crouton(trace)
|
||||||
|
|
||||||
|
assert len(captured_records) == 1
|
||||||
|
record = captured_records[0]
|
||||||
|
assert record["protocol_version"] == 2
|
||||||
|
assert record["org_id"] == "org-456"
|
||||||
|
assert record["user_id"] == "user-456"
|
||||||
|
assert record["compaction_settings"] == {"mode": "sliding_window"}
|
||||||
|
assert record["llm_config"] == {"model": "gpt-4", "temperature": 0.5}
|
||||||
|
|
||||||
|
|
||||||
class TestBackendFactory:
|
class TestBackendFactory:
|
||||||
"""Tests for backend factory."""
|
"""Tests for backend factory."""
|
||||||
|
|||||||
Reference in New Issue
Block a user