chore: delete legacy bedrock client (#3912)
This commit is contained in:
@@ -1,90 +0,0 @@
|
||||
"""
|
||||
Note that this formally only supports Anthropic Bedrock.
|
||||
TODO (cliandy): determine what other providers are supported and what is needed to add support.
|
||||
"""
|
||||
|
||||
from typing import Any, Optional
|
||||
|
||||
from anthropic import AnthropicBedrock
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.settings import model_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def get_bedrock_client(
|
||||
access_key_id: Optional[str] = None,
|
||||
secret_key: Optional[str] = None,
|
||||
default_region: Optional[str] = None,
|
||||
):
|
||||
"""
|
||||
Get a Bedrock client
|
||||
"""
|
||||
import boto3
|
||||
|
||||
sts_client = boto3.client(
|
||||
"sts",
|
||||
aws_access_key_id=access_key_id or model_settings.aws_access_key_id,
|
||||
aws_secret_access_key=secret_key or model_settings.aws_secret_access_key,
|
||||
region_name=default_region or model_settings.aws_default_region,
|
||||
)
|
||||
credentials = sts_client.get_session_token()["Credentials"]
|
||||
|
||||
bedrock = AnthropicBedrock(
|
||||
aws_access_key=credentials["AccessKeyId"],
|
||||
aws_secret_key=credentials["SecretAccessKey"],
|
||||
aws_session_token=credentials["SessionToken"],
|
||||
aws_region=default_region or model_settings.aws_default_region,
|
||||
)
|
||||
return bedrock
|
||||
|
||||
|
||||
async def bedrock_get_model_list_async(
|
||||
access_key_id: Optional[str] = None,
|
||||
secret_access_key: Optional[str] = None,
|
||||
default_region: Optional[str] = None,
|
||||
) -> list[dict]:
|
||||
from aioboto3.session import Session
|
||||
|
||||
try:
|
||||
session = Session()
|
||||
async with session.client(
|
||||
"bedrock",
|
||||
aws_access_key_id=access_key_id,
|
||||
aws_secret_access_key=secret_access_key,
|
||||
region_name=default_region,
|
||||
) as bedrock:
|
||||
response = await bedrock.list_inference_profiles()
|
||||
return response["inferenceProfileSummaries"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model list for bedrock: %s", e)
|
||||
raise e
|
||||
|
||||
|
||||
def bedrock_get_model_details(region_name: str, model_id: str) -> dict[str, Any]:
|
||||
"""
|
||||
Get details for a specific model from Bedrock.
|
||||
"""
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
|
||||
try:
|
||||
bedrock = boto3.client("bedrock", region_name=region_name)
|
||||
response = bedrock.get_foundation_model(modelIdentifier=model_id)
|
||||
return response["modelDetails"]
|
||||
except ClientError as e:
|
||||
logger.exception(f"Error getting model details: {str(e)}")
|
||||
raise e
|
||||
|
||||
|
||||
def bedrock_get_model_context_window(model_id: str) -> int:
|
||||
"""
|
||||
Get context window size for a specific model.
|
||||
"""
|
||||
# Bedrock doesn't provide this via API, so we maintain a mapping
|
||||
# 200k for anthropic: https://aws.amazon.com/bedrock/anthropic/
|
||||
if model_id.startswith("anthropic"):
|
||||
return 200_000
|
||||
else:
|
||||
return 100_000 # default to 100k if unknown
|
||||
@@ -20,20 +20,32 @@ class BedrockProvider(Provider):
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
region: str = Field(..., description="AWS region for Bedrock")
|
||||
|
||||
async def bedrock_get_model_list_async(self) -> list[dict]:
|
||||
from aioboto3.session import Session
|
||||
|
||||
try:
|
||||
session = Session()
|
||||
async with session.client(
|
||||
"bedrock",
|
||||
aws_access_key_id=self.access_key,
|
||||
aws_secret_access_key=self.api_key,
|
||||
region_name=self.region,
|
||||
) as bedrock:
|
||||
response = await bedrock.list_inference_profiles()
|
||||
return response["inferenceProfileSummaries"]
|
||||
except Exception as e:
|
||||
logger.error(f"Error getting model list for bedrock: %s", e)
|
||||
raise e
|
||||
|
||||
async def check_api_key(self):
|
||||
"""Check if the Bedrock credentials are valid"""
|
||||
from letta.errors import LLMAuthenticationError
|
||||
from letta.llm_api.aws_bedrock import bedrock_get_model_list_async
|
||||
|
||||
try:
|
||||
# For BYOK providers, use the custom credentials
|
||||
if self.provider_category == ProviderCategory.byok:
|
||||
# If we can list models, the credentials are valid
|
||||
await bedrock_get_model_list_async(
|
||||
access_key_id=self.access_key,
|
||||
secret_access_key=self.api_key, # api_key stores the secret access key
|
||||
region_name=self.region,
|
||||
)
|
||||
await self.bedrock_get_model_list_async()
|
||||
else:
|
||||
# For base providers, use default credentials
|
||||
bedrock_get_model_list(region_name=self.region)
|
||||
@@ -41,13 +53,7 @@ class BedrockProvider(Provider):
|
||||
raise LLMAuthenticationError(message=f"Failed to authenticate with Bedrock: {e}")
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.aws_bedrock import bedrock_get_model_list_async
|
||||
|
||||
models = await bedrock_get_model_list_async(
|
||||
self.access_key,
|
||||
self.api_key,
|
||||
self.region,
|
||||
)
|
||||
models = await self.bedrock_get_model_list_async()
|
||||
|
||||
configs = []
|
||||
for model_summary in models:
|
||||
@@ -67,10 +73,16 @@ class BedrockProvider(Provider):
|
||||
return configs
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> int | None:
|
||||
# Context windows for Claude models
|
||||
from letta.llm_api.aws_bedrock import bedrock_get_model_context_window
|
||||
"""
|
||||
Get context window size for a specific model.
|
||||
|
||||
return bedrock_get_model_context_window(model_name)
|
||||
Bedrock doesn't provide this via API, so we maintain a mapping
|
||||
200k for anthropic: https://aws.amazon.com/bedrock/anthropic/
|
||||
"""
|
||||
if model_name.startswith("anthropic"):
|
||||
return 200_000
|
||||
else:
|
||||
return 100_000 # default to 100k if unknown
|
||||
|
||||
def get_handle(self, model_name: str, is_embedding: bool = False, base_name: str | None = None) -> str:
|
||||
logger.debug("Getting handle for model_name: %s", model_name)
|
||||
|
||||
Reference in New Issue
Block a user