diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..f3a8404 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,27 @@ +# Auto detect text files and perform LF normalization +* text=auto + +# Python files +*.py text eol=lf + +# Markdown files +*.md text eol=lf + +# JSON files +*.json text eol=lf + +# YAML files +*.yml text eol=lf +*.yaml text eol=lf + +# Shell scripts +*.sh text eol=lf + +# Configuration files +*.toml text eol=lf +*.cfg text eol=lf +*.ini text eol=lf + +# Keep Windows batch files with CRLF +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a00ba4b --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,43 @@ +name: CI + +on: + push: + branches: [master, qwen2.5-coder] + pull_request: + branches: [master, qwen2.5-coder] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run linter + run: python -m ruff check . + continue-on-error: true + + - name: Run tests with coverage + run: | + python -m pytest tests/ -v --tb=short --cov=src --cov-report=term-missing --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + flags: unittests + name: ci-coverage + fail_ci_if_error: false + verbose: true diff --git a/.gitignore b/.gitignore index ad3c25b..d9c1a25 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ build/ # IDE .vscode/ +.pytest_cache/ # Jupyter notebooks/.ipynb_checkpoints/ diff --git a/README.md b/README.md index b94317a..89e2819 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,11 @@ Test Dataset + Model Predictions --> [benchmark.py] --> Metrics Report - **`train_lora.py`** - LoRA fine-tuning using HuggingFace Trainer + PEFT. Supports QLoRA (4-bit quantization) for training on 1-2 A100 GPUs. -- **`serve.py`** - FastAPI inference server that loads the fine-tuned model and - serves docstring generation via HTTP. +- **`serve.py`** - FastAPI inference server that uses ollama API to generate + docstrings. Supports multiple Qwen Coder models with model-specific configurations. + +- **`models.py`** - Model configuration registry with sampling parameters for + Qwen 2.5 Coder and Qwen3 Coder variants. ### Evaluation (`src/evaluation/`) @@ -87,6 +90,223 @@ python -m src.data.convert_seed \ --output-dir data/processed/python-method ``` +## Serving + +The FastAPI inference server provides HTTP endpoints for docstring generation using +ollama as the backend. The server uses a system prompt stored in +`src/training/prompts/system_prompt.md` to generate NumPy-style docstrings. + +### Prerequisites + +1. **Install ollama**: Make sure [ollama](https://ollama.ai/) is installed and running locally +2. **Pull a model**: Download one of the supported code models: + ```bash + # Qwen 2.5 Coder (dense models) + ollama pull qwen2.5-coder:32b # Default, ~18GB Q4 + ollama pull qwen2.5-coder:14b # Mid-size, ~8GB Q4 + ollama pull qwen2.5-coder:7b # Fast, ~4GB Q4 + + # Qwen3 Coder (MoE model) + ollama pull qwen3-coder:30b-a3b # Best quality, ~18GB Q4, 256K context + ``` + +### Starting the Server + +Start the FastAPI server using uvicorn: + +**Linux/macOS:** +```bash +# Using uvicorn directly +uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 + +# Or run the module directly +python -m src.training.serve +``` + +**Windows (PowerShell):** +```powershell +uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 +``` + +The server will start on `http://localhost:8000` by default. + +### Configuration + +The server can be configured using environment variables: + +- `OLLAMA_URL` - Ollama API endpoint (default: `http://localhost:11434/api/chat`) +- `OLLAMA_MODEL` - Model key or Ollama model name (default: `qwen2.5-coder-32b`) +- `REQUEST_TIMEOUT` - Request timeout in seconds (default: `120.0`) + +**Linux/macOS:** +```bash +OLLAMA_MODEL=qwen3-coder-30b uvicorn src.training.serve:app --port 8000 +``` + +**Windows (PowerShell):** +```powershell +$env:OLLAMA_MODEL="qwen3-coder-30b"; uvicorn src.training.serve:app --port 8000 +``` + +**Windows (CMD):** +```cmd +set OLLAMA_MODEL=qwen3-coder-30b && uvicorn src.training.serve:app --port 8000 +``` + +### Available Models + +| Model Key | Ollama Model | Architecture | Memory (Q4) | Context | Description | +|-----------|--------------|--------------|-------------|---------|-------------| +| `qwen2.5-coder-32b` | `qwen2.5-coder:32b` | Dense | ~18GB | 32K | Default, balanced quality/speed | +| `qwen2.5-coder-14b` | `qwen2.5-coder:14b` | Dense | ~8GB | 32K | Mid-size, good performance | +| `qwen2.5-coder-7b` | `qwen2.5-coder:7b` | Dense | ~4GB | 32K | Fast inference | +| `qwen3-coder-30b` | `qwen3-coder:30b-a3b` | MoE | ~18GB | 256K | Best quality, 3.3B active params | + +Each model has optimized sampling parameters: +- **Qwen 2.5 Coder**: temperature=0.7, top_p=0.9, top_k=40 +- **Qwen3 Coder**: temperature=1.0, top_p=0.95, top_k=40 (per official recommendations) + +### Model Selection + +You can select a model in two ways: + +1. **Environment variable** (applies to all requests): + ```bash + OLLAMA_MODEL=qwen3-coder-30b uvicorn src.training.serve:app + ``` + +2. **Per-request** (via API): + ```bash + curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{"code": "def add(x, y): return x + y", "model": "qwen3-coder-30b"}' + ``` + +### List Available Models + +**Via CLI:** +```bash +python scripts/run_ollama.py --list-models +``` + +**Via API:** +```bash +curl http://localhost:8000/models +``` + +### API Endpoints + +#### Health Check + +Check if the service is healthy and ollama is accessible: + +```bash +curl http://localhost:8000/health +``` + +**Response (200 OK):** +```json +{ + "status": "healthy", + "service": "ollama", + "active_model": "Qwen 2.5 Coder 32B", + "ollama_model": "qwen2.5-coder:32b" +} +``` + +**Response (503 Service Unavailable):** +```json +{ + "detail": "Service unhealthy: ollama is not running or not accessible" +} +``` + +#### Generate Docstring + +Generate a docstring for a Python function: + +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "code": "def add(x, y):\n return x + y", + "max_new_tokens": 256 + }' +``` + +**Request Body:** +- `code` (required): Python function code as a string +- `max_new_tokens` (optional): Maximum number of tokens to generate (uses model default if not specified) +- `model` (optional): Model key or Ollama model name to use for this request + +**Response (200 OK):** +```json +{ + "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\"\"\"", + "model": "qwen2.5-coder:32b" +} +``` + +**Response (500 Internal Server Error):** +```json +{ + "detail": "Failed to generate docstring: " +} +``` + +#### List Models + +Get available model configurations: + +```bash +curl http://localhost:8000/models +``` + +**Response (200 OK):** +```json +{ + "default": "qwen2.5-coder-32b", + "active": "qwen2.5-coder-32b", + "models": [ + { + "key": "qwen2.5-coder-32b", + "name": "Qwen 2.5 Coder 32B", + "ollama_model": "qwen2.5-coder:32b", + "context_window": 32768, + "architecture": "dense", + "memory_q4": "~18GB", + "description": "Dense 32B model, good balance of quality and speed" + } + ] +} +``` + +### CLI Tool + +The CLI tool allows testing docstring generation directly: + +```bash +# Use default model +python scripts/run_ollama.py --user "def add(x, y): return x + y" + +# Use specific model by key +python scripts/run_ollama.py --model-key qwen3-coder-30b --user "def foo(): pass" + +# Use raw Ollama model name +python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def bar(): pass" + +# List available models +python scripts/run_ollama.py --list-models +``` + +### Testing + +Run the test suite to verify the API endpoints: + +```bash +pytest tests/test_serve.py tests/test_models.py -v +``` + ## Dataset The seed dataset comes from the [NeuralCodeSum](https://github.com/wasiahmad/NeuralCodeSum) diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..3f4c24a --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +comment: + layout: "reach,diff,flags,tree" + behavior: default + require_changes: false + require_base: no + require_head: yes + +# Optional: configure thresholds or ignore patterns below +# coverage: +# precision: 2 +# round: down \ No newline at end of file diff --git a/pyproject.toml b/pyproject.toml index 50756b8..8bd2fa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,15 @@ dependencies = [ "safetensors", "fastapi>=0.104.0", "uvicorn>=0.24.0", + "requests>=2.31.0", ] [project.optional-dependencies] dev = [ "pytest>=7.0", + "pytest-cov>=4.0", "ruff>=0.1.0", + "httpx>=0.24.0", ] [tool.hatch.build.targets.wheel] diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py new file mode 100644 index 0000000..40496d0 --- /dev/null +++ b/scripts/run_ollama.py @@ -0,0 +1,238 @@ +"""CLI script for testing docstring generation with Ollama. + +Supports model selection via registry keys or raw Ollama model names, +with model-specific sampling parameters. + +Usage: + python scripts/run_ollama.py --model-key qwen2.5-coder-32b --user "def add(x, y): return x + y" + python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def foo(): pass" + python scripts/run_ollama.py --list-models +""" + +import argparse +import json +import sys +import time +from pathlib import Path + +import requests + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + SamplingConfig, + get_model_config, + list_models, +) + + +DEFAULT_URL = "http://localhost:11434/api/chat" + + +def load_system_prompt() -> str: + """Load the default system prompt from the prompts directory.""" + prompt_path = Path(__file__).parent.parent / "src" / "training" / "prompts" / "system_prompt.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"System prompt file not found: {prompt_path}. " + "Please ensure the prompt file exists." + ) + return prompt_path.read_text(encoding="utf-8") + + +def build_payload( + model_config: ModelConfig, + system_msgs: list[str], + user_msgs: list[str], + stream: bool +) -> dict: + """Build the request payload with model-specific sampling parameters.""" + messages = [] + for s in system_msgs: + messages.append({"role": "system", "content": s}) + for u in user_msgs: + messages.append({"role": "user", "content": u}) + + return { + "model": model_config.ollama_model, + "messages": messages, + "stream": stream, + "keep_alive": model_config.keep_alive, + "options": { + "temperature": model_config.sampling.temperature, + "top_p": model_config.sampling.top_p, + "top_k": model_config.sampling.top_k, + } + } + + +def print_models_list(): + """Print all available models in a formatted table.""" + print("\nAvailable Models:") + print("=" * 80) + print(f"{'Key':<22} {'Ollama Model':<25} {'Arch':<6} {'Memory':<8} {'Context'}") + print("-" * 80) + + for model_info in list_models(): + key = model_info["key"] + ollama = model_info["ollama_model"] + arch = model_info["architecture"] + memory = model_info.get("memory_q4", "N/A") + context = f"{model_info['context_window']:,}" + + # Mark default model + marker = " *" if key == DEFAULT_MODEL_KEY else "" + print(f"{key:<22} {ollama:<25} {arch:<6} {memory:<8} {context}{marker}") + + print("-" * 80) + print(f"* = default model ({DEFAULT_MODEL_KEY})") + print() + + +def main(): + parser = argparse.ArgumentParser( + description="Send prompts to Ollama for docstring generation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use a registered model by key + python scripts/run_ollama.py --model-key qwen2.5-coder-32b --user "def add(x, y): return x + y" + + # Use a raw Ollama model name + python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def foo(): pass" + + # Use Qwen3 MoE model + python scripts/run_ollama.py --model-key qwen3-coder-30b --user "def hello(): print('hi')" + + # List available models + python scripts/run_ollama.py --list-models +""" + ) + parser.add_argument( + "--url", + default=DEFAULT_URL, + help="API endpoint URL (default: %(default)s)" + ) + parser.add_argument( + "--model-key", + help=f"Model key from registry (e.g., {DEFAULT_MODEL_KEY}, qwen3-coder-30b)" + ) + parser.add_argument( + "--model", + help="Raw Ollama model name (e.g., qwen2.5-coder:32b). Overrides --model-key" + ) + parser.add_argument( + "--list-models", + action="store_true", + help="List available model configurations and exit" + ) + parser.add_argument( + "--system", + action="append", + default=None, + help="System prompt (repeatable). Overrides default." + ) + parser.add_argument( + "--user", + action="append", + default=[], + help="User prompt (repeatable)" + ) + parser.add_argument( + "--stream", + action="store_true", + help="Enable streaming mode" + ) + parser.add_argument( + "--timeout", + type=float, + default=120.0, + help="Request timeout in seconds (default: %(default)s)" + ) + + args = parser.parse_args() + + # Handle --list-models + if args.list_models: + print_models_list() + sys.exit(0) + + # Determine model configuration + if args.model: + # Raw Ollama model name provided + model_config = get_model_config(args.model) + print(f"Using model: {model_config.ollama_model}", file=sys.stderr) + elif args.model_key: + # Registry key provided + model_config = get_model_config(args.model_key) + print(f"Using model: {model_config.name} ({model_config.ollama_model})", file=sys.stderr) + else: + # Use default + model_config = get_model_config(DEFAULT_MODEL_KEY) + print(f"Using default model: {model_config.name} ({model_config.ollama_model})", file=sys.stderr) + + # Print sampling parameters + print( + f"Sampling: temp={model_config.sampling.temperature}, " + f"top_p={model_config.sampling.top_p}, " + f"top_k={model_config.sampling.top_k}", + file=sys.stderr + ) + + # Use default system prompt if none provided + if args.system: + system_msgs = args.system + else: + try: + system_msgs = [load_system_prompt()] + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) + + if not args.user: + print("Error: at least one --user prompt is required.", file=sys.stderr) + print("Use --help for usage information.", file=sys.stderr) + sys.exit(2) + + payload = build_payload(model_config, system_msgs, args.user, args.stream) + + # Track execution time + start_time = time.time() + try: + resp = requests.post(args.url, json=payload, timeout=args.timeout) + resp.raise_for_status() + except requests.RequestException as e: + elapsed_time = time.time() - start_time + print(f"Request failed: {e}", file=sys.stderr) + print(f"Execution time: {elapsed_time:.2f}s", file=sys.stderr) + sys.exit(1) + + try: + data = resp.json() + except ValueError: + print("Response is not valid JSON", file=sys.stderr) + print(resp.text, file=sys.stderr) + sys.exit(1) + + # Handle /api/chat response format + if "message" in data: + print(data["message"].get("content", "")) + elif "response" in data: + print(data["response"]) + elif "choices" in data and isinstance(data["choices"], list): + for c in data["choices"]: + print(c.get("message", {}).get("content", c.get("text", ""))) + else: + print(json.dumps(data, indent=2)) + + # Print execution time + elapsed_time = time.time() - start_time + print(f"\nExecution time: {elapsed_time:.2f}s", file=sys.stderr) + + +if __name__ == "__main__": + main() diff --git a/src/training/models.py b/src/training/models.py new file mode 100644 index 0000000..3040ca2 --- /dev/null +++ b/src/training/models.py @@ -0,0 +1,193 @@ +"""Model configuration registry for docstring generation. + +Provides model-specific configurations including sampling parameters, +context windows, and memory requirements for different LLM backends. +""" + +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class SamplingConfig: + """Sampling parameters for text generation. + + Parameters + ---------- + temperature : float + Controls randomness in generation. Higher values = more random. + top_p : float + Nucleus sampling parameter. Considers tokens with cumulative probability >= top_p. + top_k : int + Limits sampling to top k most likely tokens. + num_predict : int + Maximum number of tokens to generate. + """ + temperature: float = 0.7 + top_p: float = 0.9 + top_k: int = 40 + num_predict: int = 256 + + +@dataclass(frozen=True) +class ModelConfig: + """Configuration for a specific model. + + Parameters + ---------- + name : str + Human-readable display name for the model. + ollama_model : str + Ollama model identifier (e.g., "qwen2.5-coder:32b"). + context_window : int + Maximum context length in tokens. + sampling : SamplingConfig + Default sampling parameters for this model. + keep_alive : int + Model keep-alive time in seconds (0 = unload immediately after request). + architecture : str + Model architecture type: "dense" or "moe". + total_params : str, optional + Total parameter count description (e.g., "32B", "30B (3.3B active)"). + memory_q4 : str, optional + Approximate memory requirement at Q4 quantization. + description : str + Brief description of the model's characteristics. + """ + name: str + ollama_model: str + context_window: int + sampling: SamplingConfig + keep_alive: int = 0 + architecture: str = "dense" + total_params: Optional[str] = None + memory_q4: Optional[str] = None + description: str = "" + + +# Pre-defined model configurations +MODEL_REGISTRY: dict[str, ModelConfig] = { + "qwen2.5-coder-32b": ModelConfig( + name="Qwen 2.5 Coder 32B", + ollama_model="qwen2.5-coder:32b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="32B", + memory_q4="~18GB", + description="Dense 32B model, good balance of quality and speed" + ), + "qwen2.5-coder-14b": ModelConfig( + name="Qwen 2.5 Coder 14B", + ollama_model="qwen2.5-coder:14b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="14B", + memory_q4="~8GB", + description="Mid-size dense model, balanced performance" + ), + "qwen2.5-coder-7b": ModelConfig( + name="Qwen 2.5 Coder 7B", + ollama_model="qwen2.5-coder:7b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="7B", + memory_q4="~4GB", + description="Smaller variant, faster inference" + ), + "qwen3-coder-30b": ModelConfig( + name="Qwen3 Coder 30B-A3B", + ollama_model="qwen3-coder:30b-a3b", + context_window=262144, # 256K tokens + sampling=SamplingConfig(temperature=1.0, top_p=0.95, top_k=40), + keep_alive=300, # Longer keep_alive for MoE model to avoid reload overhead + architecture="moe", + total_params="30B (3.3B active)", + memory_q4="~18GB", + description="MoE model with 256K context, best quality" + ), +} + +# Default model key +DEFAULT_MODEL_KEY = "qwen2.5-coder-32b" + + +def get_model_config(model_key: str) -> ModelConfig: + """Get configuration for a model by registry key or Ollama model name. + + Parameters + ---------- + model_key : str + Either a registry key (e.g., "qwen2.5-coder-32b") or a raw Ollama + model name (e.g., "qwen2.5-coder:32b"). + + Returns + ------- + ModelConfig + The model configuration. If the key is not found in the registry, + creates a fallback configuration with default sampling parameters. + """ + # First, try direct registry lookup + if model_key in MODEL_REGISTRY: + return MODEL_REGISTRY[model_key] + + # Check if it matches any ollama_model in the registry + for config in MODEL_REGISTRY.values(): + if config.ollama_model == model_key: + return config + + # Fallback: create a default config for unknown models + return create_fallback_config(model_key) + + +def create_fallback_config(ollama_model: str) -> ModelConfig: + """Create a fallback configuration for unknown models. + + Parameters + ---------- + ollama_model : str + The Ollama model identifier. + + Returns + ------- + ModelConfig + A configuration with default sampling parameters. + """ + return ModelConfig( + name=ollama_model, + ollama_model=ollama_model, + context_window=32768, # Conservative default + sampling=SamplingConfig(), # Use defaults + keep_alive=0, + architecture="unknown", + description="Custom model (using default configuration)" + ) + + +def list_models() -> list[dict]: + """List all available model configurations. + + Returns + ------- + list of dict + List of model information dictionaries containing key, name, + ollama_model, context_window, architecture, memory_q4, and description. + """ + return [ + { + "key": key, + "name": config.name, + "ollama_model": config.ollama_model, + "context_window": config.context_window, + "architecture": config.architecture, + "total_params": config.total_params, + "memory_q4": config.memory_q4, + "description": config.description, + } + for key, config in MODEL_REGISTRY.items() + ] diff --git a/src/training/prompts/system_prompt.md b/src/training/prompts/system_prompt.md new file mode 100644 index 0000000..770e1da --- /dev/null +++ b/src/training/prompts/system_prompt.md @@ -0,0 +1,77 @@ +You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. + +## Output Rules +- Output ONLY the docstring content (including the triple quotes) +- Do NOT include the function signature or body +- Do NOT add any explanation before or after the docstring + +## NumPy Docstring Format + +### Structure (include sections only when applicable) +""" +Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). + +Extended summary providing more details about the function behavior, +algorithm, or implementation notes. Optional but recommended for +complex functions. + +Parameters +---------- +param_name : type + Description of the parameter. If the description spans multiple + lines, indent continuation lines. +param_name : type, optional + For optional parameters, specify default value in description. + Default is `default_value`. +*args : type + Description of variable positional arguments. +**kwargs : type + Description of variable keyword arguments. + +Returns +------- +type + Description of return value. +name : type + Use this format when returning named values or multiple values. + +Yields +------ +type + For generator functions, describe yielded values. + +Raises +------ +ExceptionType + Explanation of when this exception is raised. + +See Also +-------- +related_function : Brief description of relation. + +Notes +----- +Additional technical notes, mathematical formulas (using LaTeX), +or implementation details. + +Examples +-------- +>>> function_name(arg1, arg2) +expected_output +""" + +### Type Annotation Conventions +- Basic types: `int`, `float`, `str`, `bool`, `None` +- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` +- Multiple types: `int or float`, `str or None` +- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` +- Callable: `callable` +- Optional params: append `, optional` after type + +### Guidelines +1. First line: concise, imperative verb, no variable names, ends with period +2. Leave one blank line after the summary before Parameters +3. Align parameter descriptions consistently +4. Include realistic, runnable Examples when behavior isn't obvious +5. Document all exceptions that may be explicitly raised +6. For boolean params, describe what True/False means diff --git a/src/training/serve.py b/src/training/serve.py index 619ca5c..cbcde12 100644 --- a/src/training/serve.py +++ b/src/training/serve.py @@ -1,7 +1,7 @@ -"""FastAPI inference server for the fine-tuned docstring generation model. +"""FastAPI inference server for docstring generation using multiple LLM backends. -Loads a LoRA-adapted model and serves predictions via HTTP. Designed to be -called by the VS Code extension for local, offline docstring generation. +Serves predictions via HTTP using Ollama as the backend. Supports multiple models +including Qwen 2.5 Coder and Qwen3 Coder variants with model-specific configurations. Usage: uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 @@ -9,32 +9,239 @@ Endpoints: POST /generate - Generate a docstring for a given code snippet GET /health - Health check + GET /models - List available models """ +import os +from pathlib import Path +from typing import Optional -def load_model(base_model: str, adapter_path: str): - """Load base model + LoRA adapter for inference.""" - raise NotImplementedError +import requests +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + get_model_config, + list_models, +) -def generate_docstring(code: str, max_new_tokens: int = 256) -> str: - """Generate a docstring for the given code snippet.""" - raise NotImplementedError +# Configuration +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/chat") +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", DEFAULT_MODEL_KEY) +REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "120.0")) -# FastAPI app will be defined here once dependencies are implemented. -# app = FastAPI() -# -# @app.post("/generate") -# async def generate(request: dict): ... -# -# @app.get("/health") -# async def health(): ... +def get_active_model() -> ModelConfig: + """Get the currently active model configuration from environment.""" + return get_model_config(OLLAMA_MODEL) + + +def load_system_prompt() -> str: + """Load the default system prompt from the prompts directory.""" + prompt_path = Path(__file__).parent / "prompts" / "system_prompt.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"System prompt file not found: {prompt_path}. " + "Please ensure the prompt file exists." + ) + return prompt_path.read_text(encoding="utf-8") + + +# Load system prompt at module level +SYSTEM_PROMPT = load_system_prompt() + + +def generate_docstring( + code: str, + max_new_tokens: Optional[int] = None, + model_config: Optional[ModelConfig] = None +) -> str: + """Generate a docstring for the given code snippet using ollama API. + + Parameters + ---------- + code : str + The Python code snippet to generate a docstring for. + max_new_tokens : int, optional + Maximum number of tokens to generate. If None, uses the model's default. + model_config : ModelConfig, optional + Model configuration to use. If None, uses the active model from environment. + + Returns + ------- + str + The generated docstring. + + Raises + ------ + RuntimeError + If the request to ollama fails. + ValueError + If the response format is unexpected. + """ + if model_config is None: + model_config = get_active_model() + + # Use model-specific defaults if not provided + if max_new_tokens is None: + max_new_tokens = model_config.sampling.num_predict + + payload = { + "model": model_config.ollama_model, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": code} + ], + "stream": False, + "keep_alive": model_config.keep_alive, + "options": { + "num_predict": max_new_tokens, + "temperature": model_config.sampling.temperature, + "top_p": model_config.sampling.top_p, + "top_k": model_config.sampling.top_k, + } + } + + try: + resp = requests.post(OLLAMA_URL, json=payload, timeout=REQUEST_TIMEOUT) + resp.raise_for_status() + data = resp.json() + + # Handle /api/chat response format + if "message" in data: + return data["message"].get("content", "") + elif "response" in data: + return data["response"] + elif "choices" in data and isinstance(data["choices"], list): + content = "" + for c in data["choices"]: + content += c.get("message", {}).get("content", c.get("text", "")) + return content + else: + raise ValueError(f"Unexpected response format: {data}") + except requests.RequestException as e: + raise RuntimeError(f"Failed to generate docstring: {e}") from e + + +def check_ollama_health() -> bool: + """Check if ollama is running locally by making a test request.""" + try: + # Try to list models as a health check + health_url = OLLAMA_URL.replace("/api/chat", "/api/tags") + resp = requests.get(health_url, timeout=5.0) + return resp.status_code == 200 + except requests.RequestException: + return False + + +# FastAPI app +app = FastAPI( + title="Docstring Generation API", + version="0.2.0", + description="Generate Python docstrings using Qwen Coder models via Ollama" +) + + +class GenerateRequest(BaseModel): + """Request model for docstring generation.""" + code: str + max_new_tokens: Optional[int] = None + model: Optional[str] = None + + +class GenerateResponse(BaseModel): + """Response model for docstring generation.""" + docstring: str + model: str + + +class HealthResponse(BaseModel): + """Response model for health check.""" + status: str + service: str + active_model: str + ollama_model: str + + +class ModelsResponse(BaseModel): + """Response model for listing available models.""" + default: str + active: str + models: list + + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): + """Generate a docstring for the given code snippet. + + Optionally specify a model to use for generation. If not specified, + uses the active model from the OLLAMA_MODEL environment variable. + """ + try: + # Determine which model to use + if request.model: + model_config = get_model_config(request.model) + else: + model_config = get_active_model() + + docstring = generate_docstring( + request.code, + request.max_new_tokens, + model_config + ) + return GenerateResponse( + docstring=docstring, + model=model_config.ollama_model + ) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) + + +@app.get("/health", response_model=HealthResponse) +async def health(): + """Health check endpoint that verifies ollama is running. + + Returns information about the active model configuration. + """ + is_healthy = check_ollama_health() + active_model = get_active_model() + + if is_healthy: + return HealthResponse( + status="healthy", + service="ollama", + active_model=active_model.name, + ollama_model=active_model.ollama_model + ) + else: + raise HTTPException( + status_code=503, + detail="Service unhealthy: ollama is not running or not accessible" + ) + + +@app.get("/models", response_model=ModelsResponse) +async def get_models(): + """List all available model configurations. + + Returns the default model, currently active model, and a list of + all registered models with their configurations. + """ + active_model = get_active_model() + return ModelsResponse( + default=DEFAULT_MODEL_KEY, + active=OLLAMA_MODEL, + models=list_models() + ) def main(): """Start the inference server.""" - raise NotImplementedError + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) if __name__ == "__main__": diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..0954689 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,199 @@ +"""Tests for src.training.models module.""" + +import pytest + +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + SamplingConfig, + get_model_config, + create_fallback_config, + list_models, +) + + +class TestSamplingConfig: + """Tests for SamplingConfig dataclass.""" + + def test_default_values(self): + """SamplingConfig should have sensible defaults.""" + config = SamplingConfig() + assert config.temperature == 0.7 + assert config.top_p == 0.9 + assert config.top_k == 40 + assert config.num_predict == 256 + + def test_custom_values(self): + """SamplingConfig should accept custom values.""" + config = SamplingConfig(temperature=1.0, top_p=0.95, top_k=50, num_predict=512) + assert config.temperature == 1.0 + assert config.top_p == 0.95 + assert config.top_k == 50 + assert config.num_predict == 512 + + def test_immutable(self): + """SamplingConfig should be immutable (frozen).""" + config = SamplingConfig() + with pytest.raises(AttributeError): + config.temperature = 0.5 + + +class TestModelConfig: + """Tests for ModelConfig dataclass.""" + + def test_required_fields(self): + """ModelConfig should require essential fields.""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + assert config.name == "Test Model" + assert config.ollama_model == "test:latest" + assert config.context_window == 4096 + + def test_default_values(self): + """ModelConfig should have sensible defaults for optional fields.""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + assert config.keep_alive == 0 + assert config.architecture == "dense" + assert config.total_params is None + assert config.memory_q4 is None + assert config.description == "" + + def test_immutable(self): + """ModelConfig should be immutable (frozen).""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + with pytest.raises(AttributeError): + config.name = "New Name" + + +class TestModelRegistry: + """Tests for MODEL_REGISTRY and related functions.""" + + def test_registry_has_default_model(self): + """Registry should contain the default model.""" + assert DEFAULT_MODEL_KEY in MODEL_REGISTRY + + def test_registry_models_are_valid(self): + """All models in registry should be valid ModelConfig instances.""" + for key, config in MODEL_REGISTRY.items(): + assert isinstance(config, ModelConfig) + assert isinstance(config.sampling, SamplingConfig) + assert config.context_window > 0 + assert config.ollama_model + + def test_qwen25_coder_32b_config(self): + """Qwen 2.5 Coder 32B should have correct configuration.""" + config = MODEL_REGISTRY["qwen2.5-coder-32b"] + assert config.name == "Qwen 2.5 Coder 32B" + assert config.ollama_model == "qwen2.5-coder:32b" + assert config.context_window == 32768 + assert config.architecture == "dense" + assert config.sampling.temperature == 0.7 + assert config.sampling.top_p == 0.9 + + def test_qwen25_coder_7b_config(self): + """Qwen 2.5 Coder 7B should have correct configuration.""" + config = MODEL_REGISTRY["qwen2.5-coder-7b"] + assert config.name == "Qwen 2.5 Coder 7B" + assert config.ollama_model == "qwen2.5-coder:7b" + assert config.architecture == "dense" + + def test_qwen3_coder_30b_config(self): + """Qwen3 Coder 30B should have correct MoE configuration.""" + config = MODEL_REGISTRY["qwen3-coder-30b"] + assert config.name == "Qwen3 Coder 30B-A3B" + assert config.ollama_model == "qwen3-coder:30b-a3b" + assert config.context_window == 262144 # 256K + assert config.architecture == "moe" + # Qwen3 uses different sampling parameters + assert config.sampling.temperature == 1.0 + assert config.sampling.top_p == 0.95 + # MoE models have longer keep_alive + assert config.keep_alive == 300 + + +class TestGetModelConfig: + """Tests for get_model_config function.""" + + def test_get_by_registry_key(self): + """Should return config when given a valid registry key.""" + config = get_model_config("qwen2.5-coder-32b") + assert config.ollama_model == "qwen2.5-coder:32b" + + def test_get_by_ollama_model_name(self): + """Should return config when given a valid Ollama model name.""" + config = get_model_config("qwen2.5-coder:32b") + assert config.name == "Qwen 2.5 Coder 32B" + + def test_fallback_for_unknown_model(self): + """Should create fallback config for unknown models.""" + config = get_model_config("unknown-model:latest") + assert config.ollama_model == "unknown-model:latest" + assert config.architecture == "unknown" + assert config.context_window == 32768 # Conservative default + + def test_fallback_uses_default_sampling(self): + """Fallback config should use default sampling parameters.""" + config = get_model_config("custom-model:v1") + assert config.sampling.temperature == 0.7 + assert config.sampling.top_p == 0.9 + assert config.sampling.top_k == 40 + + +class TestCreateFallbackConfig: + """Tests for create_fallback_config function.""" + + def test_creates_valid_config(self): + """Should create a valid ModelConfig for any model name.""" + config = create_fallback_config("my-model:latest") + assert isinstance(config, ModelConfig) + assert config.ollama_model == "my-model:latest" + assert config.name == "my-model:latest" + + def test_uses_conservative_defaults(self): + """Fallback should use conservative defaults.""" + config = create_fallback_config("test") + assert config.context_window == 32768 + assert config.keep_alive == 0 + assert config.architecture == "unknown" + + +class TestListModels: + """Tests for list_models function.""" + + def test_returns_list(self): + """Should return a list of model info dictionaries.""" + models = list_models() + assert isinstance(models, list) + assert len(models) == len(MODEL_REGISTRY) + + def test_model_info_structure(self): + """Each model info should have required fields.""" + models = list_models() + for model_info in models: + assert "key" in model_info + assert "name" in model_info + assert "ollama_model" in model_info + assert "context_window" in model_info + assert "architecture" in model_info + assert "description" in model_info + + def test_contains_all_registry_models(self): + """Should contain all models from the registry.""" + models = list_models() + keys = {m["key"] for m in models} + assert keys == set(MODEL_REGISTRY.keys()) diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 0000000..6d932be --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,502 @@ +"""Tests for src.training.serve module.""" + +import json +from unittest.mock import Mock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.training.serve import app, check_ollama_health, generate_docstring, get_active_model +from src.training.models import MODEL_REGISTRY, DEFAULT_MODEL_KEY, get_model_config + + +@pytest.fixture +def client(): + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +class TestHealthEndpoint: + """Tests for the /health endpoint.""" + + @patch("src.training.serve.requests.get") + def test_health_success(self, mock_get, client): + """Health check should return 200 when ollama is running.""" + # Mock successful response from ollama + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["service"] == "ollama" + # New fields for model info + assert "active_model" in data + assert "ollama_model" in data + mock_get.assert_called_once() + + @patch("src.training.serve.requests.get") + def test_health_failure_connection_error(self, mock_get, client): + """Health check should return 503 when ollama is not accessible.""" + # Mock connection error + import requests + mock_get.side_effect = requests.RequestException("Connection refused") + + response = client.get("/health") + + assert response.status_code == 503 + data = response.json() + assert "unhealthy" in data["detail"].lower() + assert "ollama" in data["detail"].lower() + + @patch("src.training.serve.requests.get") + def test_health_failure_non_200_status(self, mock_get, client): + """Health check should return 503 when ollama returns non-200 status.""" + # Mock non-200 response + mock_response = Mock() + mock_response.status_code = 500 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 503 + data = response.json() + assert "unhealthy" in data["detail"].lower() + + @patch("src.training.serve.requests.get") + def test_health_failure_timeout(self, mock_get, client): + """Health check should return 503 when ollama request times out.""" + import requests + mock_get.side_effect = requests.Timeout("Request timed out") + + response = client.get("/health") + + assert response.status_code == 503 + + +class TestGenerateEndpoint: + """Tests for the /generate endpoint.""" + + @patch("src.training.serve.requests.post") + def test_generate_success_message_format(self, mock_post, client): + """Generate should return docstring when ollama responds with message format.""" + # Mock successful ollama response with message format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": { + "content": '"""Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n"""' + } + } + mock_post.return_value = mock_response + + request_data = { + "code": "def add(x, y):\n return x + y", + "max_new_tokens": 256 + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Compute the sum" in data["docstring"] + # New field for model info + assert "model" in data + mock_post.assert_called_once() + + @patch("src.training.serve.requests.post") + def test_generate_success_response_format(self, mock_post, client): + """Generate should return docstring when ollama responds with response format.""" + # Mock successful ollama response with response format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "response": '"""Return the product of two numbers.\n\nParameters\n----------\na : float\n First number.\nb : float\n Second number.\n\nReturns\n-------\nfloat\n Product of a and b.\n"""' + } + mock_post.return_value = mock_response + + request_data = { + "code": "def multiply(a, b):\n return a * b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Return the product" in data["docstring"] + + @patch("src.training.serve.requests.post") + def test_generate_success_choices_format(self, mock_post, client): + """Generate should return docstring when ollama responds with choices format.""" + # Mock successful ollama response with choices format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": '"""Calculate the difference.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Difference of x and y.\n"""' + } + } + ] + } + mock_post.return_value = mock_response + + request_data = { + "code": "def subtract(x, y):\n return x - y", + "max_new_tokens": 128 + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Calculate the difference" in data["docstring"] + + @patch("src.training.serve.requests.post") + def test_generate_default_max_new_tokens(self, mock_post, client): + """Generate should use default max_new_tokens when not provided.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": "def test(): pass"} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + # Verify that max_new_tokens was included in the payload + call_args = mock_post.call_args + assert call_args is not None + payload = call_args[1]["json"] + assert "options" in payload + assert payload["options"]["num_predict"] == 256 + + @patch("src.training.serve.requests.post") + def test_generate_failure_connection_error(self, mock_post, client): + """Generate should return 500 when ollama connection fails.""" + import requests + mock_post.side_effect = requests.ConnectionError("Connection refused") + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert "Failed to generate docstring" in data["detail"] + + @patch("src.training.serve.requests.post") + def test_generate_failure_timeout(self, mock_post, client): + """Generate should return 500 when ollama request times out.""" + import requests + mock_post.side_effect = requests.Timeout("Request timed out") + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + + @patch("src.training.serve.requests.post") + def test_generate_failure_non_200_status(self, mock_post, client): + """Generate should return 500 when ollama returns non-200 status.""" + mock_response = Mock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = Exception("Internal server error") + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + + @patch("src.training.serve.requests.post") + def test_generate_failure_unexpected_format(self, mock_post, client): + """Generate should return 500 when ollama returns unexpected format.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "unexpected": "format" + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert "Unexpected response format" in data["detail"] + + def test_generate_missing_code_field(self, client): + """Generate should return 422 when code field is missing.""" + request_data = {} + response = client.post("/generate", json=request_data) + + assert response.status_code == 422 + + def test_generate_empty_code(self, client): + """Generate should accept empty code string.""" + with patch("src.training.serve.requests.post") as mock_post: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Empty function."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": ""} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + +class TestHelperFunctions: + """Tests for helper functions.""" + + @patch("src.training.serve.requests.get") + def test_check_ollama_health_success(self, mock_get): + """check_ollama_health should return True when ollama is accessible.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + result = check_ollama_health() + + assert result is True + mock_get.assert_called_once() + + @patch("src.training.serve.requests.get") + def test_check_ollama_health_failure(self, mock_get): + """check_ollama_health should return False when ollama is not accessible.""" + import requests + mock_get.side_effect = requests.RequestException("Connection refused") + + result = check_ollama_health() + + assert result is False + + @patch("src.training.serve.requests.post") + def test_generate_docstring_success(self, mock_post): + """generate_docstring should return docstring content.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + result = generate_docstring("def test(): pass", max_new_tokens=128) + + assert result == '"""Test docstring."""' + mock_post.assert_called_once() + + @patch("src.training.serve.requests.post") + def test_generate_docstring_failure(self, mock_post): + """generate_docstring should raise RuntimeError on failure.""" + import requests + mock_post.side_effect = requests.RequestException("Connection error") + + with pytest.raises(RuntimeError) as exc_info: + generate_docstring("def test(): pass") + + assert "Failed to generate docstring" in str(exc_info.value) + + +class TestModelsEndpoint: + """Tests for the /models endpoint.""" + + def test_list_models_returns_all(self, client): + """GET /models should return all registered models.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "models" in data + assert len(data["models"]) == len(MODEL_REGISTRY) + + def test_list_models_includes_default(self, client): + """Response should indicate the default model.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "default" in data + assert data["default"] == DEFAULT_MODEL_KEY + + def test_list_models_includes_active(self, client): + """Response should indicate the active model.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "active" in data + + def test_list_models_model_structure(self, client): + """Each model in the list should have required fields.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + for model in data["models"]: + assert "key" in model + assert "name" in model + assert "ollama_model" in model + assert "context_window" in model + assert "architecture" in model + + +class TestModelSelection: + """Tests for per-request model selection.""" + + @patch("src.training.serve.requests.post") + def test_generate_with_specific_model_key(self, mock_post, client): + """Should use the specified model key in the request.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass", + "model": "qwen2.5-coder-7b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["model"] == "qwen2.5-coder:7b" + + # Verify the correct model was used in the payload + call_args = mock_post.call_args + payload = call_args[1]["json"] + assert payload["model"] == "qwen2.5-coder:7b" + + @patch("src.training.serve.requests.post") + def test_generate_with_raw_ollama_model(self, mock_post, client): + """Should accept raw Ollama model names.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass", + "model": "qwen2.5-coder:14b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + @patch("src.training.serve.requests.post") + def test_generate_applies_model_sampling(self, mock_post, client): + """Should apply model-specific sampling parameters.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + # Use Qwen3 which has different sampling params + request_data = { + "code": "def test(): pass", + "model": "qwen3-coder-30b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify Qwen3 sampling parameters were used + call_args = mock_post.call_args + payload = call_args[1]["json"] + options = payload["options"] + assert options["temperature"] == 1.0 + assert options["top_p"] == 0.95 + assert options["top_k"] == 40 + + @patch("src.training.serve.requests.post") + def test_generate_applies_model_keep_alive(self, mock_post, client): + """Should apply model-specific keep_alive setting.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + # Use Qwen3 which has keep_alive=300 + request_data = { + "code": "def test(): pass", + "model": "qwen3-coder-30b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify Qwen3 keep_alive was used + call_args = mock_post.call_args + payload = call_args[1]["json"] + assert payload["keep_alive"] == 300 + + @patch("src.training.serve.requests.post") + def test_generate_without_model_uses_default(self, mock_post, client): + """Should use default model when none specified.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": "def test(): pass"} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify default model was used + call_args = mock_post.call_args + payload = call_args[1]["json"] + default_config = get_model_config(DEFAULT_MODEL_KEY) + assert payload["model"] == default_config.ollama_model + + +class TestHealthWithModel: + """Tests for health endpoint with model information.""" + + @patch("src.training.serve.requests.get") + def test_health_reports_active_model(self, mock_get, client): + """Health endpoint should report the active model.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert "active_model" in data + assert "ollama_model" in data + # Should match the configured model + active_model = get_active_model() + assert data["active_model"] == active_model.name + assert data["ollama_model"] == active_model.ollama_model