diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..f3a8404 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,27 @@ +# Auto detect text files and perform LF normalization +* text=auto + +# Python files +*.py text eol=lf + +# Markdown files +*.md text eol=lf + +# JSON files +*.json text eol=lf + +# YAML files +*.yml text eol=lf +*.yaml text eol=lf + +# Shell scripts +*.sh text eol=lf + +# Configuration files +*.toml text eol=lf +*.cfg text eol=lf +*.ini text eol=lf + +# Keep Windows batch files with CRLF +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/README.md b/README.md index 4d607f4..89e2819 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,10 @@ Test Dataset + Model Predictions --> [benchmark.py] --> Metrics Report QLoRA (4-bit quantization) for training on 1-2 A100 GPUs. - **`serve.py`** - FastAPI inference server that uses ollama API to generate - docstrings. The server uses a hard-coded system prompt for NumPy-style docstring - generation. + docstrings. Supports multiple Qwen Coder models with model-specific configurations. + +- **`models.py`** - Model configuration registry with sampling parameters for + Qwen 2.5 Coder and Qwen3 Coder variants. ### Evaluation (`src/evaluation/`) @@ -97,15 +99,22 @@ ollama as the backend. The server uses a system prompt stored in ### Prerequisites 1. **Install ollama**: Make sure [ollama](https://ollama.ai/) is installed and running locally -2. **Pull a model**: Download a code model (e.g., `qwen2.5-coder:32b`): +2. **Pull a model**: Download one of the supported code models: ```bash - ollama pull qwen2.5-coder:32b + # Qwen 2.5 Coder (dense models) + ollama pull qwen2.5-coder:32b # Default, ~18GB Q4 + ollama pull qwen2.5-coder:14b # Mid-size, ~8GB Q4 + ollama pull qwen2.5-coder:7b # Fast, ~4GB Q4 + + # Qwen3 Coder (MoE model) + ollama pull qwen3-coder:30b-a3b # Best quality, ~18GB Q4, 256K context ``` ### Starting the Server Start the FastAPI server using uvicorn: +**Linux/macOS:** ```bash # Using uvicorn directly uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 @@ -114,6 +123,11 @@ uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 python -m src.training.serve ``` +**Windows (PowerShell):** +```powershell +uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 +``` + The server will start on `http://localhost:8000` by default. ### Configuration @@ -121,12 +135,63 @@ The server will start on `http://localhost:8000` by default. The server can be configured using environment variables: - `OLLAMA_URL` - Ollama API endpoint (default: `http://localhost:11434/api/chat`) -- `OLLAMA_MODEL` - Model name to use (default: `qwen2.5-coder:32b`) +- `OLLAMA_MODEL` - Model key or Ollama model name (default: `qwen2.5-coder-32b`) - `REQUEST_TIMEOUT` - Request timeout in seconds (default: `120.0`) -Example: +**Linux/macOS:** +```bash +OLLAMA_MODEL=qwen3-coder-30b uvicorn src.training.serve:app --port 8000 +``` + +**Windows (PowerShell):** +```powershell +$env:OLLAMA_MODEL="qwen3-coder-30b"; uvicorn src.training.serve:app --port 8000 +``` + +**Windows (CMD):** +```cmd +set OLLAMA_MODEL=qwen3-coder-30b && uvicorn src.training.serve:app --port 8000 +``` + +### Available Models + +| Model Key | Ollama Model | Architecture | Memory (Q4) | Context | Description | +|-----------|--------------|--------------|-------------|---------|-------------| +| `qwen2.5-coder-32b` | `qwen2.5-coder:32b` | Dense | ~18GB | 32K | Default, balanced quality/speed | +| `qwen2.5-coder-14b` | `qwen2.5-coder:14b` | Dense | ~8GB | 32K | Mid-size, good performance | +| `qwen2.5-coder-7b` | `qwen2.5-coder:7b` | Dense | ~4GB | 32K | Fast inference | +| `qwen3-coder-30b` | `qwen3-coder:30b-a3b` | MoE | ~18GB | 256K | Best quality, 3.3B active params | + +Each model has optimized sampling parameters: +- **Qwen 2.5 Coder**: temperature=0.7, top_p=0.9, top_k=40 +- **Qwen3 Coder**: temperature=1.0, top_p=0.95, top_k=40 (per official recommendations) + +### Model Selection + +You can select a model in two ways: + +1. **Environment variable** (applies to all requests): + ```bash + OLLAMA_MODEL=qwen3-coder-30b uvicorn src.training.serve:app + ``` + +2. **Per-request** (via API): + ```bash + curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{"code": "def add(x, y): return x + y", "model": "qwen3-coder-30b"}' + ``` + +### List Available Models + +**Via CLI:** +```bash +python scripts/run_ollama.py --list-models +``` + +**Via API:** ```bash -OLLAMA_MODEL=qwen2.5-coder:7b uvicorn src.training.serve:app --port 8000 +curl http://localhost:8000/models ``` ### API Endpoints @@ -143,7 +208,9 @@ curl http://localhost:8000/health ```json { "status": "healthy", - "service": "ollama" + "service": "ollama", + "active_model": "Qwen 2.5 Coder 32B", + "ollama_model": "qwen2.5-coder:32b" } ``` @@ -169,12 +236,14 @@ curl -X POST http://localhost:8000/generate \ **Request Body:** - `code` (required): Python function code as a string -- `max_new_tokens` (optional): Maximum number of tokens to generate (default: 256) +- `max_new_tokens` (optional): Maximum number of tokens to generate (uses model default if not specified) +- `model` (optional): Model key or Ollama model name to use for this request **Response (200 OK):** ```json { - "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n\"\"\"" + "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\"\"\"", + "model": "qwen2.5-coder:32b" } ``` @@ -185,12 +254,57 @@ curl -X POST http://localhost:8000/generate \ } ``` +#### List Models + +Get available model configurations: + +```bash +curl http://localhost:8000/models +``` + +**Response (200 OK):** +```json +{ + "default": "qwen2.5-coder-32b", + "active": "qwen2.5-coder-32b", + "models": [ + { + "key": "qwen2.5-coder-32b", + "name": "Qwen 2.5 Coder 32B", + "ollama_model": "qwen2.5-coder:32b", + "context_window": 32768, + "architecture": "dense", + "memory_q4": "~18GB", + "description": "Dense 32B model, good balance of quality and speed" + } + ] +} +``` + +### CLI Tool + +The CLI tool allows testing docstring generation directly: + +```bash +# Use default model +python scripts/run_ollama.py --user "def add(x, y): return x + y" + +# Use specific model by key +python scripts/run_ollama.py --model-key qwen3-coder-30b --user "def foo(): pass" + +# Use raw Ollama model name +python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def bar(): pass" + +# List available models +python scripts/run_ollama.py --list-models +``` + ### Testing Run the test suite to verify the API endpoints: ```bash -pytest tests/test_serve.py -v +pytest tests/test_serve.py tests/test_models.py -v ``` ## Dataset diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py index bef6323..40496d0 100644 --- a/scripts/run_ollama.py +++ b/scripts/run_ollama.py @@ -1,10 +1,37 @@ +"""CLI script for testing docstring generation with Ollama. + +Supports model selection via registry keys or raw Ollama model names, +with model-specific sampling parameters. + +Usage: + python scripts/run_ollama.py --model-key qwen2.5-coder-32b --user "def add(x, y): return x + y" + python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def foo(): pass" + python scripts/run_ollama.py --list-models +""" + import argparse -import requests -import sys import json +import sys import time from pathlib import Path +import requests + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + SamplingConfig, + get_model_config, + list_models, +) + + +DEFAULT_URL = "http://localhost:11434/api/chat" + def load_system_prompt() -> str: """Load the default system prompt from the prompts directory.""" @@ -17,10 +44,13 @@ def load_system_prompt() -> str: return prompt_path.read_text(encoding="utf-8") -DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate - - -def build_payload(model, system_msgs, user_msgs, stream): +def build_payload( + model_config: ModelConfig, + system_msgs: list[str], + user_msgs: list[str], + stream: bool +) -> dict: + """Build the request payload with model-specific sampling parameters.""" messages = [] for s in system_msgs: messages.append({"role": "system", "content": s}) @@ -28,24 +58,131 @@ def build_payload(model, system_msgs, user_msgs, stream): messages.append({"role": "user", "content": u}) return { - "model": model, + "model": model_config.ollama_model, "messages": messages, "stream": stream, - "keep_alive": 0 # Unload model after request + "keep_alive": model_config.keep_alive, + "options": { + "temperature": model_config.sampling.temperature, + "top_p": model_config.sampling.top_p, + "top_k": model_config.sampling.top_k, + } } +def print_models_list(): + """Print all available models in a formatted table.""" + print("\nAvailable Models:") + print("=" * 80) + print(f"{'Key':<22} {'Ollama Model':<25} {'Arch':<6} {'Memory':<8} {'Context'}") + print("-" * 80) + + for model_info in list_models(): + key = model_info["key"] + ollama = model_info["ollama_model"] + arch = model_info["architecture"] + memory = model_info.get("memory_q4", "N/A") + context = f"{model_info['context_window']:,}" + + # Mark default model + marker = " *" if key == DEFAULT_MODEL_KEY else "" + print(f"{key:<22} {ollama:<25} {arch:<6} {memory:<8} {context}{marker}") + + print("-" * 80) + print(f"* = default model ({DEFAULT_MODEL_KEY})") + print() + + def main(): - parser = argparse.ArgumentParser(description="Send system and user prompts to model endpoint") - parser.add_argument("--url", default=DEFAULT_URL, help="API endpoint URL") - parser.add_argument("--model", required=True, help="Model name, e.g. qwen2.5-coder:32b") - parser.add_argument("--system", action="append", default=None, help="System prompt (repeatable). Overrides default.") - parser.add_argument("--user", action="append", default=[], help="User prompt (repeatable)") - parser.add_argument("--stream", action="store_true", help="Enable streaming mode") - parser.add_argument("--timeout", type=float, default=120.0, help="Request timeout in seconds") + parser = argparse.ArgumentParser( + description="Send prompts to Ollama for docstring generation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use a registered model by key + python scripts/run_ollama.py --model-key qwen2.5-coder-32b --user "def add(x, y): return x + y" + + # Use a raw Ollama model name + python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def foo(): pass" + + # Use Qwen3 MoE model + python scripts/run_ollama.py --model-key qwen3-coder-30b --user "def hello(): print('hi')" + + # List available models + python scripts/run_ollama.py --list-models +""" + ) + parser.add_argument( + "--url", + default=DEFAULT_URL, + help="API endpoint URL (default: %(default)s)" + ) + parser.add_argument( + "--model-key", + help=f"Model key from registry (e.g., {DEFAULT_MODEL_KEY}, qwen3-coder-30b)" + ) + parser.add_argument( + "--model", + help="Raw Ollama model name (e.g., qwen2.5-coder:32b). Overrides --model-key" + ) + parser.add_argument( + "--list-models", + action="store_true", + help="List available model configurations and exit" + ) + parser.add_argument( + "--system", + action="append", + default=None, + help="System prompt (repeatable). Overrides default." + ) + parser.add_argument( + "--user", + action="append", + default=[], + help="User prompt (repeatable)" + ) + parser.add_argument( + "--stream", + action="store_true", + help="Enable streaming mode" + ) + parser.add_argument( + "--timeout", + type=float, + default=120.0, + help="Request timeout in seconds (default: %(default)s)" + ) args = parser.parse_args() + # Handle --list-models + if args.list_models: + print_models_list() + sys.exit(0) + + # Determine model configuration + if args.model: + # Raw Ollama model name provided + model_config = get_model_config(args.model) + print(f"Using model: {model_config.ollama_model}", file=sys.stderr) + elif args.model_key: + # Registry key provided + model_config = get_model_config(args.model_key) + print(f"Using model: {model_config.name} ({model_config.ollama_model})", file=sys.stderr) + else: + # Use default + model_config = get_model_config(DEFAULT_MODEL_KEY) + print(f"Using default model: {model_config.name} ({model_config.ollama_model})", file=sys.stderr) + + # Print sampling parameters + print( + f"Sampling: temp={model_config.sampling.temperature}, " + f"top_p={model_config.sampling.top_p}, " + f"top_k={model_config.sampling.top_k}", + file=sys.stderr + ) + # Use default system prompt if none provided if args.system: system_msgs = args.system @@ -58,9 +195,10 @@ def main(): if not args.user: print("Error: at least one --user prompt is required.", file=sys.stderr) + print("Use --help for usage information.", file=sys.stderr) sys.exit(2) - payload = build_payload(args.model, system_msgs, args.user, args.stream) + payload = build_payload(model_config, system_msgs, args.user, args.stream) # Track execution time start_time = time.time() @@ -93,8 +231,8 @@ def main(): # Print execution time elapsed_time = time.time() - start_time - print(f"Execution time: {elapsed_time:.2f}s", file=sys.stdout) + print(f"\nExecution time: {elapsed_time:.2f}s", file=sys.stderr) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/training/models.py b/src/training/models.py new file mode 100644 index 0000000..3040ca2 --- /dev/null +++ b/src/training/models.py @@ -0,0 +1,193 @@ +"""Model configuration registry for docstring generation. + +Provides model-specific configurations including sampling parameters, +context windows, and memory requirements for different LLM backends. +""" + +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class SamplingConfig: + """Sampling parameters for text generation. + + Parameters + ---------- + temperature : float + Controls randomness in generation. Higher values = more random. + top_p : float + Nucleus sampling parameter. Considers tokens with cumulative probability >= top_p. + top_k : int + Limits sampling to top k most likely tokens. + num_predict : int + Maximum number of tokens to generate. + """ + temperature: float = 0.7 + top_p: float = 0.9 + top_k: int = 40 + num_predict: int = 256 + + +@dataclass(frozen=True) +class ModelConfig: + """Configuration for a specific model. + + Parameters + ---------- + name : str + Human-readable display name for the model. + ollama_model : str + Ollama model identifier (e.g., "qwen2.5-coder:32b"). + context_window : int + Maximum context length in tokens. + sampling : SamplingConfig + Default sampling parameters for this model. + keep_alive : int + Model keep-alive time in seconds (0 = unload immediately after request). + architecture : str + Model architecture type: "dense" or "moe". + total_params : str, optional + Total parameter count description (e.g., "32B", "30B (3.3B active)"). + memory_q4 : str, optional + Approximate memory requirement at Q4 quantization. + description : str + Brief description of the model's characteristics. + """ + name: str + ollama_model: str + context_window: int + sampling: SamplingConfig + keep_alive: int = 0 + architecture: str = "dense" + total_params: Optional[str] = None + memory_q4: Optional[str] = None + description: str = "" + + +# Pre-defined model configurations +MODEL_REGISTRY: dict[str, ModelConfig] = { + "qwen2.5-coder-32b": ModelConfig( + name="Qwen 2.5 Coder 32B", + ollama_model="qwen2.5-coder:32b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="32B", + memory_q4="~18GB", + description="Dense 32B model, good balance of quality and speed" + ), + "qwen2.5-coder-14b": ModelConfig( + name="Qwen 2.5 Coder 14B", + ollama_model="qwen2.5-coder:14b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="14B", + memory_q4="~8GB", + description="Mid-size dense model, balanced performance" + ), + "qwen2.5-coder-7b": ModelConfig( + name="Qwen 2.5 Coder 7B", + ollama_model="qwen2.5-coder:7b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="7B", + memory_q4="~4GB", + description="Smaller variant, faster inference" + ), + "qwen3-coder-30b": ModelConfig( + name="Qwen3 Coder 30B-A3B", + ollama_model="qwen3-coder:30b-a3b", + context_window=262144, # 256K tokens + sampling=SamplingConfig(temperature=1.0, top_p=0.95, top_k=40), + keep_alive=300, # Longer keep_alive for MoE model to avoid reload overhead + architecture="moe", + total_params="30B (3.3B active)", + memory_q4="~18GB", + description="MoE model with 256K context, best quality" + ), +} + +# Default model key +DEFAULT_MODEL_KEY = "qwen2.5-coder-32b" + + +def get_model_config(model_key: str) -> ModelConfig: + """Get configuration for a model by registry key or Ollama model name. + + Parameters + ---------- + model_key : str + Either a registry key (e.g., "qwen2.5-coder-32b") or a raw Ollama + model name (e.g., "qwen2.5-coder:32b"). + + Returns + ------- + ModelConfig + The model configuration. If the key is not found in the registry, + creates a fallback configuration with default sampling parameters. + """ + # First, try direct registry lookup + if model_key in MODEL_REGISTRY: + return MODEL_REGISTRY[model_key] + + # Check if it matches any ollama_model in the registry + for config in MODEL_REGISTRY.values(): + if config.ollama_model == model_key: + return config + + # Fallback: create a default config for unknown models + return create_fallback_config(model_key) + + +def create_fallback_config(ollama_model: str) -> ModelConfig: + """Create a fallback configuration for unknown models. + + Parameters + ---------- + ollama_model : str + The Ollama model identifier. + + Returns + ------- + ModelConfig + A configuration with default sampling parameters. + """ + return ModelConfig( + name=ollama_model, + ollama_model=ollama_model, + context_window=32768, # Conservative default + sampling=SamplingConfig(), # Use defaults + keep_alive=0, + architecture="unknown", + description="Custom model (using default configuration)" + ) + + +def list_models() -> list[dict]: + """List all available model configurations. + + Returns + ------- + list of dict + List of model information dictionaries containing key, name, + ollama_model, context_window, architecture, memory_q4, and description. + """ + return [ + { + "key": key, + "name": config.name, + "ollama_model": config.ollama_model, + "context_window": config.context_window, + "architecture": config.architecture, + "total_params": config.total_params, + "memory_q4": config.memory_q4, + "description": config.description, + } + for key, config in MODEL_REGISTRY.items() + ] diff --git a/src/training/serve.py b/src/training/serve.py index ea5f9e8..cbcde12 100644 --- a/src/training/serve.py +++ b/src/training/serve.py @@ -1,7 +1,7 @@ -"""FastAPI inference server for the fine-tuned docstring generation model. +"""FastAPI inference server for docstring generation using multiple LLM backends. -Loads a LoRA-adapted model and serves predictions via HTTP. Designed to be -called by the VS Code extension for local, offline docstring generation. +Serves predictions via HTTP using Ollama as the backend. Supports multiple models +including Qwen 2.5 Coder and Qwen3 Coder variants with model-specific configurations. Usage: uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 @@ -9,6 +9,7 @@ Endpoints: POST /generate - Generate a docstring for a given code snippet GET /health - Health check + GET /models - List available models """ import os @@ -19,12 +20,25 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + get_model_config, + list_models, +) + # Configuration OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/chat") -OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:32b") +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", DEFAULT_MODEL_KEY) REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "120.0")) +def get_active_model() -> ModelConfig: + """Get the currently active model configuration from environment.""" + return get_model_config(OLLAMA_MODEL) + + def load_system_prompt() -> str: """Load the default system prompt from the prompts directory.""" prompt_path = Path(__file__).parent / "prompts" / "system_prompt.md" @@ -40,26 +54,62 @@ def load_system_prompt() -> str: SYSTEM_PROMPT = load_system_prompt() -def generate_docstring(code: str, max_new_tokens: int = 256) -> str: - """Generate a docstring for the given code snippet using ollama API.""" +def generate_docstring( + code: str, + max_new_tokens: Optional[int] = None, + model_config: Optional[ModelConfig] = None +) -> str: + """Generate a docstring for the given code snippet using ollama API. + + Parameters + ---------- + code : str + The Python code snippet to generate a docstring for. + max_new_tokens : int, optional + Maximum number of tokens to generate. If None, uses the model's default. + model_config : ModelConfig, optional + Model configuration to use. If None, uses the active model from environment. + + Returns + ------- + str + The generated docstring. + + Raises + ------ + RuntimeError + If the request to ollama fails. + ValueError + If the response format is unexpected. + """ + if model_config is None: + model_config = get_active_model() + + # Use model-specific defaults if not provided + if max_new_tokens is None: + max_new_tokens = model_config.sampling.num_predict + payload = { - "model": OLLAMA_MODEL, + "model": model_config.ollama_model, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": code} ], "stream": False, - "keep_alive": 0, + "keep_alive": model_config.keep_alive, "options": { - "num_predict": max_new_tokens + "num_predict": max_new_tokens, + "temperature": model_config.sampling.temperature, + "top_p": model_config.sampling.top_p, + "top_k": model_config.sampling.top_k, } } - + try: resp = requests.post(OLLAMA_URL, json=payload, timeout=REQUEST_TIMEOUT) resp.raise_for_status() data = resp.json() - + # Handle /api/chat response format if "message" in data: return data["message"].get("content", "") @@ -88,36 +138,84 @@ def check_ollama_health() -> bool: # FastAPI app -app = FastAPI(title="Docstring Generation API", version="0.1.0") +app = FastAPI( + title="Docstring Generation API", + version="0.2.0", + description="Generate Python docstrings using Qwen Coder models via Ollama" +) class GenerateRequest(BaseModel): """Request model for docstring generation.""" code: str - max_new_tokens: Optional[int] = 256 + max_new_tokens: Optional[int] = None + model: Optional[str] = None class GenerateResponse(BaseModel): """Response model for docstring generation.""" docstring: str + model: str + + +class HealthResponse(BaseModel): + """Response model for health check.""" + status: str + service: str + active_model: str + ollama_model: str + + +class ModelsResponse(BaseModel): + """Response model for listing available models.""" + default: str + active: str + models: list @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest): - """Generate a docstring for the given code snippet.""" + """Generate a docstring for the given code snippet. + + Optionally specify a model to use for generation. If not specified, + uses the active model from the OLLAMA_MODEL environment variable. + """ try: - docstring = generate_docstring(request.code, request.max_new_tokens) - return GenerateResponse(docstring=docstring) + # Determine which model to use + if request.model: + model_config = get_model_config(request.model) + else: + model_config = get_active_model() + + docstring = generate_docstring( + request.code, + request.max_new_tokens, + model_config + ) + return GenerateResponse( + docstring=docstring, + model=model_config.ollama_model + ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@app.get("/health") +@app.get("/health", response_model=HealthResponse) async def health(): - """Health check endpoint that verifies ollama is running.""" + """Health check endpoint that verifies ollama is running. + + Returns information about the active model configuration. + """ is_healthy = check_ollama_health() + active_model = get_active_model() + if is_healthy: - return {"status": "healthy", "service": "ollama"} + return HealthResponse( + status="healthy", + service="ollama", + active_model=active_model.name, + ollama_model=active_model.ollama_model + ) else: raise HTTPException( status_code=503, @@ -125,6 +223,21 @@ async def health(): ) +@app.get("/models", response_model=ModelsResponse) +async def get_models(): + """List all available model configurations. + + Returns the default model, currently active model, and a list of + all registered models with their configurations. + """ + active_model = get_active_model() + return ModelsResponse( + default=DEFAULT_MODEL_KEY, + active=OLLAMA_MODEL, + models=list_models() + ) + + def main(): """Start the inference server.""" import uvicorn diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..0954689 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,199 @@ +"""Tests for src.training.models module.""" + +import pytest + +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + SamplingConfig, + get_model_config, + create_fallback_config, + list_models, +) + + +class TestSamplingConfig: + """Tests for SamplingConfig dataclass.""" + + def test_default_values(self): + """SamplingConfig should have sensible defaults.""" + config = SamplingConfig() + assert config.temperature == 0.7 + assert config.top_p == 0.9 + assert config.top_k == 40 + assert config.num_predict == 256 + + def test_custom_values(self): + """SamplingConfig should accept custom values.""" + config = SamplingConfig(temperature=1.0, top_p=0.95, top_k=50, num_predict=512) + assert config.temperature == 1.0 + assert config.top_p == 0.95 + assert config.top_k == 50 + assert config.num_predict == 512 + + def test_immutable(self): + """SamplingConfig should be immutable (frozen).""" + config = SamplingConfig() + with pytest.raises(AttributeError): + config.temperature = 0.5 + + +class TestModelConfig: + """Tests for ModelConfig dataclass.""" + + def test_required_fields(self): + """ModelConfig should require essential fields.""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + assert config.name == "Test Model" + assert config.ollama_model == "test:latest" + assert config.context_window == 4096 + + def test_default_values(self): + """ModelConfig should have sensible defaults for optional fields.""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + assert config.keep_alive == 0 + assert config.architecture == "dense" + assert config.total_params is None + assert config.memory_q4 is None + assert config.description == "" + + def test_immutable(self): + """ModelConfig should be immutable (frozen).""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + with pytest.raises(AttributeError): + config.name = "New Name" + + +class TestModelRegistry: + """Tests for MODEL_REGISTRY and related functions.""" + + def test_registry_has_default_model(self): + """Registry should contain the default model.""" + assert DEFAULT_MODEL_KEY in MODEL_REGISTRY + + def test_registry_models_are_valid(self): + """All models in registry should be valid ModelConfig instances.""" + for key, config in MODEL_REGISTRY.items(): + assert isinstance(config, ModelConfig) + assert isinstance(config.sampling, SamplingConfig) + assert config.context_window > 0 + assert config.ollama_model + + def test_qwen25_coder_32b_config(self): + """Qwen 2.5 Coder 32B should have correct configuration.""" + config = MODEL_REGISTRY["qwen2.5-coder-32b"] + assert config.name == "Qwen 2.5 Coder 32B" + assert config.ollama_model == "qwen2.5-coder:32b" + assert config.context_window == 32768 + assert config.architecture == "dense" + assert config.sampling.temperature == 0.7 + assert config.sampling.top_p == 0.9 + + def test_qwen25_coder_7b_config(self): + """Qwen 2.5 Coder 7B should have correct configuration.""" + config = MODEL_REGISTRY["qwen2.5-coder-7b"] + assert config.name == "Qwen 2.5 Coder 7B" + assert config.ollama_model == "qwen2.5-coder:7b" + assert config.architecture == "dense" + + def test_qwen3_coder_30b_config(self): + """Qwen3 Coder 30B should have correct MoE configuration.""" + config = MODEL_REGISTRY["qwen3-coder-30b"] + assert config.name == "Qwen3 Coder 30B-A3B" + assert config.ollama_model == "qwen3-coder:30b-a3b" + assert config.context_window == 262144 # 256K + assert config.architecture == "moe" + # Qwen3 uses different sampling parameters + assert config.sampling.temperature == 1.0 + assert config.sampling.top_p == 0.95 + # MoE models have longer keep_alive + assert config.keep_alive == 300 + + +class TestGetModelConfig: + """Tests for get_model_config function.""" + + def test_get_by_registry_key(self): + """Should return config when given a valid registry key.""" + config = get_model_config("qwen2.5-coder-32b") + assert config.ollama_model == "qwen2.5-coder:32b" + + def test_get_by_ollama_model_name(self): + """Should return config when given a valid Ollama model name.""" + config = get_model_config("qwen2.5-coder:32b") + assert config.name == "Qwen 2.5 Coder 32B" + + def test_fallback_for_unknown_model(self): + """Should create fallback config for unknown models.""" + config = get_model_config("unknown-model:latest") + assert config.ollama_model == "unknown-model:latest" + assert config.architecture == "unknown" + assert config.context_window == 32768 # Conservative default + + def test_fallback_uses_default_sampling(self): + """Fallback config should use default sampling parameters.""" + config = get_model_config("custom-model:v1") + assert config.sampling.temperature == 0.7 + assert config.sampling.top_p == 0.9 + assert config.sampling.top_k == 40 + + +class TestCreateFallbackConfig: + """Tests for create_fallback_config function.""" + + def test_creates_valid_config(self): + """Should create a valid ModelConfig for any model name.""" + config = create_fallback_config("my-model:latest") + assert isinstance(config, ModelConfig) + assert config.ollama_model == "my-model:latest" + assert config.name == "my-model:latest" + + def test_uses_conservative_defaults(self): + """Fallback should use conservative defaults.""" + config = create_fallback_config("test") + assert config.context_window == 32768 + assert config.keep_alive == 0 + assert config.architecture == "unknown" + + +class TestListModels: + """Tests for list_models function.""" + + def test_returns_list(self): + """Should return a list of model info dictionaries.""" + models = list_models() + assert isinstance(models, list) + assert len(models) == len(MODEL_REGISTRY) + + def test_model_info_structure(self): + """Each model info should have required fields.""" + models = list_models() + for model_info in models: + assert "key" in model_info + assert "name" in model_info + assert "ollama_model" in model_info + assert "context_window" in model_info + assert "architecture" in model_info + assert "description" in model_info + + def test_contains_all_registry_models(self): + """Should contain all models from the registry.""" + models = list_models() + keys = {m["key"] for m in models} + assert keys == set(MODEL_REGISTRY.keys()) diff --git a/tests/test_serve.py b/tests/test_serve.py index 922211d..6d932be 100644 --- a/tests/test_serve.py +++ b/tests/test_serve.py @@ -6,7 +6,8 @@ import pytest from fastapi.testclient import TestClient -from src.training.serve import app, check_ollama_health, generate_docstring +from src.training.serve import app, check_ollama_health, generate_docstring, get_active_model +from src.training.models import MODEL_REGISTRY, DEFAULT_MODEL_KEY, get_model_config @pytest.fixture @@ -32,6 +33,9 @@ def test_health_success(self, mock_get, client): data = response.json() assert data["status"] == "healthy" assert data["service"] == "ollama" + # New fields for model info + assert "active_model" in data + assert "ollama_model" in data mock_get.assert_called_once() @patch("src.training.serve.requests.get") @@ -99,6 +103,8 @@ def test_generate_success_message_format(self, mock_post, client): data = response.json() assert "docstring" in data assert "Compute the sum" in data["docstring"] + # New field for model info + assert "model" in data mock_post.assert_called_once() @patch("src.training.serve.requests.post") @@ -310,3 +316,187 @@ def test_generate_docstring_failure(self, mock_post): generate_docstring("def test(): pass") assert "Failed to generate docstring" in str(exc_info.value) + + +class TestModelsEndpoint: + """Tests for the /models endpoint.""" + + def test_list_models_returns_all(self, client): + """GET /models should return all registered models.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "models" in data + assert len(data["models"]) == len(MODEL_REGISTRY) + + def test_list_models_includes_default(self, client): + """Response should indicate the default model.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "default" in data + assert data["default"] == DEFAULT_MODEL_KEY + + def test_list_models_includes_active(self, client): + """Response should indicate the active model.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "active" in data + + def test_list_models_model_structure(self, client): + """Each model in the list should have required fields.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + for model in data["models"]: + assert "key" in model + assert "name" in model + assert "ollama_model" in model + assert "context_window" in model + assert "architecture" in model + + +class TestModelSelection: + """Tests for per-request model selection.""" + + @patch("src.training.serve.requests.post") + def test_generate_with_specific_model_key(self, mock_post, client): + """Should use the specified model key in the request.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass", + "model": "qwen2.5-coder-7b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["model"] == "qwen2.5-coder:7b" + + # Verify the correct model was used in the payload + call_args = mock_post.call_args + payload = call_args[1]["json"] + assert payload["model"] == "qwen2.5-coder:7b" + + @patch("src.training.serve.requests.post") + def test_generate_with_raw_ollama_model(self, mock_post, client): + """Should accept raw Ollama model names.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass", + "model": "qwen2.5-coder:14b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + @patch("src.training.serve.requests.post") + def test_generate_applies_model_sampling(self, mock_post, client): + """Should apply model-specific sampling parameters.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + # Use Qwen3 which has different sampling params + request_data = { + "code": "def test(): pass", + "model": "qwen3-coder-30b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify Qwen3 sampling parameters were used + call_args = mock_post.call_args + payload = call_args[1]["json"] + options = payload["options"] + assert options["temperature"] == 1.0 + assert options["top_p"] == 0.95 + assert options["top_k"] == 40 + + @patch("src.training.serve.requests.post") + def test_generate_applies_model_keep_alive(self, mock_post, client): + """Should apply model-specific keep_alive setting.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + # Use Qwen3 which has keep_alive=300 + request_data = { + "code": "def test(): pass", + "model": "qwen3-coder-30b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify Qwen3 keep_alive was used + call_args = mock_post.call_args + payload = call_args[1]["json"] + assert payload["keep_alive"] == 300 + + @patch("src.training.serve.requests.post") + def test_generate_without_model_uses_default(self, mock_post, client): + """Should use default model when none specified.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": "def test(): pass"} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify default model was used + call_args = mock_post.call_args + payload = call_args[1]["json"] + default_config = get_model_config(DEFAULT_MODEL_KEY) + assert payload["model"] == default_config.ollama_model + + +class TestHealthWithModel: + """Tests for health endpoint with model information.""" + + @patch("src.training.serve.requests.get") + def test_health_reports_active_model(self, mock_get, client): + """Health endpoint should report the active model.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert "active_model" in data + assert "ollama_model" in data + # Should match the configured model + active_model = get_active_model() + assert data["active_model"] == active_model.name + assert data["ollama_model"] == active_model.ollama_model