feat: Add code file support to file uploads (#2702)
This commit is contained in:
@@ -21,16 +21,15 @@ from letta.server.server import SyncServer
|
||||
from letta.services.file_processor.chunker.llama_index_chunker import LlamaIndexChunker
|
||||
from letta.services.file_processor.embedder.openai_embedder import OpenAIEmbedder
|
||||
from letta.services.file_processor.file_processor import FileProcessor
|
||||
from letta.services.file_processor.file_types import get_allowed_media_types, get_extension_to_mime_type_map, register_mime_types
|
||||
from letta.services.file_processor.parser.mistral_parser import MistralFileParser
|
||||
from letta.settings import model_settings, settings
|
||||
from letta.utils import safe_create_task, sanitize_filename
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
mimetypes.add_type("text/markdown", ".md")
|
||||
mimetypes.add_type("text/markdown", ".markdown")
|
||||
mimetypes.add_type("application/jsonl", ".jsonl")
|
||||
mimetypes.add_type("application/x-jsonlines", ".jsonl")
|
||||
# Register all supported file types with Python's mimetypes module
|
||||
register_mime_types()
|
||||
|
||||
|
||||
router = APIRouter(prefix="/sources", tags=["sources"])
|
||||
@@ -179,15 +178,7 @@ async def upload_file_to_source(
|
||||
"""
|
||||
Upload a file to a data source.
|
||||
"""
|
||||
allowed_media_types = {
|
||||
"application/pdf",
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"application/json",
|
||||
"application/jsonl",
|
||||
"application/x-jsonlines",
|
||||
}
|
||||
allowed_media_types = get_allowed_media_types()
|
||||
|
||||
# Normalize incoming Content-Type header (strip charset or any parameters).
|
||||
raw_ct = file.content_type or ""
|
||||
@@ -201,21 +192,18 @@ async def upload_file_to_source(
|
||||
|
||||
if media_type not in allowed_media_types:
|
||||
ext = Path(file.filename).suffix.lower()
|
||||
ext_map = {
|
||||
".pdf": "application/pdf",
|
||||
".txt": "text/plain",
|
||||
".json": "application/json",
|
||||
".md": "text/markdown",
|
||||
".markdown": "text/markdown",
|
||||
".jsonl": "application/jsonl",
|
||||
}
|
||||
ext_map = get_extension_to_mime_type_map()
|
||||
media_type = ext_map.get(ext, media_type)
|
||||
|
||||
# If still not allowed, reject with 415.
|
||||
if media_type not in allowed_media_types:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_415_UNSUPPORTED_MEDIA_TYPE,
|
||||
detail=(f"Unsupported file type: {media_type or 'unknown'} " f"(filename: {file.filename}). Only PDF, .txt, or .json allowed."),
|
||||
detail=(
|
||||
f"Unsupported file type: {media_type or 'unknown'} "
|
||||
f"(filename: {file.filename}). "
|
||||
f"Supported types: PDF, text files (.txt, .md), JSON, and code files (.py, .js, .java, etc.)."
|
||||
),
|
||||
)
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
248
letta/services/file_processor/file_types.py
Normal file
248
letta/services/file_processor/file_types.py
Normal file
@@ -0,0 +1,248 @@
|
||||
"""
|
||||
Centralized file type configuration for supported file formats.
|
||||
|
||||
This module provides a single source of truth for file type definitions,
|
||||
mime types, and file processing capabilities across the Letta codebase.
|
||||
"""
|
||||
|
||||
import mimetypes
|
||||
from dataclasses import dataclass
|
||||
from typing import Dict, Set
|
||||
|
||||
|
||||
@dataclass
|
||||
class FileTypeInfo:
|
||||
"""Information about a supported file type."""
|
||||
|
||||
extension: str
|
||||
mime_type: str
|
||||
is_simple_text: bool
|
||||
description: str
|
||||
|
||||
|
||||
class FileTypeRegistry:
|
||||
"""Central registry for supported file types."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the registry with default supported file types."""
|
||||
self._file_types: Dict[str, FileTypeInfo] = {}
|
||||
self._register_default_types()
|
||||
|
||||
def _register_default_types(self) -> None:
|
||||
"""Register all default supported file types."""
|
||||
# Document formats
|
||||
self.register(".pdf", "application/pdf", False, "PDF document")
|
||||
self.register(".txt", "text/plain", True, "Plain text file")
|
||||
self.register(".md", "text/markdown", True, "Markdown document")
|
||||
self.register(".markdown", "text/markdown", True, "Markdown document")
|
||||
self.register(".json", "application/json", True, "JSON data file")
|
||||
self.register(".jsonl", "application/jsonl", True, "JSON Lines file")
|
||||
|
||||
# Programming languages
|
||||
self.register(".py", "text/x-python", True, "Python source code")
|
||||
self.register(".js", "text/javascript", True, "JavaScript source code")
|
||||
self.register(".ts", "text/x-typescript", True, "TypeScript source code")
|
||||
self.register(".java", "text/x-java-source", True, "Java source code")
|
||||
self.register(".cpp", "text/x-c++", True, "C++ source code")
|
||||
self.register(".cxx", "text/x-c++", True, "C++ source code")
|
||||
self.register(".c", "text/x-c", True, "C source code")
|
||||
self.register(".h", "text/x-c", True, "C/C++ header file")
|
||||
self.register(".cs", "text/x-csharp", True, "C# source code")
|
||||
self.register(".php", "text/x-php", True, "PHP source code")
|
||||
self.register(".rb", "text/x-ruby", True, "Ruby source code")
|
||||
self.register(".go", "text/x-go", True, "Go source code")
|
||||
self.register(".rs", "text/x-rust", True, "Rust source code")
|
||||
self.register(".swift", "text/x-swift", True, "Swift source code")
|
||||
self.register(".kt", "text/x-kotlin", True, "Kotlin source code")
|
||||
self.register(".scala", "text/x-scala", True, "Scala source code")
|
||||
self.register(".r", "text/x-r", True, "R source code")
|
||||
self.register(".m", "text/x-objective-c", True, "Objective-C source code")
|
||||
|
||||
# Web technologies
|
||||
self.register(".html", "text/html", True, "HTML document")
|
||||
self.register(".htm", "text/html", True, "HTML document")
|
||||
self.register(".css", "text/css", True, "CSS stylesheet")
|
||||
self.register(".scss", "text/x-scss", True, "SCSS stylesheet")
|
||||
self.register(".sass", "text/x-sass", True, "Sass stylesheet")
|
||||
self.register(".less", "text/x-less", True, "Less stylesheet")
|
||||
self.register(".vue", "text/x-vue", True, "Vue.js component")
|
||||
self.register(".jsx", "text/x-jsx", True, "JSX source code")
|
||||
self.register(".tsx", "text/x-tsx", True, "TSX source code")
|
||||
|
||||
# Configuration and data formats
|
||||
self.register(".xml", "application/xml", True, "XML document")
|
||||
self.register(".yaml", "text/x-yaml", True, "YAML configuration")
|
||||
self.register(".yml", "text/x-yaml", True, "YAML configuration")
|
||||
self.register(".toml", "application/toml", True, "TOML configuration")
|
||||
self.register(".ini", "text/x-ini", True, "INI configuration")
|
||||
self.register(".cfg", "text/x-conf", True, "Configuration file")
|
||||
self.register(".conf", "text/x-conf", True, "Configuration file")
|
||||
|
||||
# Scripts and SQL
|
||||
self.register(".sh", "text/x-shellscript", True, "Shell script")
|
||||
self.register(".bash", "text/x-shellscript", True, "Bash script")
|
||||
self.register(".ps1", "text/x-powershell", True, "PowerShell script")
|
||||
self.register(".bat", "text/x-batch", True, "Batch script")
|
||||
self.register(".cmd", "text/x-batch", True, "Command script")
|
||||
self.register(".dockerfile", "text/x-dockerfile", True, "Dockerfile")
|
||||
self.register(".sql", "text/x-sql", True, "SQL script")
|
||||
|
||||
def register(self, extension: str, mime_type: str, is_simple_text: bool, description: str) -> None:
|
||||
"""
|
||||
Register a new file type.
|
||||
|
||||
Args:
|
||||
extension: File extension (with leading dot, e.g., '.py')
|
||||
mime_type: MIME type for the file
|
||||
is_simple_text: Whether this is a simple text file that can be read directly
|
||||
description: Human-readable description of the file type
|
||||
"""
|
||||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
|
||||
self._file_types[extension] = FileTypeInfo(
|
||||
extension=extension, mime_type=mime_type, is_simple_text=is_simple_text, description=description
|
||||
)
|
||||
|
||||
def register_mime_types(self) -> None:
|
||||
"""Register all file types with Python's mimetypes module."""
|
||||
for file_type in self._file_types.values():
|
||||
mimetypes.add_type(file_type.mime_type, file_type.extension)
|
||||
|
||||
# Also register some additional MIME type aliases that may be encountered
|
||||
mimetypes.add_type("text/x-markdown", ".md")
|
||||
mimetypes.add_type("application/x-jsonlines", ".jsonl")
|
||||
mimetypes.add_type("text/xml", ".xml")
|
||||
|
||||
def get_allowed_media_types(self) -> Set[str]:
|
||||
"""
|
||||
Get set of all allowed MIME types.
|
||||
|
||||
Returns:
|
||||
Set of MIME type strings that are supported for upload
|
||||
"""
|
||||
allowed_types = {file_type.mime_type for file_type in self._file_types.values()}
|
||||
|
||||
# Add additional MIME type aliases
|
||||
allowed_types.update(
|
||||
{
|
||||
"text/x-markdown", # Alternative markdown MIME type
|
||||
"application/x-jsonlines", # Alternative JSONL MIME type
|
||||
"text/xml", # Alternative XML MIME type
|
||||
}
|
||||
)
|
||||
|
||||
return allowed_types
|
||||
|
||||
def get_extension_to_mime_type_map(self) -> Dict[str, str]:
|
||||
"""
|
||||
Get mapping from file extensions to MIME types.
|
||||
|
||||
Returns:
|
||||
Dictionary mapping extensions (with leading dot) to MIME types
|
||||
"""
|
||||
return {file_type.extension: file_type.mime_type for file_type in self._file_types.values()}
|
||||
|
||||
def get_simple_text_mime_types(self) -> Set[str]:
|
||||
"""
|
||||
Get set of MIME types that represent simple text files.
|
||||
|
||||
Returns:
|
||||
Set of MIME type strings for files that can be read as plain text
|
||||
"""
|
||||
return {file_type.mime_type for file_type in self._file_types.values() if file_type.is_simple_text}
|
||||
|
||||
def is_simple_text_mime_type(self, mime_type: str) -> bool:
|
||||
"""
|
||||
Check if a MIME type represents simple text that can be read directly.
|
||||
|
||||
Args:
|
||||
mime_type: MIME type to check
|
||||
|
||||
Returns:
|
||||
True if the MIME type represents simple text
|
||||
"""
|
||||
# Check if it's in our registered simple text types
|
||||
if mime_type in self.get_simple_text_mime_types():
|
||||
return True
|
||||
|
||||
# Check for text/* types
|
||||
if mime_type.startswith("text/"):
|
||||
return True
|
||||
|
||||
# Check for known aliases that represent simple text
|
||||
simple_text_aliases = {
|
||||
"application/x-jsonlines", # Alternative JSONL MIME type
|
||||
"text/xml", # Alternative XML MIME type
|
||||
}
|
||||
return mime_type in simple_text_aliases
|
||||
|
||||
def get_supported_extensions(self) -> Set[str]:
|
||||
"""
|
||||
Get set of all supported file extensions.
|
||||
|
||||
Returns:
|
||||
Set of file extensions (with leading dots)
|
||||
"""
|
||||
return set(self._file_types.keys())
|
||||
|
||||
def is_supported_extension(self, extension: str) -> bool:
|
||||
"""
|
||||
Check if a file extension is supported.
|
||||
|
||||
Args:
|
||||
extension: File extension (with or without leading dot)
|
||||
|
||||
Returns:
|
||||
True if the extension is supported
|
||||
"""
|
||||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
return extension in self._file_types
|
||||
|
||||
def get_file_type_info(self, extension: str) -> FileTypeInfo:
|
||||
"""
|
||||
Get information about a file type by extension.
|
||||
|
||||
Args:
|
||||
extension: File extension (with or without leading dot)
|
||||
|
||||
Returns:
|
||||
FileTypeInfo object with details about the file type
|
||||
|
||||
Raises:
|
||||
KeyError: If the extension is not supported
|
||||
"""
|
||||
if not extension.startswith("."):
|
||||
extension = f".{extension}"
|
||||
return self._file_types[extension]
|
||||
|
||||
|
||||
# Global registry instance
|
||||
file_type_registry = FileTypeRegistry()
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility and ease of use
|
||||
def register_mime_types() -> None:
|
||||
"""Register all supported file types with Python's mimetypes module."""
|
||||
file_type_registry.register_mime_types()
|
||||
|
||||
|
||||
def get_allowed_media_types() -> Set[str]:
|
||||
"""Get set of all allowed MIME types for file uploads."""
|
||||
return file_type_registry.get_allowed_media_types()
|
||||
|
||||
|
||||
def get_extension_to_mime_type_map() -> Dict[str, str]:
|
||||
"""Get mapping from file extensions to MIME types."""
|
||||
return file_type_registry.get_extension_to_mime_type_map()
|
||||
|
||||
|
||||
def get_simple_text_mime_types() -> Set[str]:
|
||||
"""Get set of MIME types that represent simple text files."""
|
||||
return file_type_registry.get_simple_text_mime_types()
|
||||
|
||||
|
||||
def is_simple_text_mime_type(mime_type: str) -> bool:
|
||||
"""Check if a MIME type represents simple text."""
|
||||
return file_type_registry.is_simple_text_mime_type(mime_type)
|
||||
@@ -3,22 +3,13 @@ import base64
|
||||
from mistralai import Mistral, OCRPageObject, OCRResponse, OCRUsageInfo
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.services.file_processor.file_types import is_simple_text_mime_type
|
||||
from letta.services.file_processor.parser.base_parser import FileParser
|
||||
from letta.settings import settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
SIMPLE_TEXT_MIME_TYPES = {
|
||||
"text/plain",
|
||||
"text/markdown",
|
||||
"text/x-markdown",
|
||||
"application/json",
|
||||
"application/jsonl",
|
||||
"application/x-jsonlines",
|
||||
}
|
||||
|
||||
|
||||
class MistralFileParser(FileParser):
|
||||
"""Mistral-based OCR extraction"""
|
||||
|
||||
@@ -33,7 +24,7 @@ class MistralFileParser(FileParser):
|
||||
|
||||
# TODO: Kind of hacky...we try to exit early here?
|
||||
# TODO: Create our internal file parser representation we return instead of OCRResponse
|
||||
if mime_type in SIMPLE_TEXT_MIME_TYPES or mime_type.startswith("text/"):
|
||||
if is_simple_text_mime_type(mime_type):
|
||||
text = content.decode("utf-8", errors="replace")
|
||||
return OCRResponse(
|
||||
model=self.model,
|
||||
|
||||
371
tests/data/api_server.go
Normal file
371
tests/data/api_server.go
Normal file
@@ -0,0 +1,371 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"log"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
)
|
||||
|
||||
// User represents a user in the system
|
||||
type User struct {
|
||||
ID int `json:"id"`
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
}
|
||||
|
||||
// UserService handles user-related operations
|
||||
type UserService struct {
|
||||
users map[int]*User
|
||||
nextID int
|
||||
mutex sync.RWMutex
|
||||
}
|
||||
|
||||
// NewUserService creates a new instance of UserService
|
||||
func NewUserService() *UserService {
|
||||
return &UserService{
|
||||
users: make(map[int]*User),
|
||||
nextID: 1,
|
||||
}
|
||||
}
|
||||
|
||||
// CreateUser adds a new user to the service
|
||||
func (us *UserService) CreateUser(name, email string) (*User, error) {
|
||||
us.mutex.Lock()
|
||||
defer us.mutex.Unlock()
|
||||
|
||||
if name == "" || email == "" {
|
||||
return nil, fmt.Errorf("name and email are required")
|
||||
}
|
||||
|
||||
// Check for duplicate email
|
||||
for _, user := range us.users {
|
||||
if user.Email == email {
|
||||
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||
}
|
||||
}
|
||||
|
||||
user := &User{
|
||||
ID: us.nextID,
|
||||
Name: name,
|
||||
Email: email,
|
||||
CreatedAt: time.Now(),
|
||||
UpdatedAt: time.Now(),
|
||||
}
|
||||
|
||||
us.users[us.nextID] = user
|
||||
us.nextID++
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetUser retrieves a user by ID
|
||||
func (us *UserService) GetUser(id int) (*User, error) {
|
||||
us.mutex.RLock()
|
||||
defer us.mutex.RUnlock()
|
||||
|
||||
user, exists := us.users[id]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with ID %d not found", id)
|
||||
}
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// GetAllUsers returns all users
|
||||
func (us *UserService) GetAllUsers() []*User {
|
||||
us.mutex.RLock()
|
||||
defer us.mutex.RUnlock()
|
||||
|
||||
users := make([]*User, 0, len(us.users))
|
||||
for _, user := range us.users {
|
||||
users = append(users, user)
|
||||
}
|
||||
|
||||
return users
|
||||
}
|
||||
|
||||
// UpdateUser modifies an existing user
|
||||
func (us *UserService) UpdateUser(id int, name, email string) (*User, error) {
|
||||
us.mutex.Lock()
|
||||
defer us.mutex.Unlock()
|
||||
|
||||
user, exists := us.users[id]
|
||||
if !exists {
|
||||
return nil, fmt.Errorf("user with ID %d not found", id)
|
||||
}
|
||||
|
||||
// Check for duplicate email (excluding current user)
|
||||
if email != user.Email {
|
||||
for _, u := range us.users {
|
||||
if u.Email == email && u.ID != id {
|
||||
return nil, fmt.Errorf("user with email %s already exists", email)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if name != "" {
|
||||
user.Name = name
|
||||
}
|
||||
if email != "" {
|
||||
user.Email = email
|
||||
}
|
||||
user.UpdatedAt = time.Now()
|
||||
|
||||
return user, nil
|
||||
}
|
||||
|
||||
// DeleteUser removes a user from the service
|
||||
func (us *UserService) DeleteUser(id int) error {
|
||||
us.mutex.Lock()
|
||||
defer us.mutex.Unlock()
|
||||
|
||||
if _, exists := us.users[id]; !exists {
|
||||
return fmt.Errorf("user with ID %d not found", id)
|
||||
}
|
||||
|
||||
delete(us.users, id)
|
||||
return nil
|
||||
}
|
||||
|
||||
// APIServer represents the HTTP server
|
||||
type APIServer struct {
|
||||
userService *UserService
|
||||
router *mux.Router
|
||||
}
|
||||
|
||||
// NewAPIServer creates a new API server instance
|
||||
func NewAPIServer(userService *UserService) *APIServer {
|
||||
server := &APIServer{
|
||||
userService: userService,
|
||||
router: mux.NewRouter(),
|
||||
}
|
||||
server.setupRoutes()
|
||||
return server
|
||||
}
|
||||
|
||||
// setupRoutes configures the API routes
|
||||
func (s *APIServer) setupRoutes() {
|
||||
api := s.router.PathPrefix("/api/v1").Subrouter()
|
||||
|
||||
// User routes
|
||||
api.HandleFunc("/users", s.handleGetUsers).Methods("GET")
|
||||
api.HandleFunc("/users", s.handleCreateUser).Methods("POST")
|
||||
api.HandleFunc("/users/{id:[0-9]+}", s.handleGetUser).Methods("GET")
|
||||
api.HandleFunc("/users/{id:[0-9]+}", s.handleUpdateUser).Methods("PUT")
|
||||
api.HandleFunc("/users/{id:[0-9]+}", s.handleDeleteUser).Methods("DELETE")
|
||||
|
||||
// Health check
|
||||
api.HandleFunc("/health", s.handleHealthCheck).Methods("GET")
|
||||
|
||||
// Add CORS middleware
|
||||
s.router.Use(s.corsMiddleware)
|
||||
s.router.Use(s.loggingMiddleware)
|
||||
}
|
||||
|
||||
// HTTP Handlers
|
||||
|
||||
func (s *APIServer) handleGetUsers(w http.ResponseWriter, r *http.Request) {
|
||||
users := s.userService.GetAllUsers()
|
||||
s.writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"users": users,
|
||||
"count": len(users),
|
||||
})
|
||||
}
|
||||
|
||||
func (s *APIServer) handleCreateUser(w http.ResponseWriter, r *http.Request) {
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeError(w, http.StatusBadRequest, "Invalid JSON payload")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.userService.CreateUser(req.Name, req.Email)
|
||||
if err != nil {
|
||||
s.writeError(w, http.StatusBadRequest, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.writeJSON(w, http.StatusCreated, map[string]*User{"user": user})
|
||||
}
|
||||
|
||||
func (s *APIServer) handleGetUser(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, err := strconv.Atoi(vars["id"])
|
||||
if err != nil {
|
||||
s.writeError(w, http.StatusBadRequest, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.userService.GetUser(id)
|
||||
if err != nil {
|
||||
s.writeError(w, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.writeJSON(w, http.StatusOK, map[string]*User{"user": user})
|
||||
}
|
||||
|
||||
func (s *APIServer) handleUpdateUser(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, err := strconv.Atoi(vars["id"])
|
||||
if err != nil {
|
||||
s.writeError(w, http.StatusBadRequest, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
var req struct {
|
||||
Name string `json:"name"`
|
||||
Email string `json:"email"`
|
||||
}
|
||||
|
||||
if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
|
||||
s.writeError(w, http.StatusBadRequest, "Invalid JSON payload")
|
||||
return
|
||||
}
|
||||
|
||||
user, err := s.userService.UpdateUser(id, req.Name, req.Email)
|
||||
if err != nil {
|
||||
status := http.StatusBadRequest
|
||||
if strings.Contains(err.Error(), "not found") {
|
||||
status = http.StatusNotFound
|
||||
}
|
||||
s.writeError(w, status, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.writeJSON(w, http.StatusOK, map[string]*User{"user": user})
|
||||
}
|
||||
|
||||
func (s *APIServer) handleDeleteUser(w http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
id, err := strconv.Atoi(vars["id"])
|
||||
if err != nil {
|
||||
s.writeError(w, http.StatusBadRequest, "Invalid user ID")
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.userService.DeleteUser(id); err != nil {
|
||||
s.writeError(w, http.StatusNotFound, err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
s.writeJSON(w, http.StatusOK, map[string]string{"message": "User deleted successfully"})
|
||||
}
|
||||
|
||||
func (s *APIServer) handleHealthCheck(w http.ResponseWriter, r *http.Request) {
|
||||
s.writeJSON(w, http.StatusOK, map[string]interface{}{
|
||||
"status": "healthy",
|
||||
"timestamp": time.Now(),
|
||||
"service": "user-api",
|
||||
})
|
||||
}
|
||||
|
||||
// Middleware
|
||||
|
||||
func (s *APIServer) corsMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Access-Control-Allow-Origin", "*")
|
||||
w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization")
|
||||
|
||||
if r.Method == "OPTIONS" {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func (s *APIServer) loggingMiddleware(next http.Handler) http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
start := time.Now()
|
||||
|
||||
// Wrap ResponseWriter to capture status code
|
||||
ww := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
|
||||
|
||||
next.ServeHTTP(ww, r)
|
||||
|
||||
log.Printf("%s %s %d %v", r.Method, r.URL.Path, ww.statusCode, time.Since(start))
|
||||
})
|
||||
}
|
||||
|
||||
// Helper methods
|
||||
|
||||
func (s *APIServer) writeJSON(w http.ResponseWriter, status int, data interface{}) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(status)
|
||||
json.NewEncoder(w).Encode(data)
|
||||
}
|
||||
|
||||
func (s *APIServer) writeError(w http.ResponseWriter, status int, message string) {
|
||||
s.writeJSON(w, status, map[string]string{"error": message})
|
||||
}
|
||||
|
||||
// responseWriter wraps http.ResponseWriter to capture status code
|
||||
type responseWriter struct {
|
||||
http.ResponseWriter
|
||||
statusCode int
|
||||
}
|
||||
|
||||
func (rw *responseWriter) WriteHeader(code int) {
|
||||
rw.statusCode = code
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
// Start starts the HTTP server
|
||||
func (s *APIServer) Start(ctx context.Context, addr string) error {
|
||||
server := &http.Server{
|
||||
Addr: addr,
|
||||
Handler: s.router,
|
||||
ReadTimeout: 15 * time.Second,
|
||||
WriteTimeout: 15 * time.Second,
|
||||
IdleTimeout: 60 * time.Second,
|
||||
}
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
log.Println("Shutting down server...")
|
||||
|
||||
shutdownCtx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
if err := server.Shutdown(shutdownCtx); err != nil {
|
||||
log.Printf("Server shutdown error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
log.Printf("Server starting on %s", addr)
|
||||
return server.ListenAndServe()
|
||||
}
|
||||
|
||||
func main() {
|
||||
userService := NewUserService()
|
||||
|
||||
// Add some sample data
|
||||
userService.CreateUser("John Doe", "john@example.com")
|
||||
userService.CreateUser("Jane Smith", "jane@example.com")
|
||||
|
||||
server := NewAPIServer(userService)
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
defer cancel()
|
||||
|
||||
if err := server.Start(ctx, ":8080"); err != nil && err != http.ErrServerClosed {
|
||||
log.Fatalf("Server failed to start: %v", err)
|
||||
}
|
||||
}
|
||||
402
tests/data/data_analysis.py
Normal file
402
tests/data/data_analysis.py
Normal file
@@ -0,0 +1,402 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Data Analysis Module - Advanced statistical and machine learning operations
|
||||
Contains various data processing and analysis functions for research purposes.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from enum import Enum
|
||||
from typing import Dict, Optional
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class AnalysisType(Enum):
|
||||
"""Enumeration of different analysis types."""
|
||||
|
||||
DESCRIPTIVE = "descriptive"
|
||||
CORRELATION = "correlation"
|
||||
REGRESSION = "regression"
|
||||
CLUSTERING = "clustering"
|
||||
TIME_SERIES = "time_series"
|
||||
|
||||
|
||||
@dataclass
|
||||
class AnalysisResult:
|
||||
"""Container for analysis results."""
|
||||
|
||||
analysis_type: AnalysisType
|
||||
timestamp: datetime
|
||||
metrics: Dict[str, float]
|
||||
metadata: Dict[str, any]
|
||||
success: bool = True
|
||||
error_message: Optional[str] = None
|
||||
|
||||
|
||||
class DataPreprocessor:
|
||||
"""
|
||||
Advanced data preprocessing utility class.
|
||||
Handles cleaning, transformation, and feature engineering.
|
||||
"""
|
||||
|
||||
def __init__(self, missing_threshold: float = 0.5):
|
||||
self.missing_threshold = missing_threshold
|
||||
self.transformations_applied = []
|
||||
|
||||
def clean_data(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Comprehensive data cleaning pipeline.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame to clean
|
||||
|
||||
Returns:
|
||||
Cleaned DataFrame
|
||||
"""
|
||||
original_shape = df.shape
|
||||
|
||||
# Remove columns with excessive missing values
|
||||
missing_ratios = df.isnull().sum() / len(df)
|
||||
cols_to_drop = missing_ratios[missing_ratios > self.missing_threshold].index
|
||||
df_cleaned = df.drop(columns=cols_to_drop)
|
||||
|
||||
if len(cols_to_drop) > 0:
|
||||
self.transformations_applied.append(f"Dropped {len(cols_to_drop)} columns")
|
||||
|
||||
# Handle remaining missing values
|
||||
numeric_cols = df_cleaned.select_dtypes(include=[np.number]).columns
|
||||
categorical_cols = df_cleaned.select_dtypes(include=["object"]).columns
|
||||
|
||||
# Fill numeric missing values with median
|
||||
for col in numeric_cols:
|
||||
if df_cleaned[col].isnull().any():
|
||||
median_value = df_cleaned[col].median()
|
||||
df_cleaned[col].fillna(median_value, inplace=True)
|
||||
self.transformations_applied.append(f"Filled {col} with median")
|
||||
|
||||
# Fill categorical missing values with mode
|
||||
for col in categorical_cols:
|
||||
if df_cleaned[col].isnull().any():
|
||||
mode_value = df_cleaned[col].mode().iloc[0] if not df_cleaned[col].mode().empty else "Unknown"
|
||||
df_cleaned[col].fillna(mode_value, inplace=True)
|
||||
self.transformations_applied.append(f"Filled {col} with mode")
|
||||
|
||||
# Remove duplicates
|
||||
initial_rows = len(df_cleaned)
|
||||
df_cleaned = df_cleaned.drop_duplicates()
|
||||
duplicates_removed = initial_rows - len(df_cleaned)
|
||||
|
||||
if duplicates_removed > 0:
|
||||
self.transformations_applied.append(f"Removed {duplicates_removed} duplicate rows")
|
||||
|
||||
print(f"Data cleaning complete: {original_shape} -> {df_cleaned.shape}")
|
||||
return df_cleaned
|
||||
|
||||
def engineer_features(self, df: pd.DataFrame) -> pd.DataFrame:
|
||||
"""
|
||||
Create new features from existing data.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
|
||||
Returns:
|
||||
DataFrame with engineered features
|
||||
"""
|
||||
df_featured = df.copy()
|
||||
|
||||
# Numeric feature engineering
|
||||
numeric_cols = df_featured.select_dtypes(include=[np.number]).columns
|
||||
|
||||
if len(numeric_cols) >= 2:
|
||||
# Create interaction features
|
||||
for i, col1 in enumerate(numeric_cols):
|
||||
for col2 in numeric_cols[i + 1 :]:
|
||||
df_featured[f"{col1}_{col2}_ratio"] = df_featured[col1] / (df_featured[col2] + 1e-8)
|
||||
df_featured[f"{col1}_{col2}_sum"] = df_featured[col1] + df_featured[col2]
|
||||
|
||||
self.transformations_applied.append("Created interaction features")
|
||||
|
||||
# Binning continuous variables
|
||||
for col in numeric_cols:
|
||||
if df_featured[col].nunique() > 10: # Only bin if many unique values
|
||||
df_featured[f"{col}_binned"] = pd.qcut(df_featured[col], q=5, labels=False, duplicates="drop")
|
||||
self.transformations_applied.append(f"Binned {col}")
|
||||
|
||||
return df_featured
|
||||
|
||||
|
||||
class StatisticalAnalyzer:
|
||||
"""
|
||||
Statistical analysis and hypothesis testing utilities.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def descriptive_statistics(df: pd.DataFrame) -> AnalysisResult:
|
||||
"""
|
||||
Calculate comprehensive descriptive statistics.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
|
||||
Returns:
|
||||
AnalysisResult with descriptive metrics
|
||||
"""
|
||||
try:
|
||||
numeric_df = df.select_dtypes(include=[np.number])
|
||||
|
||||
if numeric_df.empty:
|
||||
return AnalysisResult(
|
||||
analysis_type=AnalysisType.DESCRIPTIVE,
|
||||
timestamp=datetime.now(),
|
||||
metrics={},
|
||||
metadata={},
|
||||
success=False,
|
||||
error_message="No numeric columns found",
|
||||
)
|
||||
|
||||
metrics = {
|
||||
"mean_values": numeric_df.mean().to_dict(),
|
||||
"std_values": numeric_df.std().to_dict(),
|
||||
"median_values": numeric_df.median().to_dict(),
|
||||
"skewness": numeric_df.skew().to_dict(),
|
||||
"kurtosis": numeric_df.kurtosis().to_dict(),
|
||||
"correlation_with_target": None, # Would need target column
|
||||
}
|
||||
|
||||
metadata = {
|
||||
"total_rows": len(df),
|
||||
"total_columns": len(df.columns),
|
||||
"numeric_columns": len(numeric_df.columns),
|
||||
"missing_values": df.isnull().sum().to_dict(),
|
||||
}
|
||||
|
||||
return AnalysisResult(analysis_type=AnalysisType.DESCRIPTIVE, timestamp=datetime.now(), metrics=metrics, metadata=metadata)
|
||||
|
||||
except Exception as e:
|
||||
return AnalysisResult(
|
||||
analysis_type=AnalysisType.DESCRIPTIVE,
|
||||
timestamp=datetime.now(),
|
||||
metrics={},
|
||||
metadata={},
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def correlation_analysis(df: pd.DataFrame, method: str = "pearson") -> AnalysisResult:
|
||||
"""
|
||||
Perform correlation analysis between variables.
|
||||
|
||||
Args:
|
||||
df: Input DataFrame
|
||||
method: Correlation method ('pearson', 'spearman', 'kendall')
|
||||
|
||||
Returns:
|
||||
AnalysisResult with correlation metrics
|
||||
"""
|
||||
try:
|
||||
numeric_df = df.select_dtypes(include=[np.number])
|
||||
|
||||
if len(numeric_df.columns) < 2:
|
||||
return AnalysisResult(
|
||||
analysis_type=AnalysisType.CORRELATION,
|
||||
timestamp=datetime.now(),
|
||||
metrics={},
|
||||
metadata={},
|
||||
success=False,
|
||||
error_message="Need at least 2 numeric columns for correlation",
|
||||
)
|
||||
|
||||
corr_matrix = numeric_df.corr(method=method)
|
||||
|
||||
# Find highest correlations (excluding diagonal)
|
||||
corr_pairs = []
|
||||
for i in range(len(corr_matrix.columns)):
|
||||
for j in range(i + 1, len(corr_matrix.columns)):
|
||||
col1, col2 = corr_matrix.columns[i], corr_matrix.columns[j]
|
||||
corr_value = corr_matrix.iloc[i, j]
|
||||
if not np.isnan(corr_value):
|
||||
corr_pairs.append((col1, col2, abs(corr_value)))
|
||||
|
||||
# Sort by correlation strength
|
||||
corr_pairs.sort(key=lambda x: x[2], reverse=True)
|
||||
|
||||
metrics = {
|
||||
"correlation_matrix": corr_matrix.to_dict(),
|
||||
"highest_correlations": corr_pairs[:10], # Top 10
|
||||
"method_used": method,
|
||||
}
|
||||
|
||||
metadata = {"variables_analyzed": list(numeric_df.columns), "total_pairs": len(corr_pairs)}
|
||||
|
||||
return AnalysisResult(analysis_type=AnalysisType.CORRELATION, timestamp=datetime.now(), metrics=metrics, metadata=metadata)
|
||||
|
||||
except Exception as e:
|
||||
return AnalysisResult(
|
||||
analysis_type=AnalysisType.CORRELATION,
|
||||
timestamp=datetime.now(),
|
||||
metrics={},
|
||||
metadata={},
|
||||
success=False,
|
||||
error_message=str(e),
|
||||
)
|
||||
|
||||
|
||||
class TimeSeriesAnalyzer:
|
||||
"""
|
||||
Time series analysis and forecasting utilities.
|
||||
"""
|
||||
|
||||
def __init__(self, frequency: str = "D"):
|
||||
self.frequency = frequency
|
||||
self.models_fitted = {}
|
||||
|
||||
def detect_seasonality(self, series: pd.Series) -> Dict[str, any]:
|
||||
"""
|
||||
Detect seasonal patterns in time series data.
|
||||
|
||||
Args:
|
||||
series: Time series data
|
||||
|
||||
Returns:
|
||||
Dictionary with seasonality information
|
||||
"""
|
||||
try:
|
||||
# Simple seasonality detection using autocorrelation
|
||||
autocorr_values = []
|
||||
for lag in range(1, min(len(series) // 2, 365)):
|
||||
if len(series) > lag:
|
||||
autocorr = series.autocorr(lag=lag)
|
||||
if not np.isnan(autocorr):
|
||||
autocorr_values.append((lag, autocorr))
|
||||
|
||||
# Find peaks in autocorrelation
|
||||
significant_lags = [(lag, corr) for lag, corr in autocorr_values if abs(corr) > 0.5]
|
||||
significant_lags.sort(key=lambda x: abs(x[1]), reverse=True)
|
||||
|
||||
return {
|
||||
"seasonal_lags": significant_lags[:5],
|
||||
"strongest_seasonality": significant_lags[0] if significant_lags else None,
|
||||
"autocorrelation_values": autocorr_values,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
warnings.warn(f"Seasonality detection failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
def trend_analysis(self, series: pd.Series, window: int = 30) -> Dict[str, any]:
|
||||
"""
|
||||
Analyze trend patterns in time series.
|
||||
|
||||
Args:
|
||||
series: Time series data
|
||||
window: Rolling window size for trend calculation
|
||||
|
||||
Returns:
|
||||
Dictionary with trend information
|
||||
"""
|
||||
try:
|
||||
# Calculate rolling statistics
|
||||
rolling_mean = series.rolling(window=window).mean()
|
||||
rolling_std = series.rolling(window=window).std()
|
||||
|
||||
# Simple trend detection
|
||||
first_third = rolling_mean.iloc[: len(rolling_mean) // 3].mean()
|
||||
last_third = rolling_mean.iloc[-len(rolling_mean) // 3 :].mean()
|
||||
|
||||
trend_direction = "increasing" if last_third > first_third else "decreasing"
|
||||
trend_strength = abs(last_third - first_third) / first_third if first_third != 0 else 0
|
||||
|
||||
return {
|
||||
"trend_direction": trend_direction,
|
||||
"trend_strength": trend_strength,
|
||||
"rolling_mean": rolling_mean.to_dict(),
|
||||
"rolling_std": rolling_std.to_dict(),
|
||||
"volatility": rolling_std.mean(),
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
warnings.warn(f"Trend analysis failed: {e}")
|
||||
return {"error": str(e)}
|
||||
|
||||
|
||||
def generate_sample_data(n_samples: int = 1000) -> pd.DataFrame:
|
||||
"""
|
||||
Generate sample dataset for testing analysis functions.
|
||||
|
||||
Args:
|
||||
n_samples: Number of samples to generate
|
||||
|
||||
Returns:
|
||||
Sample DataFrame
|
||||
"""
|
||||
np.random.seed(42)
|
||||
|
||||
data = {
|
||||
"feature_1": np.random.normal(100, 15, n_samples),
|
||||
"feature_2": np.random.exponential(2, n_samples),
|
||||
"feature_3": np.random.uniform(0, 100, n_samples),
|
||||
"category": np.random.choice(["A", "B", "C"], n_samples),
|
||||
"timestamp": pd.date_range("2023-01-01", periods=n_samples, freq="D"),
|
||||
}
|
||||
|
||||
# Add some correlation
|
||||
data["feature_4"] = data["feature_1"] * 0.7 + np.random.normal(0, 10, n_samples)
|
||||
|
||||
# Add missing values
|
||||
missing_indices = np.random.choice(n_samples, size=int(0.05 * n_samples), replace=False)
|
||||
for idx in missing_indices:
|
||||
col = np.random.choice(["feature_1", "feature_2", "feature_3"])
|
||||
data[col][idx] = np.nan
|
||||
|
||||
return pd.DataFrame(data)
|
||||
|
||||
|
||||
def main():
|
||||
"""
|
||||
Demonstration of the data analysis pipeline.
|
||||
"""
|
||||
print("=== Data Analysis Pipeline Demo ===")
|
||||
|
||||
# Generate sample data
|
||||
df = generate_sample_data(1000)
|
||||
print(f"Generated dataset with shape: {df.shape}")
|
||||
|
||||
# Data preprocessing
|
||||
preprocessor = DataPreprocessor(missing_threshold=0.1)
|
||||
df_clean = preprocessor.clean_data(df)
|
||||
df_featured = preprocessor.engineer_features(df_clean)
|
||||
|
||||
print(f"Applied transformations: {preprocessor.transformations_applied}")
|
||||
|
||||
# Statistical analysis
|
||||
analyzer = StatisticalAnalyzer()
|
||||
|
||||
# Descriptive statistics
|
||||
desc_result = analyzer.descriptive_statistics(df_featured)
|
||||
if desc_result.success:
|
||||
print(f"Descriptive analysis completed at {desc_result.timestamp}")
|
||||
print(f"Analyzed {desc_result.metadata['numeric_columns']} numeric columns")
|
||||
|
||||
# Correlation analysis
|
||||
corr_result = analyzer.correlation_analysis(df_featured)
|
||||
if corr_result.success:
|
||||
print(f"Correlation analysis completed")
|
||||
print(f"Found {len(corr_result.metrics['highest_correlations'])} significant correlations")
|
||||
|
||||
# Time series analysis
|
||||
ts_analyzer = TimeSeriesAnalyzer()
|
||||
time_series = df_clean.set_index("timestamp")["feature_1"]
|
||||
|
||||
ts_analyzer.detect_seasonality(time_series)
|
||||
trend = ts_analyzer.trend_analysis(time_series)
|
||||
|
||||
print(f"Time series trend: {trend.get('trend_direction', 'unknown')}")
|
||||
print(f"Volatility: {trend.get('volatility', 0):.2f}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
286
tests/data/data_structures.cpp
Normal file
286
tests/data/data_structures.cpp
Normal file
@@ -0,0 +1,286 @@
|
||||
#include <iostream>
|
||||
#include <vector>
|
||||
#include <memory>
|
||||
#include <algorithm>
|
||||
#include <stdexcept>
|
||||
|
||||
/**
|
||||
* Binary Search Tree implementation with smart pointers
|
||||
* Template class supporting any comparable type
|
||||
*/
|
||||
template<typename T>
|
||||
class BinarySearchTree {
|
||||
private:
|
||||
struct Node {
|
||||
T data;
|
||||
std::unique_ptr<Node> left;
|
||||
std::unique_ptr<Node> right;
|
||||
|
||||
Node(const T& value) : data(value), left(nullptr), right(nullptr) {}
|
||||
};
|
||||
|
||||
std::unique_ptr<Node> root;
|
||||
size_t size_;
|
||||
|
||||
void insertHelper(std::unique_ptr<Node>& node, const T& value) {
|
||||
if (!node) {
|
||||
node = std::make_unique<Node>(value);
|
||||
++size_;
|
||||
return;
|
||||
}
|
||||
|
||||
if (value < node->data) {
|
||||
insertHelper(node->left, value);
|
||||
} else if (value > node->data) {
|
||||
insertHelper(node->right, value);
|
||||
}
|
||||
// Ignore duplicates
|
||||
}
|
||||
|
||||
bool searchHelper(const std::unique_ptr<Node>& node, const T& value) const {
|
||||
if (!node) return false;
|
||||
|
||||
if (value == node->data) return true;
|
||||
else if (value < node->data) return searchHelper(node->left, value);
|
||||
else return searchHelper(node->right, value);
|
||||
}
|
||||
|
||||
void inorderHelper(const std::unique_ptr<Node>& node, std::vector<T>& result) const {
|
||||
if (!node) return;
|
||||
|
||||
inorderHelper(node->left, result);
|
||||
result.push_back(node->data);
|
||||
inorderHelper(node->right, result);
|
||||
}
|
||||
|
||||
std::unique_ptr<Node> removeHelper(std::unique_ptr<Node> node, const T& value) {
|
||||
if (!node) return nullptr;
|
||||
|
||||
if (value < node->data) {
|
||||
node->left = removeHelper(std::move(node->left), value);
|
||||
} else if (value > node->data) {
|
||||
node->right = removeHelper(std::move(node->right), value);
|
||||
} else {
|
||||
// Node to delete found
|
||||
--size_;
|
||||
|
||||
if (!node->left) return std::move(node->right);
|
||||
if (!node->right) return std::move(node->left);
|
||||
|
||||
// Node has two children
|
||||
Node* successor = findMin(node->right.get());
|
||||
node->data = successor->data;
|
||||
node->right = removeHelper(std::move(node->right), successor->data);
|
||||
++size_; // Compensate for decrement in recursive call
|
||||
}
|
||||
|
||||
return node;
|
||||
}
|
||||
|
||||
Node* findMin(Node* node) const {
|
||||
while (node->left) {
|
||||
node = node->left.get();
|
||||
}
|
||||
return node;
|
||||
}
|
||||
|
||||
public:
|
||||
BinarySearchTree() : root(nullptr), size_(0) {}
|
||||
|
||||
void insert(const T& value) {
|
||||
insertHelper(root, value);
|
||||
}
|
||||
|
||||
bool search(const T& value) const {
|
||||
return searchHelper(root, value);
|
||||
}
|
||||
|
||||
void remove(const T& value) {
|
||||
root = removeHelper(std::move(root), value);
|
||||
}
|
||||
|
||||
std::vector<T> inorderTraversal() const {
|
||||
std::vector<T> result;
|
||||
inorderHelper(root, result);
|
||||
return result;
|
||||
}
|
||||
|
||||
size_t size() const { return size_; }
|
||||
bool empty() const { return size_ == 0; }
|
||||
|
||||
void clear() {
|
||||
root.reset();
|
||||
size_ = 0;
|
||||
}
|
||||
};
|
||||
|
||||
/**
|
||||
* Dynamic Array implementation with automatic resizing
|
||||
*/
|
||||
template<typename T>
|
||||
class DynamicArray {
|
||||
private:
|
||||
std::unique_ptr<T[]> data;
|
||||
size_t capacity_;
|
||||
size_t size_;
|
||||
|
||||
void resize() {
|
||||
size_t newCapacity = capacity_ == 0 ? 1 : capacity_ * 2;
|
||||
auto newData = std::make_unique<T[]>(newCapacity);
|
||||
|
||||
for (size_t i = 0; i < size_; ++i) {
|
||||
newData[i] = std::move(data[i]);
|
||||
}
|
||||
|
||||
data = std::move(newData);
|
||||
capacity_ = newCapacity;
|
||||
}
|
||||
|
||||
public:
|
||||
DynamicArray() : data(nullptr), capacity_(0), size_(0) {}
|
||||
|
||||
explicit DynamicArray(size_t initialCapacity)
|
||||
: data(std::make_unique<T[]>(initialCapacity)),
|
||||
capacity_(initialCapacity),
|
||||
size_(0) {}
|
||||
|
||||
void pushBack(const T& value) {
|
||||
if (size_ >= capacity_) {
|
||||
resize();
|
||||
}
|
||||
data[size_++] = value;
|
||||
}
|
||||
|
||||
void pushBack(T&& value) {
|
||||
if (size_ >= capacity_) {
|
||||
resize();
|
||||
}
|
||||
data[size_++] = std::move(value);
|
||||
}
|
||||
|
||||
T& operator[](size_t index) {
|
||||
if (index >= size_) {
|
||||
throw std::out_of_range("Index out of bounds");
|
||||
}
|
||||
return data[index];
|
||||
}
|
||||
|
||||
const T& operator[](size_t index) const {
|
||||
if (index >= size_) {
|
||||
throw std::out_of_range("Index out of bounds");
|
||||
}
|
||||
return data[index];
|
||||
}
|
||||
|
||||
void popBack() {
|
||||
if (size_ > 0) {
|
||||
--size_;
|
||||
}
|
||||
}
|
||||
|
||||
size_t size() const { return size_; }
|
||||
size_t capacity() const { return capacity_; }
|
||||
bool empty() const { return size_ == 0; }
|
||||
|
||||
void clear() { size_ = 0; }
|
||||
|
||||
// Iterator support
|
||||
T* begin() { return data.get(); }
|
||||
T* end() { return data.get() + size_; }
|
||||
const T* begin() const { return data.get(); }
|
||||
const T* end() const { return data.get() + size_; }
|
||||
};
|
||||
|
||||
/**
|
||||
* Stack implementation using dynamic array
|
||||
*/
|
||||
template<typename T>
|
||||
class Stack {
|
||||
private:
|
||||
DynamicArray<T> container;
|
||||
|
||||
public:
|
||||
void push(const T& value) {
|
||||
container.pushBack(value);
|
||||
}
|
||||
|
||||
void push(T&& value) {
|
||||
container.pushBack(std::move(value));
|
||||
}
|
||||
|
||||
void pop() {
|
||||
if (empty()) {
|
||||
throw std::runtime_error("Stack underflow");
|
||||
}
|
||||
container.popBack();
|
||||
}
|
||||
|
||||
T& top() {
|
||||
if (empty()) {
|
||||
throw std::runtime_error("Stack is empty");
|
||||
}
|
||||
return container[container.size() - 1];
|
||||
}
|
||||
|
||||
const T& top() const {
|
||||
if (empty()) {
|
||||
throw std::runtime_error("Stack is empty");
|
||||
}
|
||||
return container[container.size() - 1];
|
||||
}
|
||||
|
||||
bool empty() const { return container.empty(); }
|
||||
size_t size() const { return container.size(); }
|
||||
};
|
||||
|
||||
// Demonstration and testing
|
||||
int main() {
|
||||
std::cout << "=== Binary Search Tree Demo ===" << std::endl;
|
||||
|
||||
BinarySearchTree<int> bst;
|
||||
std::vector<int> values = {50, 30, 70, 20, 40, 60, 80, 10, 25, 35};
|
||||
|
||||
for (int val : values) {
|
||||
bst.insert(val);
|
||||
}
|
||||
|
||||
std::cout << "Tree size: " << bst.size() << std::endl;
|
||||
std::cout << "Inorder traversal: ";
|
||||
auto inorder = bst.inorderTraversal();
|
||||
for (size_t i = 0; i < inorder.size(); ++i) {
|
||||
std::cout << inorder[i];
|
||||
if (i < inorder.size() - 1) std::cout << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "\n=== Dynamic Array Demo ===" << std::endl;
|
||||
|
||||
DynamicArray<std::string> arr;
|
||||
arr.pushBack("Hello");
|
||||
arr.pushBack("World");
|
||||
arr.pushBack("C++");
|
||||
arr.pushBack("Templates");
|
||||
|
||||
std::cout << "Array contents: ";
|
||||
for (size_t i = 0; i < arr.size(); ++i) {
|
||||
std::cout << arr[i];
|
||||
if (i < arr.size() - 1) std::cout << ", ";
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
std::cout << "\n=== Stack Demo ===" << std::endl;
|
||||
|
||||
Stack<int> stack;
|
||||
for (int i = 1; i <= 5; ++i) {
|
||||
stack.push(i * 10);
|
||||
}
|
||||
|
||||
std::cout << "Stack contents (top to bottom): ";
|
||||
while (!stack.empty()) {
|
||||
std::cout << stack.top() << " ";
|
||||
stack.pop();
|
||||
}
|
||||
std::cout << std::endl;
|
||||
|
||||
return 0;
|
||||
}
|
||||
123
tests/data/react_component.jsx
Normal file
123
tests/data/react_component.jsx
Normal file
@@ -0,0 +1,123 @@
|
||||
import React, { useState, useEffect } from 'react';
|
||||
import PropTypes from 'prop-types';
|
||||
|
||||
/**
|
||||
* UserProfile component for displaying user information
|
||||
* @param {Object} props - Component props
|
||||
* @param {Object} props.user - User object
|
||||
* @param {Function} props.onEdit - Edit callback function
|
||||
*/
|
||||
const UserProfile = ({ user, onEdit }) => {
|
||||
const [isEditing, setIsEditing] = useState(false);
|
||||
const [userData, setUserData] = useState(user);
|
||||
const [loading, setLoading] = useState(false);
|
||||
|
||||
useEffect(() => {
|
||||
setUserData(user);
|
||||
}, [user]);
|
||||
|
||||
const handleSave = async () => {
|
||||
setLoading(true);
|
||||
try {
|
||||
await onEdit(userData);
|
||||
setIsEditing(false);
|
||||
} catch (error) {
|
||||
console.error('Failed to save user data:', error);
|
||||
} finally {
|
||||
setLoading(false);
|
||||
}
|
||||
};
|
||||
|
||||
const handleCancel = () => {
|
||||
setUserData(user);
|
||||
setIsEditing(false);
|
||||
};
|
||||
|
||||
const handleInputChange = (field, value) => {
|
||||
setUserData(prev => ({
|
||||
...prev,
|
||||
[field]: value
|
||||
}));
|
||||
};
|
||||
|
||||
if (loading) {
|
||||
return <div className="loading-spinner">Saving...</div>;
|
||||
}
|
||||
|
||||
return (
|
||||
<div className="user-profile">
|
||||
<div className="profile-header">
|
||||
<h2>{userData.name}</h2>
|
||||
{!isEditing && (
|
||||
<button onClick={() => setIsEditing(true)} className="edit-btn">
|
||||
Edit Profile
|
||||
</button>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="profile-content">
|
||||
{isEditing ? (
|
||||
<form onSubmit={(e) => { e.preventDefault(); handleSave(); }}>
|
||||
<div className="form-group">
|
||||
<label htmlFor="name">Name:</label>
|
||||
<input
|
||||
id="name"
|
||||
type="text"
|
||||
value={userData.name}
|
||||
onChange={(e) => handleInputChange('name', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="form-group">
|
||||
<label htmlFor="email">Email:</label>
|
||||
<input
|
||||
id="email"
|
||||
type="email"
|
||||
value={userData.email}
|
||||
onChange={(e) => handleInputChange('email', e.target.value)}
|
||||
required
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="form-group">
|
||||
<label htmlFor="bio">Bio:</label>
|
||||
<textarea
|
||||
id="bio"
|
||||
value={userData.bio || ''}
|
||||
onChange={(e) => handleInputChange('bio', e.target.value)}
|
||||
rows={4}
|
||||
/>
|
||||
</div>
|
||||
|
||||
<div className="form-actions">
|
||||
<button type="submit" className="save-btn">Save</button>
|
||||
<button type="button" onClick={handleCancel} className="cancel-btn">
|
||||
Cancel
|
||||
</button>
|
||||
</div>
|
||||
</form>
|
||||
) : (
|
||||
<div className="profile-display">
|
||||
<p><strong>Email:</strong> {userData.email}</p>
|
||||
<p><strong>Bio:</strong> {userData.bio || 'No bio provided'}</p>
|
||||
<p><strong>Member since:</strong> {new Date(userData.joinDate).toLocaleDateString()}</p>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
UserProfile.propTypes = {
|
||||
user: PropTypes.shape({
|
||||
id: PropTypes.number.isRequired,
|
||||
name: PropTypes.string.isRequired,
|
||||
email: PropTypes.string.isRequired,
|
||||
bio: PropTypes.string,
|
||||
joinDate: PropTypes.string.isRequired,
|
||||
}).isRequired,
|
||||
onEdit: PropTypes.func.isRequired,
|
||||
};
|
||||
|
||||
export default UserProfile;
|
||||
177
tests/data/task_manager.java
Normal file
177
tests/data/task_manager.java
Normal file
@@ -0,0 +1,177 @@
|
||||
package com.example.taskmanager;
|
||||
|
||||
import java.util.*;
|
||||
import java.time.LocalDateTime;
|
||||
|
||||
/**
|
||||
* TaskManager class for managing tasks and their lifecycle
|
||||
*
|
||||
* @author Development Team
|
||||
* @version 1.0
|
||||
*/
|
||||
public class TaskManager {
|
||||
private Map<String, Task> tasks;
|
||||
private List<TaskObserver> observers;
|
||||
private static final int MAX_TASKS = 1000;
|
||||
|
||||
public TaskManager() {
|
||||
this.tasks = new HashMap<>();
|
||||
this.observers = new ArrayList<>();
|
||||
}
|
||||
|
||||
/**
|
||||
* Adds a new task to the manager
|
||||
*
|
||||
* @param id Unique task identifier
|
||||
* @param title Task title
|
||||
* @param description Task description
|
||||
* @param priority Task priority level
|
||||
* @return true if task was added successfully
|
||||
* @throws IllegalArgumentException if task ID already exists
|
||||
* @throws IllegalStateException if maximum tasks exceeded
|
||||
*/
|
||||
public boolean addTask(String id, String title, String description, Priority priority) {
|
||||
if (tasks.containsKey(id)) {
|
||||
throw new IllegalArgumentException("Task with ID " + id + " already exists");
|
||||
}
|
||||
|
||||
if (tasks.size() >= MAX_TASKS) {
|
||||
throw new IllegalStateException("Maximum number of tasks reached");
|
||||
}
|
||||
|
||||
Task newTask = new Task(id, title, description, priority);
|
||||
tasks.put(id, newTask);
|
||||
notifyObservers(TaskEvent.TASK_ADDED, newTask);
|
||||
return true;
|
||||
}
|
||||
|
||||
/**
|
||||
* Updates an existing task status
|
||||
*/
|
||||
public void updateTaskStatus(String id, TaskStatus newStatus) {
|
||||
Task task = tasks.get(id);
|
||||
if (task == null) {
|
||||
throw new NoSuchElementException("Task not found: " + id);
|
||||
}
|
||||
|
||||
TaskStatus oldStatus = task.getStatus();
|
||||
task.setStatus(newStatus);
|
||||
task.setLastModified(LocalDateTime.now());
|
||||
|
||||
notifyObservers(TaskEvent.TASK_UPDATED, task);
|
||||
|
||||
if (newStatus == TaskStatus.COMPLETED) {
|
||||
handleTaskCompletion(task);
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Retrieves tasks by status
|
||||
*/
|
||||
public List<Task> getTasksByStatus(TaskStatus status) {
|
||||
return tasks.values()
|
||||
.stream()
|
||||
.filter(task -> task.getStatus() == status)
|
||||
.sorted(Comparator.comparing(Task::getPriority))
|
||||
.toList();
|
||||
}
|
||||
|
||||
/**
|
||||
* Handles task completion logic
|
||||
*/
|
||||
private void handleTaskCompletion(Task task) {
|
||||
task.setCompletedAt(LocalDateTime.now());
|
||||
System.out.println("Task completed: " + task.getTitle());
|
||||
|
||||
// Check for dependent tasks
|
||||
tasks.values().stream()
|
||||
.filter(t -> t.getDependencies().contains(task.getId()))
|
||||
.forEach(this::checkDependenciesResolved);
|
||||
}
|
||||
|
||||
private void checkDependenciesResolved(Task task) {
|
||||
boolean allResolved = task.getDependencies()
|
||||
.stream()
|
||||
.allMatch(depId -> {
|
||||
Task dep = tasks.get(depId);
|
||||
return dep != null && dep.getStatus() == TaskStatus.COMPLETED;
|
||||
});
|
||||
|
||||
if (allResolved && task.getStatus() == TaskStatus.BLOCKED) {
|
||||
updateTaskStatus(task.getId(), TaskStatus.TODO);
|
||||
}
|
||||
}
|
||||
|
||||
public void addObserver(TaskObserver observer) {
|
||||
observers.add(observer);
|
||||
}
|
||||
|
||||
private void notifyObservers(TaskEvent event, Task task) {
|
||||
observers.forEach(observer -> observer.onTaskEvent(event, task));
|
||||
}
|
||||
|
||||
// Inner classes and enums
|
||||
public enum Priority {
|
||||
LOW(1), MEDIUM(2), HIGH(3), CRITICAL(4);
|
||||
|
||||
private final int value;
|
||||
Priority(int value) { this.value = value; }
|
||||
public int getValue() { return value; }
|
||||
}
|
||||
|
||||
public enum TaskStatus {
|
||||
TODO, IN_PROGRESS, BLOCKED, COMPLETED, CANCELLED
|
||||
}
|
||||
|
||||
public enum TaskEvent {
|
||||
TASK_ADDED, TASK_UPDATED, TASK_DELETED
|
||||
}
|
||||
|
||||
public static class Task {
|
||||
private String id;
|
||||
private String title;
|
||||
private String description;
|
||||
private Priority priority;
|
||||
private TaskStatus status;
|
||||
private LocalDateTime createdAt;
|
||||
private LocalDateTime lastModified;
|
||||
private LocalDateTime completedAt;
|
||||
private Set<String> dependencies;
|
||||
|
||||
public Task(String id, String title, String description, Priority priority) {
|
||||
this.id = id;
|
||||
this.title = title;
|
||||
this.description = description;
|
||||
this.priority = priority;
|
||||
this.status = TaskStatus.TODO;
|
||||
this.createdAt = LocalDateTime.now();
|
||||
this.lastModified = LocalDateTime.now();
|
||||
this.dependencies = new HashSet<>();
|
||||
}
|
||||
|
||||
// Getters and setters
|
||||
public String getId() { return id; }
|
||||
public String getTitle() { return title; }
|
||||
public String getDescription() { return description; }
|
||||
public Priority getPriority() { return priority; }
|
||||
public TaskStatus getStatus() { return status; }
|
||||
public LocalDateTime getCreatedAt() { return createdAt; }
|
||||
public LocalDateTime getLastModified() { return lastModified; }
|
||||
public LocalDateTime getCompletedAt() { return completedAt; }
|
||||
public Set<String> getDependencies() { return dependencies; }
|
||||
|
||||
public void setStatus(TaskStatus status) { this.status = status; }
|
||||
public void setLastModified(LocalDateTime lastModified) { this.lastModified = lastModified; }
|
||||
public void setCompletedAt(LocalDateTime completedAt) { this.completedAt = completedAt; }
|
||||
|
||||
@Override
|
||||
public String toString() {
|
||||
return String.format("Task{id='%s', title='%s', status=%s, priority=%s}",
|
||||
id, title, status, priority);
|
||||
}
|
||||
}
|
||||
|
||||
public interface TaskObserver {
|
||||
void onTaskEvent(TaskEvent event, Task task);
|
||||
}
|
||||
}
|
||||
@@ -137,6 +137,11 @@ def test_auto_attach_detach_files_tools(client: LettaSDKClient):
|
||||
("tests/data/toy_chat_fine_tuning.jsonl", '{"messages"', r"toy_chat_fine_tuning_[a-z0-9]+\.jsonl"),
|
||||
("tests/data/test.md", "h2 Heading", r"test_[a-z0-9]+\.md"),
|
||||
("tests/data/test.json", "glossary", r"test_[a-z0-9]+\.json"),
|
||||
("tests/data/react_component.jsx", "UserProfile", r"react_component_[a-z0-9]+\.jsx"),
|
||||
("tests/data/task_manager.java", "TaskManager", r"task_manager_[a-z0-9]+\.java"),
|
||||
("tests/data/data_structures.cpp", "BinarySearchTree", r"data_structures_[a-z0-9]+\.cpp"),
|
||||
("tests/data/api_server.go", "UserService", r"api_server_[a-z0-9]+\.go"),
|
||||
("tests/data/data_analysis.py", "StatisticalAnalyzer", r"data_analysis_[a-z0-9]+\.py"),
|
||||
],
|
||||
)
|
||||
def test_file_upload_creates_source_blocks_correctly(
|
||||
@@ -158,11 +163,14 @@ def test_file_upload_creates_source_blocks_correctly(
|
||||
job = client.sources.files.upload(source_id=source.id, file=f)
|
||||
|
||||
# Wait for the job to complete
|
||||
while job.status != "completed":
|
||||
while job.status != "completed" and job.status != "failed":
|
||||
time.sleep(1)
|
||||
job = client.jobs.retrieve(job_id=job.id)
|
||||
print("Waiting for jobs to complete...", job.status)
|
||||
|
||||
if job.status == "failed":
|
||||
pytest.fail("Job failed. Check error logs.")
|
||||
|
||||
# Get uploaded files
|
||||
files = client.sources.files.list(source_id=source.id, limit=1)
|
||||
assert len(files) == 1
|
||||
|
||||
Reference in New Issue
Block a user