From 98f00624160dea7f1df0388e5090f2ece73c5e85 Mon Sep 17 00:00:00 2001 From: Kevin Lin Date: Tue, 18 Feb 2025 15:28:01 -0800 Subject: [PATCH] feat: support deepseek models (#821) Co-authored-by: Charles Packer Co-authored-by: Sarah Wooders Co-authored-by: Shubham Naik Co-authored-by: Shubham Naik --- letta/agent.py | 5 +- letta/constants.py | 2 + letta/llm_api/deepseek.py | 303 ++++++++++++++++++ letta/llm_api/llm_api_tools.py | 56 +++- letta/llm_api/openai.py | 13 + letta/local_llm/chat_completion_proxy.py | 17 +- letta/local_llm/lmstudio/api.py | 76 ++++- letta/schemas/llm_config.py | 2 + .../openai/chat_completion_response.py | 2 + letta/schemas/providers.py | 69 ++++ letta/server/rest_api/interface.py | 123 ++++++- letta/server/server.py | 8 +- letta/settings.py | 3 + letta/utils.py | 4 + .../llm_model_configs/deepseek-reasoner.json | 7 + tests/helpers/endpoints_helper.py | 17 +- tests/test_model_letta_performance.py | 13 + tests/test_providers.py | 9 + 18 files changed, 709 insertions(+), 20 deletions(-) create mode 100644 letta/llm_api/deepseek.py create mode 100644 tests/configs/llm_model_configs/deepseek-reasoner.json diff --git a/letta/agent.py b/letta/agent.py index d4a71b42..986f9e4a 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -492,7 +492,10 @@ class Agent(BaseAgent): try: raw_function_args = function_call.arguments function_args = parse_json(raw_function_args) - except Exception: + if not isinstance(function_args, dict): + raise ValueError(f"Function arguments are not a dictionary: {function_args} (raw={raw_function_args})") + except Exception as e: + print(e) error_msg = f"Error parsing JSON for function '{function_name}' arguments: {function_call.arguments}" function_response = "None" # more like "never ran?" messages = self._handle_function_error_response( diff --git a/letta/constants.py b/letta/constants.py index 1a59137e..35ab7cb4 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -86,6 +86,8 @@ NON_USER_MSG_PREFIX = "[This is an automated system message hidden from the user # The max amount of tokens supported by the underlying model (eg 8k for gpt-4 and Mistral 7B) LLM_MAX_TOKENS = { "DEFAULT": 8192, + "deepseek-chat": 64000, + "deepseek-reasoner": 64000, ## OpenAI models: https://platform.openai.com/docs/models/overview # "o1-preview "chatgpt-4o-latest": 128000, diff --git a/letta/llm_api/deepseek.py b/letta/llm_api/deepseek.py new file mode 100644 index 00000000..f0b2a45a --- /dev/null +++ b/letta/llm_api/deepseek.py @@ -0,0 +1,303 @@ +import json +import re +import warnings +from typing import List, Optional + +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as _Message +from letta.schemas.openai.chat_completion_request import AssistantMessage, ChatCompletionRequest, ChatMessage +from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall +from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, ToolMessage, UserMessage, cast_message_to_subtype +from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +from letta.schemas.openai.openai import Function, ToolCall +from letta.utils import get_tool_call_id + + +def merge_tool_message(previous_message: ChatMessage, tool_message: ToolMessage) -> ChatMessage: + """ + Merge `ToolMessage` objects into the previous message. + """ + previous_message.content += ( + f" content: {tool_message.content}, role: {tool_message.role}, tool_call_id: {tool_message.tool_call_id}" + ) + return previous_message + + +def handle_assistant_message(assistant_message: AssistantMessage) -> AssistantMessage: + """ + For `AssistantMessage` objects, remove the `tool_calls` field and add them to the `content` field. + """ + + if "tool_calls" in assistant_message.dict().keys(): + assistant_message.content = "".join( + [ + # f" name: {tool_call.function.name}, function: {tool_call.function}" + f" {json.dumps(tool_call.function.dict())} " + for tool_call in assistant_message.tool_calls + ] + ) + del assistant_message.tool_calls + return assistant_message + + +def map_messages_to_deepseek_format(messages: List[ChatMessage]) -> List[_Message]: + """ + Deepeek API has the following constraints: messages must be interleaved between user and assistant messages, ending on a user message. + Tools are currently unstable for V3 and not supported for R1 in the API: https://api-docs.deepseek.com/guides/function_calling. + + This function merges ToolMessages into AssistantMessages and removes ToolCalls from AssistantMessages, and adds a dummy user message + at the end. + + """ + deepseek_messages = [] + for idx, message in enumerate(messages): + # First message is the system prompt, add it + if idx == 0 and message.role == "system": + deepseek_messages.append(message) + continue + if message.role == "user": + if deepseek_messages[-1].role == "assistant" or deepseek_messages[-1].role == "system": + # User message, add it + deepseek_messages.append(UserMessage(content=message.content)) + else: + # add to the content of the previous message + deepseek_messages[-1].content += message.content + elif message.role == "assistant": + if deepseek_messages[-1].role == "user": + # Assistant message, remove tool calls and add them to the content + deepseek_messages.append(handle_assistant_message(message)) + else: + # add to the content of the previous message + deepseek_messages[-1].content += message.content + elif message.role == "tool" and deepseek_messages[-1].role == "assistant": + # Tool message, add it to the last assistant message + merged_message = merge_tool_message(deepseek_messages[-1], message) + deepseek_messages[-1] = merged_message + else: + print(f"Skipping message: {message}") + + # This needs to end on a user message, add a dummy message if the last was assistant + if deepseek_messages[-1].role == "assistant": + deepseek_messages.append(UserMessage(content="")) + return deepseek_messages + + +def build_deepseek_chat_completions_request( + llm_config: LLMConfig, + messages: List[_Message], + user_id: Optional[str], + functions: Optional[list], + function_call: Optional[str], + use_tool_naming: bool, + max_tokens: Optional[int], +) -> ChatCompletionRequest: + # if functions and llm_config.put_inner_thoughts_in_kwargs: + # # Special case for LM Studio backend since it needs extra guidance to force out the thoughts first + # # TODO(fix) + # inner_thoughts_desc = ( + # INNER_THOUGHTS_KWARG_DESCRIPTION_GO_FIRST if ":1234" in llm_config.model_endpoint else INNER_THOUGHTS_KWARG_DESCRIPTION + # ) + # functions = add_inner_thoughts_to_functions( + # functions=functions, + # inner_thoughts_key=INNER_THOUGHTS_KWARG, + # inner_thoughts_description=inner_thoughts_desc, + # ) + + openai_message_list = [cast_message_to_subtype(m.to_openai_dict(put_inner_thoughts_in_kwargs=False)) for m in messages] + + if llm_config.model: + model = llm_config.model + else: + warnings.warn(f"Model type not set in llm_config: {llm_config.model_dump_json(indent=4)}") + model = None + if use_tool_naming: + if function_call is None: + tool_choice = None + elif function_call not in ["none", "auto", "required"]: + tool_choice = ToolFunctionChoice(type="function", function=ToolFunctionChoiceFunctionCall(name=function_call)) + else: + tool_choice = function_call + + def add_functions_to_system_message(system_message: ChatMessage): + system_message.content += f" {''.join(json.dumps(f) for f in functions)} " + system_message.content += f'Select best function to call simply respond with a single json block with the fields "name" and "arguments". Use double quotes around the arguments.' + + if llm_config.model == "deepseek-reasoner": # R1 currently doesn't support function calling natively + add_functions_to_system_message( + openai_message_list[0] + ) # Inject additional instructions to the system prompt with the available functions + + openai_message_list = map_messages_to_deepseek_format(openai_message_list) + + data = ChatCompletionRequest( + model=model, + messages=openai_message_list, + user=str(user_id), + max_completion_tokens=max_tokens, + temperature=llm_config.temperature, + ) + else: + data = ChatCompletionRequest( + model=model, + messages=openai_message_list, + tools=[Tool(type="function", function=f) for f in functions] if functions else None, + tool_choice=tool_choice, + user=str(user_id), + max_completion_tokens=max_tokens, + temperature=llm_config.temperature, + ) + else: + data = ChatCompletionRequest( + model=model, + messages=openai_message_list, + functions=functions, + function_call=function_call, + user=str(user_id), + max_completion_tokens=max_tokens, + temperature=llm_config.temperature, + ) + + return data + + +def convert_deepseek_response_to_chatcompletion( + response: ChatCompletionResponse, +) -> ChatCompletionResponse: + """ + Example response from DeepSeek: + + ChatCompletion( + id='bc7f7d25-82e4-443a-b217-dfad2b66da8e', + choices=[ + Choice( + finish_reason='stop', + index=0, + logprobs=None, + message=ChatCompletionMessage( + content='{"function": "send_message", "arguments": {"message": "Hey! Whales are such majestic creatures, aren\'t they? How\'s your day going? 🌊 "}}', + refusal=None, + role='assistant', + audio=None, + function_call=None, + tool_calls=None, + reasoning_content='Okay, the user said "hello whales". Hmm, that\'s an interesting greeting. Maybe they meant "hello there" or are they actually talking about whales? Let me check if I misheard. Whales are fascinating creatures. I should respond in a friendly way. Let me ask them how they\'re doing and mention whales to keep the conversation going.' + ) + ) + ], + created=1738266449, + model='deepseek-reasoner', + object='chat.completion', + service_tier=None, + system_fingerprint='fp_7e73fd9a08', + usage=CompletionUsage( + completion_tokens=111, + prompt_tokens=1270, + total_tokens=1381, + completion_tokens_details=CompletionTokensDetails( + accepted_prediction_tokens=None, + audio_tokens=None, + reasoning_tokens=72, + rejected_prediction_tokens=None + ), + prompt_tokens_details=PromptTokensDetails( + audio_tokens=None, + cached_tokens=1088 + ), + prompt_cache_hit_tokens=1088, + prompt_cache_miss_tokens=182 + ) + ) + """ + + def convert_dict_quotes(input_dict: dict): + """ + Convert a dictionary with single-quoted keys to double-quoted keys, + properly handling boolean values and nested structures. + + Args: + input_dict (dict): Input dictionary with single-quoted keys + + Returns: + str: JSON string with double-quoted keys + """ + # First convert the dictionary to a JSON string to handle booleans properly + json_str = json.dumps(input_dict) + + # Function to handle complex string replacements + def replace_quotes(match): + key = match.group(1) + # Escape any existing double quotes in the key + key = key.replace('"', '\\"') + return f'"{key}":' + + # Replace single-quoted keys with double-quoted keys + # This regex looks for single-quoted keys followed by a colon + def strip_json_block(text): + # Check if text starts with ```json or similar + if text.strip().startswith("```"): + # Split by \n to remove the first and last lines + lines = text.split("\n")[1:-1] + return "\n".join(lines) + return text + + pattern = r"'([^']*)':" + converted_str = re.sub(pattern, replace_quotes, strip_json_block(json_str)) + + # Parse the string back to ensure valid JSON format + try: + json.loads(converted_str) + return converted_str + except json.JSONDecodeError as e: + raise ValueError(f"Failed to create valid JSON with double quotes: {str(e)}") + + def extract_json_block(text): + # Find the first { + start = text.find("{") + if start == -1: + return text + + # Track nested braces to find the matching closing brace + brace_count = 0 + end = start + + for i in range(start, len(text)): + if text[i] == "{": + brace_count += 1 + elif text[i] == "}": + brace_count -= 1 + if brace_count == 0: + end = i + 1 + break + + return text[start:end] + + content = response.choices[0].message.content + try: + content_dict = json.loads(extract_json_block(content)) + + if type(content_dict["arguments"]) == str: + content_dict["arguments"] = json.loads(content_dict["arguments"]) + + tool_calls = [ + ToolCall( + id=get_tool_call_id(), + type="function", + function=Function( + name=content_dict["name"], + arguments=convert_dict_quotes(content_dict["arguments"]), + ), + ) + ] + except (json.JSONDecodeError, TypeError, KeyError) as e: + print(e) + tool_calls = response.choices[0].message.tool_calls + raise ValueError(f"Failed to create valid JSON {content}") + + # Move the "reasoning_content" into the "content" field + response.choices[0].message.content = response.choices[0].message.reasoning_content + response.choices[0].message.tool_calls = tool_calls + + # Remove the "reasoning_content" field + response.choices[0].message.reasoning_content = None + + return response diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index e4bc63f6..b92098d4 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -1,3 +1,4 @@ +import json import random import time from typing import List, Optional, Union @@ -13,6 +14,7 @@ from letta.llm_api.anthropic import ( ) from letta.llm_api.aws_bedrock import has_valid_aws_credentials from letta.llm_api.azure_openai import azure_openai_chat_completions_request +from letta.llm_api.deepseek import build_deepseek_chat_completions_request, convert_deepseek_response_to_chatcompletion from letta.llm_api.google_ai import convert_tools_to_google_ai_format, google_ai_chat_completions_request from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs from letta.llm_api.openai import ( @@ -30,7 +32,7 @@ from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.settings import ModelSettings from letta.streaming_interface import AgentChunkStreamingInterface, AgentRefreshStreamingInterface -LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq"] +LLM_API_PROVIDER_OPTIONS = ["openai", "azure", "anthropic", "google_ai", "cohere", "local", "groq", "deepseek"] def retry_with_exponential_backoff( @@ -453,10 +455,62 @@ def create( ), ) + elif llm_config.model_endpoint_type == "deepseek": + if model_settings.deepseek_api_key is None and llm_config.model_endpoint == "": + # only is a problem if we are *not* using an openai proxy + raise LettaConfigurationError(message="DeepSeek key is missing from letta config file", missing_fields=["deepseek_api_key"]) + + data = build_deepseek_chat_completions_request( + llm_config, + messages, + user_id, + functions, + function_call, + use_tool_naming, + llm_config.max_tokens, + ) + if stream: # Client requested token streaming + data.stream = True + assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( + stream_interface, AgentRefreshStreamingInterface + ), type(stream_interface) + response = openai_chat_completions_process_stream( + url=llm_config.model_endpoint, + api_key=model_settings.deepseek_api_key, + chat_completion_request=data, + stream_interface=stream_interface, + ) + else: # Client did not request token streaming (expect a blocking backend response) + data.stream = False + if isinstance(stream_interface, AgentChunkStreamingInterface): + stream_interface.stream_start() + try: + response = openai_chat_completions_request( + url=llm_config.model_endpoint, + api_key=model_settings.deepseek_api_key, + chat_completion_request=data, + ) + finally: + if isinstance(stream_interface, AgentChunkStreamingInterface): + stream_interface.stream_end() + """ + if llm_config.put_inner_thoughts_in_kwargs: + response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + """ + response = convert_deepseek_response_to_chatcompletion(response) + return response + # local model else: if stream: raise NotImplementedError(f"Streaming not yet implemented for {llm_config.model_endpoint_type}") + + if "DeepSeek-R1".lower() in llm_config.model.lower(): # TODO: move this to the llm_config. + messages[0].content[0].text += f" {''.join(json.dumps(f) for f in functions)} " + messages[0].content[ + 0 + ].text += f'Select best function to call simply by responding with a single json block with the keys "function" and "params". Use double quotes around the arguments.' + return get_chat_completion( model=llm_config.model, messages=messages, diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index c6762872..c793a49f 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -166,6 +166,11 @@ def openai_chat_completions_process_stream( create_message_id: bool = True, create_message_datetime: bool = True, override_tool_call_id: bool = True, + # if we expect reasoning content in the response, + # then we should emit reasoning_content as "inner_thoughts" + # however, we don't necessarily want to put these + # expect_reasoning_content: bool = False, + expect_reasoning_content: bool = True, ) -> ChatCompletionResponse: """Process a streaming completion response, and return a ChatCompletionRequest at the end. @@ -250,6 +255,7 @@ def openai_chat_completions_process_stream( chat_completion_chunk, message_id=chat_completion_response.id if create_message_id else chat_completion_chunk.id, message_date=chat_completion_response.created if create_message_datetime else chat_completion_chunk.created, + expect_reasoning_content=expect_reasoning_content, ) elif isinstance(stream_interface, AgentRefreshStreamingInterface): stream_interface.process_refresh(chat_completion_response) @@ -290,6 +296,13 @@ def openai_chat_completions_process_stream( else: accum_message.content += content_delta + if expect_reasoning_content and message_delta.reasoning_content is not None: + reasoning_content_delta = message_delta.reasoning_content + if accum_message.reasoning_content is None: + accum_message.reasoning_content = reasoning_content_delta + else: + accum_message.reasoning_content += reasoning_content_delta + # TODO(charles) make sure this works for parallel tool calling? if message_delta.tool_calls is not None: tool_calls_delta = message_delta.tool_calls diff --git a/letta/local_llm/chat_completion_proxy.py b/letta/local_llm/chat_completion_proxy.py index 8ba8bfef..c5e7d025 100644 --- a/letta/local_llm/chat_completion_proxy.py +++ b/letta/local_llm/chat_completion_proxy.py @@ -14,7 +14,7 @@ from letta.local_llm.grammars.gbnf_grammar_generator import create_dynamic_model from letta.local_llm.koboldcpp.api import get_koboldcpp_completion from letta.local_llm.llamacpp.api import get_llamacpp_completion from letta.local_llm.llm_chat_completion_wrappers import simple_summary_wrapper -from letta.local_llm.lmstudio.api import get_lmstudio_completion +from letta.local_llm.lmstudio.api import get_lmstudio_completion, get_lmstudio_completion_chatcompletions from letta.local_llm.ollama.api import get_ollama_completion from letta.local_llm.utils import count_tokens, get_available_wrappers from letta.local_llm.vllm.api import get_vllm_completion @@ -141,11 +141,24 @@ def get_chat_completion( f"Failed to convert ChatCompletion messages into prompt string with wrapper {str(llm_wrapper)} - error: {str(e)}" ) + # get the schema for the model + + """ + if functions_python is not None: + model_schema = generate_schema(functions) + else: + model_schema = None + """ + + # Run the LLM try: + result_reasoning = None if endpoint_type == "webui": result, usage = get_webui_completion(endpoint, auth_type, auth_key, prompt, context_window, grammar=grammar) elif endpoint_type == "webui-legacy": result, usage = get_webui_completion_legacy(endpoint, auth_type, auth_key, prompt, context_window, grammar=grammar) + elif endpoint_type == "lmstudio-chatcompletions": + result, usage, result_reasoning = get_lmstudio_completion_chatcompletions(endpoint, auth_type, auth_key, model, messages) elif endpoint_type == "lmstudio": result, usage = get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_window, api="completions") elif endpoint_type == "lmstudio-legacy": @@ -214,7 +227,7 @@ def get_chat_completion( index=0, message=Message( role=chat_completion_result["role"], - content=chat_completion_result["content"], + content=result_reasoning if result_reasoning is not None else chat_completion_result["content"], tool_calls=( [ToolCall(id=get_tool_call_id(), type="function", function=chat_completion_result["function_call"])] if "function_call" in chat_completion_result diff --git a/letta/local_llm/lmstudio/api.py b/letta/local_llm/lmstudio/api.py index 0debbd1f..dd0debee 100644 --- a/letta/local_llm/lmstudio/api.py +++ b/letta/local_llm/lmstudio/api.py @@ -1,3 +1,4 @@ +import json from urllib.parse import urljoin from letta.local_llm.settings.settings import get_completions_settings @@ -6,6 +7,73 @@ from letta.utils import count_tokens LMSTUDIO_API_CHAT_SUFFIX = "/v1/chat/completions" LMSTUDIO_API_COMPLETIONS_SUFFIX = "/v1/completions" +LMSTUDIO_API_CHAT_COMPLETIONS_SUFFIX = "/v1/chat/completions" + + +def get_lmstudio_completion_chatcompletions(endpoint, auth_type, auth_key, model, messages): + """ + This is the request we need to send + + { + "model": "deepseek-r1-distill-qwen-7b", + "messages": [ + { "role": "system", "content": "Always answer in rhymes. Today is Thursday" }, + { "role": "user", "content": "What day is it today?" }, + { "role": "user", "content": "What day is it today?" }], + "temperature": 0.7, + "max_tokens": -1, + "stream": false + """ + from letta.utils import printd + + URI = endpoint + LMSTUDIO_API_CHAT_COMPLETIONS_SUFFIX + request = {"model": model, "messages": messages} + + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) + + # Get the reasoning from the model + if response.status_code == 200: + result_full = response.json() + result_reasoning = result_full["choices"][0]["message"].get("reasoning_content") + result = result_full["choices"][0]["message"]["content"] + usage = result_full["usage"] + + # See if result is json + try: + function_call = json.loads(result) + if "function" in function_call and "params" in function_call: + return result, usage, result_reasoning + else: + print("Did not get json on without json constraint, attempting with json decoding") + except Exception as e: + print(f"Did not get json on without json constraint, attempting with json decoding: {e}") + + request["messages"].append({"role": "assistant", "content": result_reasoning}) + request["messages"].append({"role": "user", "content": ""}) # last message must be user + # Now run with json decoding to get the function + request["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": "function_call", + "strict": "true", + "schema": { + "type": "object", + "properties": {"function": {"type": "string"}, "params": {"type": "object"}}, + "required": ["function", "params"], + }, + }, + } + + response = post_json_auth_request(uri=URI, json_payload=request, auth_type=auth_type, auth_key=auth_key) + if response.status_code == 200: + result_full = response.json() + printd(f"JSON API response:\n{result_full}") + result = result_full["choices"][0]["message"]["content"] + # add usage with previous call, merge with prev usage + for key, value in result_full["usage"].items(): + usage[key] += value + + return result, usage, result_reasoning def get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_window, api="completions"): @@ -24,7 +92,8 @@ def get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_windo # This controls how LM studio handles context overflow # In Letta we handle this ourselves, so this should be disabled # "context_overflow_policy": 0, - "lmstudio": {"context_overflow_policy": 0}, # 0 = stop at limit + # "lmstudio": {"context_overflow_policy": 0}, # 0 = stop at limit + # "lmstudio": {"context_overflow_policy": "stopAtLimit"}, # https://github.com/letta-ai/letta/issues/1782 "stream": False, "model": "local model", } @@ -72,6 +141,11 @@ def get_lmstudio_completion(endpoint, auth_type, auth_key, prompt, context_windo elif api == "completions": result = result_full["choices"][0]["text"] usage = result_full.get("usage", None) + elif api == "chat/completions": + result = result_full["choices"][0]["content"] + result_full["choices"][0]["reasoning_content"] + usage = result_full.get("usage", None) + else: # Example error: msg={"error":"Context length exceeded. Tokens in context: 8000, Context length: 8000"} if "context length" in str(response.text).lower(): diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 8e44b25e..7a941940 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -33,6 +33,7 @@ class LLMConfig(BaseModel): "webui-legacy", "lmstudio", "lmstudio-legacy", + "lmstudio-chatcompletions", "llamacpp", "koboldcpp", "vllm", @@ -40,6 +41,7 @@ class LLMConfig(BaseModel): "mistral", "together", # completions endpoint "bedrock", + "deepseek", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") diff --git a/letta/schemas/openai/chat_completion_response.py b/letta/schemas/openai/chat_completion_response.py index e41859f0..9c23c3cf 100644 --- a/letta/schemas/openai/chat_completion_response.py +++ b/letta/schemas/openai/chat_completion_response.py @@ -39,6 +39,7 @@ class Message(BaseModel): tool_calls: Optional[List[ToolCall]] = None role: str function_call: Optional[FunctionCall] = None # Deprecated + reasoning_content: Optional[str] = None # Used in newer reasoning APIs class Choice(BaseModel): @@ -115,6 +116,7 @@ class MessageDelta(BaseModel): """ content: Optional[str] = None + reasoning_content: Optional[str] = None tool_calls: Optional[List[ToolCallDelta]] = None role: Optional[str] = None function_call: Optional[FunctionCallDelta] = None # Deprecated diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 621958cc..e569db04 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -211,6 +211,75 @@ class OpenAIProvider(Provider): return None +class DeepSeekProvider(OpenAIProvider): + """ + DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, + but with slight differences: + * For example, DeepSeek's API requires perfect interleaving of user/assistant + * It also does not support native function calling + """ + + name: str = "deepseek" + base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") + api_key: str = Field(..., description="API key for the DeepSeek API.") + + def get_model_context_window_size(self, model_name: str) -> Optional[int]: + # DeepSeek doesn't return context window in the model listing, + # so these are hardcoded from their website + if model_name == "deepseek-reasoner": + return 64000 + elif model_name == "deepseek-chat": + return 64000 + else: + return None + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=self.api_key) + + if "data" in response: + data = response["data"] + else: + data = response + + configs = [] + for model in data: + assert "id" in model, f"DeepSeek model missing 'id' field: {model}" + model_name = model["id"] + + # In case DeepSeek starts supporting it in the future: + if "context_length" in model: + # Context length is returned in OpenRouter as "context_length" + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + warnings.warn(f"Couldn't find context window size for model {model_name}") + continue + + # Not used for deepseek-reasoner, but otherwise is true + put_inner_thoughts_in_kwargs = False if model_name == "deepseek-reasoner" else True + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="deepseek", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + put_inner_thoughts_in_kwargs=put_inner_thoughts_in_kwargs, + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # No embeddings supported + return [] + + class LMStudioOpenAIProvider(OpenAIProvider): name: str = "lmstudio-openai" base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 037b15c8..5b6d342a 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -317,6 +317,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.debug = False self.timeout = 10 * 60 # 10 minute timeout + # for expect_reasoning_content, we should accumulate `content` + self.expect_reasoning_content_buffer = None + def _reset_inner_thoughts_json_reader(self): # A buffer for accumulating function arguments (we want to buffer keys and run checks on each one) self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=self.inner_thoughts_kwarg, wait_for_first_key=True) @@ -387,6 +390,39 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # Wipe the inner thoughts buffers self._reset_inner_thoughts_json_reader() + # If we were in reasoning mode and accumulated a json block, attempt to release it as chunks + # if self.expect_reasoning_content_buffer is not None: + # try: + # # NOTE: this is hardcoded for our DeepSeek API integration + # json_reasoning_content = json.loads(self.expect_reasoning_content_buffer) + + # if "name" in json_reasoning_content: + # self._push_to_buffer( + # ToolCallMessage( + # id=message_id, + # date=message_date, + # tool_call=ToolCallDelta( + # name=json_reasoning_content["name"], + # arguments=None, + # tool_call_id=None, + # ), + # ) + # ) + # if "arguments" in json_reasoning_content: + # self._push_to_buffer( + # ToolCallMessage( + # id=message_id, + # date=message_date, + # tool_call=ToolCallDelta( + # name=None, + # arguments=json_reasoning_content["arguments"], + # tool_call_id=None, + # ), + # ) + # ) + # except Exception as e: + # print(f"Failed to interpret reasoning content ({self.expect_reasoning_content_buffer}) as JSON: {e}") + def step_complete(self): """Signal from the agent that one 'step' finished (step = LLM response + tool execution)""" if not self.multi_step: @@ -410,7 +446,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): return def _process_chunk_to_letta_style( - self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime + self, + chunk: ChatCompletionChunkResponse, + message_id: str, + message_date: datetime, + # if we expect `reasoning_content``, then that's what gets mapped to ReasoningMessage + # and `content` needs to be handled outside the interface + expect_reasoning_content: bool = False, ) -> Optional[Union[ReasoningMessage, ToolCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -426,6 +468,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): if ( message_delta.content is None + and (expect_reasoning_content and message_delta.reasoning_content is None) and message_delta.tool_calls is None and message_delta.function_call is None and choice.finish_reason is None @@ -435,17 +478,68 @@ class StreamingServerInterface(AgentChunkStreamingInterface): return None # inner thoughts - if message_delta.content is not None: - if message_delta.content == "": - print("skipping empty content") - processed_chunk = None + if expect_reasoning_content and message_delta.reasoning_content is not None: + processed_chunk = ReasoningMessage( + id=message_id, + date=message_date, + reasoning=message_delta.reasoning_content, + ) + elif expect_reasoning_content and message_delta.content is not None: + # "ignore" content if we expect reasoning content + if self.expect_reasoning_content_buffer is None: + self.expect_reasoning_content_buffer = message_delta.content else: - processed_chunk = ReasoningMessage( + self.expect_reasoning_content_buffer += message_delta.content + + # we expect this to be pure JSON + # OptimisticJSONParser + + # If we can pull a name out, pull it + + try: + # NOTE: this is hardcoded for our DeepSeek API integration + json_reasoning_content = json.loads(self.expect_reasoning_content_buffer) + print(f"json_reasoning_content: {json_reasoning_content}") + + processed_chunk = ToolCallMessage( id=message_id, date=message_date, - reasoning=message_delta.content, + tool_call=ToolCallDelta( + name=json_reasoning_content.get("name"), + arguments=json.dumps(json_reasoning_content.get("arguments")), + tool_call_id=None, + ), ) + except json.JSONDecodeError as e: + print(f"Failed to interpret reasoning content ({self.expect_reasoning_content_buffer}) as JSON: {e}") + + return None + # Else, + # return None + # processed_chunk = ToolCallMessage( + # id=message_id, + # date=message_date, + # tool_call=ToolCallDelta( + # # name=tool_call_delta.get("name"), + # name=None, + # arguments=message_delta.content, + # # tool_call_id=tool_call_delta.get("id"), + # tool_call_id=None, + # ), + # ) + # return processed_chunk + + # TODO eventually output as tool call outputs? + # print(f"Hiding content delta stream: '{message_delta.content}'") + # return None + elif message_delta.content is not None: + processed_chunk = ReasoningMessage( + id=message_id, + date=message_date, + reasoning=message_delta.content, + ) + # tool calls elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: tool_call = message_delta.tool_calls[0] @@ -890,7 +984,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): return processed_chunk - def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime): + def process_chunk( + self, + chunk: ChatCompletionChunkResponse, + message_id: str, + message_date: datetime, + expect_reasoning_content: bool = False, + ): """Process a streaming chunk from an OpenAI-compatible server. Example data from non-streaming response looks like: @@ -910,7 +1010,12 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # processed_chunk = self._process_chunk_to_openai_style(chunk) raise NotImplementedError("OpenAI proxy streaming temporarily disabled") else: - processed_chunk = self._process_chunk_to_letta_style(chunk=chunk, message_id=message_id, message_date=message_date) + processed_chunk = self._process_chunk_to_letta_style( + chunk=chunk, + message_id=message_id, + message_date=message_date, + expect_reasoning_content=expect_reasoning_content, + ) if processed_chunk is None: return diff --git a/letta/server/server.py b/letta/server/server.py index b3590999..38fe4626 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -48,6 +48,7 @@ from letta.schemas.providers import ( AnthropicBedrockProvider, AnthropicProvider, AzureProvider, + DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, GroqProvider, @@ -305,6 +306,8 @@ class SyncServer(Server): else model_settings.lmstudio_base_url + "/v1" ) self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url)) + if model_settings.deepseek_api_key: + self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" @@ -1182,11 +1185,12 @@ class SyncServer(Server): # Disable token streaming if not OpenAI or Anthropic # TODO: cleanup this logic llm_config = letta_agent.agent_state.llm_config + supports_token_streaming = ["openai", "anthropic", "deepseek"] if stream_tokens and ( - llm_config.model_endpoint_type not in ["openai", "anthropic"] or "inference.memgpt.ai" in llm_config.model_endpoint + llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint ): warnings.warn( - f"Token streaming is only supported for models with type 'openai' or 'anthropic' in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." + f"Token streaming is only supported for models with type {' or '.join(supports_token_streaming)} in the model_endpoint: agent has endpoint type {llm_config.model_endpoint_type} and {llm_config.model_endpoint}. Setting stream_tokens to False." ) stream_tokens = False diff --git a/letta/settings.py b/letta/settings.py index 4e9f0d0b..7dd756e7 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -60,6 +60,9 @@ class ModelSettings(BaseSettings): openai_api_key: Optional[str] = None openai_api_base: str = "https://api.openai.com/v1" + # deepseek + deepseek_api_key: Optional[str] = None + # groq groq_api_key: Optional[str] = None diff --git a/letta/utils.py b/letta/utils.py index b61660e3..33927f83 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -824,12 +824,16 @@ def parse_json(string) -> dict: result = None try: result = json_loads(string) + if not isinstance(result, dict): + raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}") return result except Exception as e: print(f"Error parsing json with json package: {e}") try: result = demjson.decode(string) + if not isinstance(result, dict): + raise ValueError(f"JSON from string input ({string}) is not a dictionary (type {type(result)}): {result}") return result except demjson.JSONDecodeError as e: print(f"Error parsing json with demjson package: {e}") diff --git a/tests/configs/llm_model_configs/deepseek-reasoner.json b/tests/configs/llm_model_configs/deepseek-reasoner.json new file mode 100644 index 00000000..99dac148 --- /dev/null +++ b/tests/configs/llm_model_configs/deepseek-reasoner.json @@ -0,0 +1,7 @@ +{ + "model": "deepseek-reasoner", + "model_endpoint_type": "deepseek", + "model_endpoint": "https://api.deepseek.com/v1", + "context_window": 64000, + "put_inner_thoughts_in_kwargs": false +} diff --git a/tests/helpers/endpoints_helper.py b/tests/helpers/endpoints_helper.py index b59a16af..36fd66c7 100644 --- a/tests/helpers/endpoints_helper.py +++ b/tests/helpers/endpoints_helper.py @@ -83,7 +83,7 @@ def setup_agent( # ====================================================================================================================== -def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatCompletionResponse: +def check_first_response_is_valid_for_llm_endpoint(filename: str, validate_inner_monologue_contents: bool = True) -> ChatCompletionResponse: """ Checks that the first response is valid: @@ -126,7 +126,11 @@ def check_first_response_is_valid_for_llm_endpoint(filename: str) -> ChatComplet assert_contains_valid_function_call(choice.message, validator_func) # Assert that the message has an inner monologue - assert_contains_correct_inner_monologue(choice, agent_state.llm_config.put_inner_thoughts_in_kwargs) + assert_contains_correct_inner_monologue( + choice, + agent_state.llm_config.put_inner_thoughts_in_kwargs, + validate_inner_monologue_contents=validate_inner_monologue_contents, + ) return response @@ -470,7 +474,11 @@ def assert_inner_monologue_is_valid(message: Message) -> None: raise InvalidInnerMonologueError(messages=[message], explanation=f"{phrase} is in monologue") -def assert_contains_correct_inner_monologue(choice: Choice, inner_thoughts_in_kwargs: bool) -> None: +def assert_contains_correct_inner_monologue( + choice: Choice, + inner_thoughts_in_kwargs: bool, + validate_inner_monologue_contents: bool = True, +) -> None: """ Helper function to check that the inner monologue exists and is valid. """ @@ -483,4 +491,5 @@ def assert_contains_correct_inner_monologue(choice: Choice, inner_thoughts_in_kw if not monologue or monologue is None or monologue == "": raise MissingInnerMonologueError(messages=[message]) - assert_inner_monologue_is_valid(message) + if validate_inner_monologue_contents: + assert_inner_monologue_is_valid(message) diff --git a/tests/test_model_letta_performance.py b/tests/test_model_letta_performance.py index 369552c6..3d425f49 100644 --- a/tests/test_model_letta_performance.py +++ b/tests/test_model_letta_performance.py @@ -315,6 +315,19 @@ def test_vertex_gemini_pro_20_returns_valid_first_message(): print(f"Got successful response from client: \n\n{response}") +# ====================================================================================================================== +# DEEPSEEK TESTS +# ====================================================================================================================== +@pytest.mark.deepseek_basic +def test_deepseek_reasoner_returns_valid_first_message(): + filename = os.path.join(llm_config_dir, "deepseek-reasoner.json") + # Don't validate that the inner monologue doesn't contain things like "function", since + # for the reasoners it might be quite meta (have analysis about functions etc.) + response = check_first_response_is_valid_for_llm_endpoint(filename, validate_inner_monologue_contents=False) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + # ====================================================================================================================== # TOGETHER TESTS # ====================================================================================================================== diff --git a/tests/test_providers.py b/tests/test_providers.py index 5dd99fbe..0394dec0 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -4,6 +4,7 @@ from letta.schemas.providers import ( AnthropicBedrockProvider, AnthropicProvider, AzureProvider, + DeepSeekProvider, GoogleAIProvider, GoogleVertexProvider, GroqProvider, @@ -23,6 +24,14 @@ def test_openai(): print(models) +def test_deepseek(): + api_key = os.getenv("DEEPSEEK_API_KEY") + assert api_key is not None + provider = DeepSeekProvider(api_key=api_key) + models = provider.list_llm_models() + print(models) + + def test_anthropic(): api_key = os.getenv("ANTHROPIC_API_KEY") assert api_key is not None