From 807c5c18d9cec387d6208153190e1623c372a9c8 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 28 Nov 2025 19:49:02 -0800 Subject: [PATCH] feat: add gemini token counting [LET-6371] (#6444) --- letta/llm_api/google_vertex_client.py | 53 +++- letta/schemas/message.py | 4 +- letta/services/agent_manager.py | 57 +++-- .../token_counter.py | 49 ++++ .../data/__pycache__/1_to_100.cpython-310.pyc | Bin 1712 -> 0 bytes .../__pycache__/data_analysis.cpython-310.pyc | Bin 11094 -> 0 bytes .../__pycache__/dump_json.cpython-310.pyc | Bin 533 -> 0 bytes tests/integration_test_token_counters.py | 235 ++++++++++++++++++ 8 files changed, 379 insertions(+), 19 deletions(-) delete mode 100644 tests/data/__pycache__/1_to_100.cpython-310.pyc delete mode 100644 tests/data/__pycache__/data_analysis.cpython-310.pyc delete mode 100644 tests/data/functions/__pycache__/dump_json.cpython-310.pyc create mode 100644 tests/integration_test_token_counters.py diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 26b63f44..c5c7b60b 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -37,7 +37,7 @@ from letta.otel.tracing import trace_method from letta.schemas.agent import AgentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage -from letta.schemas.openai.chat_completion_request import Tool +from letta.schemas.openai.chat_completion_request import Tool, Tool as OpenAITool from letta.schemas.openai.chat_completion_response import ChatCompletionResponse, Choice, FunctionCall, Message, ToolCall, UsageStatistics from letta.settings import model_settings, settings from letta.utils import get_tool_call_id @@ -832,3 +832,54 @@ class GoogleVertexClient(LLMClientBase): # Fallback to base implementation for other errors return super().handle_llm_error(e) + + async def count_tokens(self, messages: List[dict] = None, model: str = None, tools: List[OpenAITool] = None) -> int: + """ + Count tokens for the given messages and tools using the Gemini token counting API. + + Args: + messages: List of message dicts in Google AI format (with 'role' and 'parts' keys) + model: The model to use for token counting (defaults to gemini-2.0-flash-lite) + tools: List of OpenAI-style Tool objects to include in the count + + Returns: + The total token count for the input + """ + from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK + + client = self._get_client() + + # Default model for token counting if not specified + count_model = model or GOOGLE_MODEL_FOR_API_KEY_CHECK + + # Build the contents parameter + # If no messages provided, use empty string (like the API key check) + if messages is None or len(messages) == 0: + contents = "" + else: + # Messages should already be in Google format (role + parts) + contents = messages + + try: + # Count message tokens + result = await client.aio.models.count_tokens( + model=count_model, + contents=contents, + ) + total_tokens = result.total_tokens + + # Count tool tokens separately by serializing to text + # The Gemini count_tokens API doesn't support a tools parameter directly + if tools and len(tools) > 0: + # Serialize tools to JSON text and count those tokens + tools_text = json.dumps([t.model_dump() for t in tools]) + tools_result = await client.aio.models.count_tokens( + model=count_model, + contents=tools_text, + ) + total_tokens += tools_result.total_tokens + + except Exception as e: + raise self.handle_llm_error(e) + + return total_tokens diff --git a/letta/schemas/message.py b/letta/schemas/message.py index 3ee289d9..1f9b64bf 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -1795,7 +1795,9 @@ class Message(BaseMessage): parts.append(function_call_part) else: - if not native_content: + # Only add single text_content if we don't have multiple content items + # (multi-content case is handled below at the len(self.content) > 1 block) + if not native_content and not (self.content and len(self.content) > 1): assert text_content is not None parts.append({"text": text_content}) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 3f2d0e2d..194afff5 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -80,7 +80,7 @@ from letta.server.db import db_registry from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager, validate_block_limit_constraint from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator -from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, TiktokenCounter +from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, GeminiTokenCounter, TiktokenCounter from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.files_agents_manager import FileAgentManager from letta.services.helpers.agent_manager_helper import ( @@ -3286,37 +3286,60 @@ class AgentManager: ) calculator = ContextWindowCalculator() + # Determine which token counter to use based on provider + model_endpoint_type = agent_state.llm_config.model_endpoint_type + + # Use Gemini token counter for Google Vertex and Google AI + use_gemini = model_endpoint_type in ("google_vertex", "google_ai") + # Use Anthropic token counter if: # 1. The model endpoint type is anthropic, OR - # 2. We're in PRODUCTION and anthropic_api_key is available - use_anthropic = agent_state.llm_config.model_endpoint_type == "anthropic" or ( - settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None + # 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini) + use_anthropic = model_endpoint_type == "anthropic" or ( + not use_gemini and settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None ) - if use_anthropic: + if use_gemini: + # Use native Gemini token counting API + + client = LLMClient.create(provider_type=agent_state.llm_config.model_endpoint_type, actor=actor) + model = agent_state.llm_config.model + + token_counter = GeminiTokenCounter(client, model) + logger.info( + f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, " + f"model_endpoint_type={model_endpoint_type}, " + f"environment={settings.environment}" + ) + elif use_anthropic: anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) - model = agent_state.llm_config.model if agent_state.llm_config.model_endpoint_type == "anthropic" else None + model = agent_state.llm_config.model if model_endpoint_type == "anthropic" else None token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa logger.info( f"Using AnthropicTokenCounter for agent_id={agent_id}, model={model}, " - f"model_endpoint_type={agent_state.llm_config.model_endpoint_type}, " + f"model_endpoint_type={model_endpoint_type}, " f"environment={settings.environment}" ) else: token_counter = TiktokenCounter(agent_state.llm_config.model) logger.info( f"Using TiktokenCounter for agent_id={agent_id}, model={agent_state.llm_config.model}, " - f"model_endpoint_type={agent_state.llm_config.model_endpoint_type}, " + f"model_endpoint_type={model_endpoint_type}, " f"environment={settings.environment}" ) - return await calculator.calculate_context_window( - agent_state=agent_state, - actor=actor, - token_counter=token_counter, - message_manager=self.message_manager, - system_message_compiled=system_message, - num_archival_memories=num_archival_memories, - num_messages=num_messages, - ) + try: + result = await calculator.calculate_context_window( + agent_state=agent_state, + actor=actor, + token_counter=token_counter, + message_manager=self.message_manager, + system_message_compiled=system_message, + num_archival_memories=num_archival_memories, + num_messages=num_messages, + ) + except Exception as e: + raise e + + return result diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 12833fda..33c9a70f 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -5,6 +5,7 @@ from typing import Any, Dict, List from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient +from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.otel.tracing import trace_method from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import Tool as OpenAITool @@ -77,6 +78,54 @@ class AnthropicTokenCounter(TokenCounter): return Message.to_anthropic_dicts_from_list(messages, current_model=self.model) +class GeminiTokenCounter(TokenCounter): + """Token counter using Google's Gemini token counting API""" + + def __init__(self, gemini_client: GoogleVertexClient, model: str): + self.client = gemini_client + self.model = model + + @trace_method + @async_redis_cache( + key_func=lambda self, text: f"gemini_text_tokens:{self.model}:{hashlib.sha256(text.encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) + async def count_text_tokens(self, text: str) -> int: + if not text: + return 0 + # For text counting, wrap in a simple user message format for Google + return await self.client.count_tokens(model=self.model, messages=[{"role": "user", "parts": [{"text": text}]}]) + + @trace_method + @async_redis_cache( + key_func=lambda self, + messages: f"gemini_message_tokens:{self.model}:{hashlib.sha256(json.dumps(messages, sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) + async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: + if not messages: + return 0 + return await self.client.count_tokens(model=self.model, messages=messages) + + @trace_method + @async_redis_cache( + key_func=lambda self, + tools: f"gemini_tool_tokens:{self.model}:{hashlib.sha256(json.dumps([t.model_dump() for t in tools], sort_keys=True).encode()).hexdigest()[:16]}", + prefix="token_counter", + ttl_s=3600, # cache for 1 hour + ) + async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: + if not tools: + return 0 + return await self.client.count_tokens(model=self.model, tools=tools) + + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: + google_messages = Message.to_google_dicts_from_list(messages, current_model=self.model) + return google_messages + + class TiktokenCounter(TokenCounter): """Token counter using tiktoken""" diff --git a/tests/data/__pycache__/1_to_100.cpython-310.pyc b/tests/data/__pycache__/1_to_100.cpython-310.pyc deleted file mode 100644 index 431649a9e700c0c0a5af97ad7411a75581f9c968..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 1712 zcmdVa$x>7Y9EI^_6chwxP!tpyWTxpmUj;-FL_uU84wJ~CDvP9`gjAXL$V=okvdS}L zn-#B+RqmHnULm*Z{HkwNf2)60Up6=2Efvqu^zT31)Ranp@$%oNs(6_%cKvJd%_Wsw zr7Bme+BK?it!iDTI@hb-cXY=MYH*_(eOGsVPxsuUCO50u_jTVbYH_Pt-KI7_&;z%t z-5u(1r#juGEd;3z^0AJ6q7$F$ z)L-<)XFBs&ef7D{?XA-2#Vt3wt%l-)rQ2GlgL=3F4bTXj`}Q6*K{MQk7HEYwcmVCt z0iDnV58)AX!((^?J`eihv@Z2q*%IfFhs>C<2O*BBTf@LW+W zB8r$IridwGikKp%h$&)`_DN>4* zBBjVEGK!2MqsS;Sii{$o$S5+3oFb>lDRPROBB#hHa*CWH&sAGoaQe5>w~ONH|MGX= z+11r``TX+RMfv99`sTXqvzxPW<@DzAwD{@N)h~ZG%=~oepMSiVA2ZCL032j&oj$ZPspTn|>5E?b(3#Bisnb88Pkq>@Hl1m!KJ=k|={UFkzH@*D zC^5BNkbCyr_uRhUcRPc*IakB)>9c>gx$qTD`xi>g{z^z(z!U!i0;c&I)0xrJeI0MJ zXY@_q)G5#ES$)g5WW3ZX^=;plal2RUJHFGe_!XHpdahrUeboH9*ECjU&IcNExV>xo zb#5Z3!d&FIyT-2W&m&!BHKc31hTlMHj@6N>BXtC+dDcLxfz$#!$rkugc9hR`7dsYP zWXC=z`N!B2JN|*@FY)8-#0^7q_|mS1aTsgzX`V{&x|f3_@GeF{Z#xcS?~MT)^|*K1 zyU1<^QJXU_PJ$$ildv81ydYv;KWJ};5%+pL5K$Oyc!MDqs4|FR_tGFrf-s7`n}G-i zqu68A`A`hnJdUX@br4Lx)fq)?+0N?>geQ269^0GVa3qFZjcc z;ur}B37&*~j<~fFjrvU^w=Re6BzN8%Qh!0OO$}#%IuaM~#7`mM8n#dK4W|1hGklAg zeu-JW%}Rcm*}lWdYg*IEYm-G;-yZVxS^Cm=qTZn6v9QzOf=9`8J_+i@XRy6IZi^5@ zxyk)%dmsew$@TGF!8b6?7-RI^c5f6XTwqo5D$S-9Bf;aa~w2@uCys~!b+SRw#uYP+auU>ld+O?I}FRovGbJcgRt-N+^Wo=Ey zFTMWu+WN}1tE;d1)%B}ythCmUSXsNLiFquE3_gB37!Kp-F=GCF!s8@_;Ykw_R#8^sCR)Hd=J57_1^~B_$Ie_i5cfLZsTp@UFHt2 zbc~Z4($+Zbj?<;{8uq8kYbYxt?)Y=O&K+5|LTT)qo7X4%c8$lQUXp%(d@8w+=WH?) zI&dhC&*XEHtQPHHUZLHz_Ilp9jOBSt@G-RyL>jvueg%-w@&OAb z6$lPSN!SaMZBLT+8Cm`7fQTNCk;Maom*>1h1W^ox=*!i6PGU%h2gyiqk4GE8B1uK6 z`M+Ue`2EwX`H_AoF{hPmVy8C;J=Qi+ z^Wz7H5Vt@ly^u2ko|yMKvewKV(S{Hahf$ccTBnivZ(4a3|IH0qT19A>J#4>e|eWbW&_mX*4d#O+o$bXK~fcWXN~Q_BGKSr|bLLPDRDL^I3)X+Z_Y_s_#`Y2j6nlx3(^1W%Tv@Lw)_( zxUa^();$uxKa*Nu0ES3s)g8C^hI{BYYU6KoFvHTWzEjOgdq*+0QdU(tC$ybfR+Fpo z*W=O6Wi^bhu39FGY(6uPZeRtp54Df5cCqpYRlt-0L_oYD}M|zV-G0IgW(W$Q^y+@+|&JX7=a5hp4)@#UEWUCn=)_M3)(y_UkQ6X zYEkoUg~_JZ=PV2&zx;M|BO2U_Qg@s^U_5oM@&4c@Ri5|Q2(-}#i}u9eR-B%o(G}E; zkN1Z?p73QawY<}>c+F#ZDc%eKyLK2yqh1eyIO^wSk4K93N@2wK?c8FRfOs5p$!$;{ zk61nj(!-!zO!D2_h=#ch9(^6C{0!;|!p7Vl58LiwHc^mUV5dAU@%}K`&aJRFX#4Zj zb*mV5oR@~+%1LuhJVD(&jli!DLPQ>p_cqX z*nwPuQ_*5r^%pSm7(%UPyNf2Vg_?0fUqH-}sRghJ<9&)hrBC7xYB|QDz66e*K0;eJ zi|A*jeIEcQ2j8Ga0@QyZ0%+1eY({2u^~C6!fDb@W*V-|YQf5jJDv37}o8o|g7v42? zERZGuikZg&Mn3?VBmXJDkCBy5Y6$I<8Y>+K!=oIKrJdOb%O62Tn5}dCV7Wu(TTgBs zp8${qIa5Md8STn|u*yCl380t|@Toi!x=GZ&V^JF4ng0kf%D9gP<$PpNPg~#0AW8!e z?iHkWNq{wRFXFenF_VhDju`Ykew(CdFg@Cu&k6~jfCRKr3OXdIPPifnax#WueBxQ` z&wb*2_X!zRH@o2Yeoaa@Ra_9Zv3id5vh0}72|!&NS}C!I3DV1^`0)Mr>1+T zWO$0k^CSgV5ag9&6UA>+<~J#r4p+u#%;IwhXsKe_0TRC(r6uds3o^frtOGwH%L{m7 z`o^rp$1P^?QrCvuM-m>%ec*7ugZ%2aq|7U_)RyuexF~?l#c{9YN7trGp_H~$F47C~ zjLnoxH6v2{Hn)d^WRsJUNy^1k{Dab?WH$-5;tLpPn~vt>M`w)WX^AfXg-H4iC{auF zuAy*cM~9eWic^V|U_uE5n|&?BA2XQAtULBQkmi6$pz}Rw25~RXZd{e^BmGs2wz#B8diTnX&Y!tD3wphTHilY=PQgDm{ zQi7!Vv)Y8h*C&o9d1LImal$6?xmmUu`%Zj=--@7(<<5OL zlcwdn2^6w1W1UlRpascH`TPviri>}SSEIOR9X9Ibu?G>2n6K2tq#7tlke8}oZbF2D zEP{FEWs(}mdI3$4i7%syZ_z}=TU49iPyyqYDE1tJys~n;&84M7oJ6`=l7y4HlR>l2 zbR~10&wz-;yrr4_Y8BR<_=o69qJ&Xx=(b@%pSN{m{O5e^x-Qfg`ZIL|R{ z)|f$4}__q54i>sVQM?I_n@{j`IZQCB!FR1E9mwuE=c7Du@>g5DV$&D z$t}(!Y2^#9LtFpMAVZqZGP``0#uNYr87>vbB6$!m|Y+7CK019-o;D2jNkwlBaTK=*BYLf zxN-*j*GVLs#;W|FB-49(uHSx;V4d^5pnshOulzH$rA{0zT|SJwW~Mf^*qw)x6ZGO& z&P=*lzh~s;4ZfY11;5Ef%vZ&$=(K723$q=--Xg-=`Qm1{0k)k$T`Be_CQK$mZH;0w zgZK+%@RdhFeK+OIuPRI(2BC<>HTu9q_&7y3GG-}|B%`9TkI_<*>?U9!gAHhdi?32n z35sG6NtqB9v+FTzFawc@w`J`fglzE+bs?c~PRSH6G}ez8!tO^L6+T(nOZ2ip|5`s}COco@#b@8bK->S5w%UX#$@R{JPLQ1J%MP40wG zF-!f;t*b<)T3g>~3)*Q7yx!h6z}B)Cz5r%dfrKqL7SaL3yq z^g{>)&)GI)X#_h}zo+R9~V-dZchlhe4SULy)&JBMJ|hFeDD}^5;mihoPVoCJE7v zQeUWOW{Ig6sTcf`BIK}0A9;8>Q?sG1W}^gW8q4HYMuc<(Ikta9B>e_ho{<1}5~$^5 zGi1gojhXN@5tBOw5mx4ACe-033Dpo{EwD-}vGz)t1;*CQj6JZz47Osa0MC<{`~j5i z%Mg;G!k5TuEPYz=V@dW_B4t`Kwvt!@Yn@XTyv!grHx`eb?Jk%XJZ9-)QrUB1d4qDV zRBYA&L$-JIP=5zTPzcA>9s4~StV_NX>q>uG@U6?REwvMDZ~$=v9;;!1l@{g z{mIsUq7tqab=jLeFVyH07!LWN$OR>~m25m22wAzqsIX!e;)SK zIERUyBlEH39CRGWM*ZD7Y$iHwW(G$g)39OKy3FVpe`;*OCLzee5B+lkFR13mnX_UG zQNbvflCy*__dKB{S>|3jhA~JE^OTmpbjYD&a(UsZQp<~3fvXhkWzx%m;S=MdDahdy zR$$1I_3i44`4OuWf`ac-aBy&M{1Ei0wO;^EF3xl2aIzb4t>9PkSBe{o~HNo-8 z-0X*ul;(41(w{PYbZ9_lDcGcdcvtS+QodF2Hi{`;R#s>+Lrj@p@bFIRw92G?8?I$MO6Ec^raG)0bucQ|L`-*nS+^TtS`hGml|KEc zcH%x+Kyv#Q%(#PFZ8`{(A3;+APd`H_Y#h{3-wqpk@?huoPw`(8oDu!F5pDDFcM;8>YAbTO>+O0JyT!Bji@F zA~8|9|5LL5jTki%t?8{+e%ZS==)oDWF;07Nm~!b4RiB51D2)m@ z!JiS~RQtoQ4f02ZZC$-jyY9PiHo!JY*8pHg;*+?d#Zo>m`++75E=&lT*Hs2+5{9*8 zGp`9X0r*>?HCHDoD9yQhb3h(;awVq6*X42s#U+C=If@Ym(X9IAaYM0=wD>j!BvmR! zo6u4Ul1hZUP04Kvq%i4D2~bWq!woAgH#V1(}Q`h9xM9VQ4W zG4DwlBH0LidrgW>(mO4EC?Z96Zo&vGcB$4Bpk<6$i0C8uA3Rf=r1hxvg~w(7{eQ>) z=ar}x!=(;o@o)?vF`0X}M0P?}g&(1|IO!GQ2(^XR09yBXnG1)hmQ)~uXp2Aw zNSsIRnB>WTfOG@`5`B}tvhWj(4uX+J_Jj=X)(p27w!R7LV>LqwSliL`uOzL%#v{5n z0?|XxU)uj;o+#oQTS6y0<`s8h#)N;+W>>7?vbQQdU33#=ysYFk+!%jKk?=zXyfrm| z!q6;-cmIo6@cSee+?V@#9J%^WXLrG+VKFnuO+`T}H(W#fIf~nOiXU-E8;_(75`;*6Cs%X= z;TL-h2uwp>79sM0&q4X+3mTjH8rq7x;K)Zyr}MpJaC02hV%Q0%QM#cVq_EVxHLPZ zm5ap0U`p}Cx2U_ugBGatu}f2%%v75#dP(goWMlUPov8FAczz)n&Kj=bDLtk~iK$GK zB*{&dJ;}s!kdAWv*S=kFTa^7Fx|VQ}nrF^+R-23Bw|mJ$jM;p(|_UR^$X1yfe40q*#fNHM5krt-R(X}CbKYW>~%Xld2)oG-b$_3Qry DpqD2W diff --git a/tests/data/functions/__pycache__/dump_json.cpython-310.pyc b/tests/data/functions/__pycache__/dump_json.cpython-310.pyc deleted file mode 100644 index 1a3446d4f0273f9d71c570c9264f8f439abfa96e..0000000000000000000000000000000000000000 GIT binary patch literal 0 HcmV?d00001 literal 533 zcmYjOy-ve05VrG28sG((ydV)FC9qZqQNhGeB)VlmA$HnSX&l)OQV}B$0wV)2!DD!3 z;uV;p9J0LzSIx3jxxmbm&v&Gpzt*zI<)6W&xN9W!)| z>`Jv}@^L9DHw!RNLb)O_|q2=U74tgAa(2^=mu&&+tt&o+?Qc v9w`kKek%4`VN98ocx#1q>jPug;`Z&B5mw`imE=P!F5%&cMu?akct`#poneED diff --git a/tests/integration_test_token_counters.py b/tests/integration_test_token_counters.py new file mode 100644 index 00000000..f105ad02 --- /dev/null +++ b/tests/integration_test_token_counters.py @@ -0,0 +1,235 @@ +""" +Integration tests for token counting APIs. + +These tests verify that the token counting implementations actually hit the real APIs +for Anthropic, Google Gemini, and OpenAI (tiktoken) by calling get_context_window +on an imported agent. +""" + +import json +import os + +import pytest + +from letta.config import LettaConfig +from letta.orm import Base +from letta.schemas.agent import UpdateAgent +from letta.schemas.agent_file import AgentFileSchema +from letta.schemas.llm_config import LLMConfig +from letta.schemas.organization import Organization +from letta.schemas.user import User +from letta.server.server import SyncServer + +# ============================================================================ +# LLM Configs to test +# ============================================================================ + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + """Load LLM configuration from JSON file.""" + filename = os.path.join(llm_config_dir, filename) + with open(filename, "r") as f: + config_data = json.load(f) + return LLMConfig(**config_data) + + +LLM_CONFIG_FILES = [ + "openai-gpt-4o-mini.json", + "claude-4-5-sonnet.json", + "gemini-2.5-pro.json", +] + +LLM_CONFIGS = [pytest.param(get_llm_config(f), id=f.replace(".json", "")) for f in LLM_CONFIG_FILES] + + +# ============================================================================ +# Fixtures +# ============================================================================ + + +async def _clear_tables(): + from letta.server.db import db_registry + + async with db_registry.async_session() as session: + for table in reversed(Base.metadata.sorted_tables): + await session.execute(table.delete()) + await session.commit() + + +@pytest.fixture(autouse=True) +async def clear_tables(): + await _clear_tables() + + +@pytest.fixture +async def server(): + config = LettaConfig.load() + config.save() + server = SyncServer(init_with_default_org_and_user=True) + await server.init_async() + await server.tool_manager.upsert_base_tools_async(actor=server.default_user) + yield server + + +@pytest.fixture +async def default_organization(server: SyncServer): + """Fixture to create and return the default organization.""" + org = await server.organization_manager.create_default_organization_async() + yield org + + +@pytest.fixture +async def default_user(server: SyncServer, default_organization): + """Fixture to create and return the default user within the default organization.""" + user = await server.user_manager.create_default_actor_async(org_id=default_organization.id) + yield user + + +@pytest.fixture +async def other_organization(server: SyncServer): + """Fixture to create and return another organization.""" + org = await server.organization_manager.create_organization_async(pydantic_org=Organization(name="test_org")) + yield org + + +@pytest.fixture +async def other_user(server: SyncServer, other_organization): + """Fixture to create and return another user within the other organization.""" + user = await server.user_manager.create_actor_async(pydantic_user=User(organization_id=other_organization.id, name="test_user")) + yield user + + +@pytest.fixture +async def imported_agent_id(server: SyncServer, other_user): + """Import the test agent from the .af file and return the agent ID.""" + file_path = os.path.join(os.path.dirname(__file__), "test_agent_files", "test_agent.af") + + with open(file_path, "r") as f: + agent_file_json = json.load(f) + + agent_schema = AgentFileSchema.model_validate(agent_file_json) + + import_result = await server.agent_serialization_manager.import_file( + schema=agent_schema, + actor=other_user, + append_copy_suffix=False, + override_existing_tools=True, + ) + + assert import_result.success, f"Failed to import agent: {import_result.message}" + + # Get the imported agent ID + agent_id = next(db_id for file_id, db_id in import_result.id_mappings.items() if file_id.startswith("agent-")) + yield agent_id + + +# ============================================================================ +# Token Counter Integration Test +# ============================================================================ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_get_context_window(server: SyncServer, imported_agent_id: str, other_user, llm_config: LLMConfig): + """Test get_context_window with different LLM providers.""" + # Update the agent to use the specified LLM config + await server.agent_manager.update_agent_async( + agent_id=imported_agent_id, + agent_update=UpdateAgent(llm_config=llm_config), + actor=other_user, + ) + + # Call get_context_window which will use the appropriate token counting API + context_window = await server.agent_manager.get_context_window(agent_id=imported_agent_id, actor=other_user) + + # Verify we got valid token counts + assert context_window.context_window_size_current > 0 + assert context_window.num_tokens_system >= 0 + assert context_window.num_tokens_messages >= 0 + assert context_window.num_tokens_functions_definitions >= 0 + + print(f"{llm_config.model_endpoint_type} ({llm_config.model}) context window:") + print(f" Total tokens: {context_window.context_window_size_current}") + print(f" System tokens: {context_window.num_tokens_system}") + print(f" Message tokens: {context_window.num_tokens_messages}") + print(f" Function tokens: {context_window.num_tokens_functions_definitions}") + + +# ============================================================================ +# Edge Case Tests +# ============================================================================ + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_count_empty_text_tokens(llm_config: LLMConfig): + """Test that empty text returns 0 tokens for all providers.""" + from letta.llm_api.anthropic_client import AnthropicClient + from letta.llm_api.google_ai_client import GoogleAIClient + from letta.llm_api.google_vertex_client import GoogleVertexClient + from letta.services.context_window_calculator.token_counter import ( + AnthropicTokenCounter, + GeminiTokenCounter, + TiktokenCounter, + ) + + if llm_config.model_endpoint_type == "anthropic": + token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model) + elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"): + client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() + token_counter = GeminiTokenCounter(client, llm_config.model) + else: + token_counter = TiktokenCounter(llm_config.model) + + token_count = await token_counter.count_text_tokens("") + assert token_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_count_empty_messages_tokens(llm_config: LLMConfig): + """Test that empty message list returns 0 tokens for all providers.""" + from letta.llm_api.anthropic_client import AnthropicClient + from letta.llm_api.google_ai_client import GoogleAIClient + from letta.llm_api.google_vertex_client import GoogleVertexClient + from letta.services.context_window_calculator.token_counter import ( + AnthropicTokenCounter, + GeminiTokenCounter, + TiktokenCounter, + ) + + if llm_config.model_endpoint_type == "anthropic": + token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model) + elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"): + client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() + token_counter = GeminiTokenCounter(client, llm_config.model) + else: + token_counter = TiktokenCounter(llm_config.model) + + token_count = await token_counter.count_message_tokens([]) + assert token_count == 0 + + +@pytest.mark.asyncio +@pytest.mark.parametrize("llm_config", LLM_CONFIGS) +async def test_count_empty_tools_tokens(llm_config: LLMConfig): + """Test that empty tools list returns 0 tokens for all providers.""" + from letta.llm_api.anthropic_client import AnthropicClient + from letta.llm_api.google_ai_client import GoogleAIClient + from letta.llm_api.google_vertex_client import GoogleVertexClient + from letta.services.context_window_calculator.token_counter import ( + AnthropicTokenCounter, + GeminiTokenCounter, + TiktokenCounter, + ) + + if llm_config.model_endpoint_type == "anthropic": + token_counter = AnthropicTokenCounter(AnthropicClient(), llm_config.model) + elif llm_config.model_endpoint_type in ("google_vertex", "google_ai"): + client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() + token_counter = GeminiTokenCounter(client, llm_config.model) + else: + token_counter = TiktokenCounter(llm_config.model) + + token_count = await token_counter.count_tool_tokens([]) + assert token_count == 0