diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 0d423677..f5083e7f 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -111,7 +111,6 @@ def create( # streaming? stream: bool = False, stream_interface: Optional[Union[AgentRefreshStreamingInterface, AgentChunkStreamingInterface]] = None, - max_tokens: Optional[int] = None, model_settings: Optional[dict] = None, # TODO: eventually pass from server ) -> ChatCompletionResponse: """Return response to chat completion with backoff""" @@ -157,7 +156,7 @@ def create( else: function_call = "required" - data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens) + data = build_openai_chat_completions_request(llm_config, messages, user_id, functions, function_call, use_tool_naming) if stream: # Client requested token streaming data.stream = True assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( @@ -212,7 +211,7 @@ def create( # For Azure, this model_endpoint is required to be configured via env variable, so users don't need to provide it in the LLM config llm_config.model_endpoint = model_settings.azure_base_url chat_completion_request = build_openai_chat_completions_request( - llm_config, messages, user_id, functions, function_call, use_tool_naming, max_tokens + llm_config, messages, user_id, functions, function_call, use_tool_naming ) response = azure_openai_chat_completions_request( @@ -248,7 +247,7 @@ def create( data=dict( contents=[m.to_google_ai_dict() for m in messages], tools=tools, - generation_config={"temperature": llm_config.temperature}, + generation_config={"temperature": llm_config.temperature, "max_output_tokens": llm_config.max_tokens}, ), inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs, ) @@ -268,7 +267,7 @@ def create( messages=[cast_message_to_subtype(m.to_openai_dict()) for m in messages], tools=([{"type": "function", "function": f} for f in functions] if functions else None), tool_choice=tool_call, - max_tokens=1024, # TODO make dynamic + max_tokens=llm_config.max_tokens, # Note: max_tokens is required for Anthropic API temperature=llm_config.temperature, stream=stream, ) @@ -416,7 +415,7 @@ def create( tool_choice=tool_call, # user=str(user_id), # NOTE: max_tokens is required for Anthropic API - max_tokens=1024, # TODO make dynamic + max_tokens=llm_config.max_tokens, ), ) diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 734e7b23..c6762872 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -94,7 +94,6 @@ def build_openai_chat_completions_request( functions: Optional[list], function_call: Optional[str], use_tool_naming: bool, - max_tokens: Optional[int], ) -> ChatCompletionRequest: if functions and llm_config.put_inner_thoughts_in_kwargs: # Special case for LM Studio backend since it needs extra guidance to force out the thoughts first @@ -131,7 +130,7 @@ def build_openai_chat_completions_request( tools=[Tool(type="function", function=f) for f in functions] if functions else None, tool_choice=tool_choice, user=str(user_id), - max_completion_tokens=max_tokens, + max_completion_tokens=llm_config.max_tokens, temperature=llm_config.temperature, ) else: @@ -141,7 +140,7 @@ def build_openai_chat_completions_request( functions=functions, function_call=function_call, user=str(user_id), - max_completion_tokens=max_tokens, + max_completion_tokens=llm_config.max_tokens, temperature=llm_config.temperature, ) # https://platform.openai.com/docs/guides/text-generation/json-mode diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 6e87e629..e3877389 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -15,6 +15,7 @@ class LLMConfig(BaseModel): context_window (int): The context window size for the model. put_inner_thoughts_in_kwargs (bool): Puts `inner_thoughts` as a kwarg in the function call if this is set to True. This helps with function calling performance and also the generation of inner thoughts. temperature (float): The temperature to use when generating text with the model. A higher temperature will result in more random text. + max_tokens (int): The maximum number of tokens to generate. """ # TODO: 🤮 don't default to a vendor! bug city! @@ -51,6 +52,10 @@ class LLMConfig(BaseModel): 0.7, description="The temperature to use when generating text with the model. A higher temperature will result in more random text.", ) + max_tokens: Optional[int] = Field( + 1024, + description="The maximum number of tokens to generate. If not set, the model will use its default value.", + ) # FIXME hack to silence pydantic protected namespace warning model_config = ConfigDict(protected_namespaces=())