feat: add tracing to llm clients (#2340)

This commit is contained in:
cthomas
2025-05-22 13:55:32 -07:00
committed by GitHub
parent e40b389536
commit b554171d41
3 changed files with 13 additions and 0 deletions

View File

@@ -45,11 +45,13 @@ logger = get_logger(__name__)
class AnthropicClient(LLMClientBase):
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(llm_config, async_client=False)
response = client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
client = self._get_anthropic_client(llm_config, async_client=True)
response = await client.beta.messages.create(**request_data, betas=["tools-2024-04-04"])
@@ -339,6 +341,7 @@ class AnthropicClient(LLMClientBase):
# TODO: Input messages doesn't get used here
# TODO: Clean up this interface
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,

View File

@@ -17,6 +17,7 @@ from letta.schemas.message import Message as PydanticMessage
from letta.schemas.openai.chat_completion_request import Tool
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics
from letta.settings import model_settings, settings
from letta.tracing import trace_method
from letta.utils import get_tool_call_id
logger = get_logger(__name__)
@@ -32,6 +33,7 @@ class GoogleVertexClient(LLMClientBase):
http_options={"api_version": "v1"},
)
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
@@ -44,6 +46,7 @@ class GoogleVertexClient(LLMClientBase):
)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying request to llm and returns raw response.
@@ -189,6 +192,7 @@ class GoogleVertexClient(LLMClientBase):
return [{"functionDeclarations": function_list}]
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
@@ -248,6 +252,7 @@ class GoogleVertexClient(LLMClientBase):
return request_data
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,

View File

@@ -32,6 +32,7 @@ from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.schemas.openai.chat_completion_request import ToolFunctionChoice, cast_message_to_subtype
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
from letta.settings import model_settings
from letta.tracing import trace_method
logger = get_logger(__name__)
@@ -124,6 +125,7 @@ class OpenAIClient(LLMClientBase):
return kwargs
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
@@ -213,6 +215,7 @@ class OpenAIClient(LLMClientBase):
return data.model_dump(exclude_unset=True)
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
@@ -222,6 +225,7 @@ class OpenAIClient(LLMClientBase):
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
@@ -230,6 +234,7 @@ class OpenAIClient(LLMClientBase):
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
def convert_response_to_chat_completion(
self,
response_data: dict,