From 7cb4e8ec3130ad72dd3a0ec7ac56180562c3e4ff Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Sun, 25 Jan 2026 14:07:35 +0100 Subject: [PATCH 1/6] added a simple script to write docstrings locally based on ollama --- scripts/run_ollama.py | 153 ++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 153 insertions(+) create mode 100644 scripts/run_ollama.py diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py new file mode 100644 index 0000000..85f121a --- /dev/null +++ b/scripts/run_ollama.py @@ -0,0 +1,153 @@ +import argparse +import requests +import sys +import json + + +SYSTEM_PROMPT = r""" +You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. + +\#\# Output Rules +- Output ONLY the docstring content (including the triple quotes) +- Do NOT include the function signature or body +- Do NOT add any explanation before or after the docstring + +\#\# NumPy Docstring Format + +\#\#\# Structure (include sections only when applicable) +\"\"\" +Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). + +Extended summary providing more details about the function behavior, +algorithm, or implementation notes. Optional but recommended for +complex functions. + +Parameters +---------- +param_name : type + Description of the parameter. If the description spans multiple + lines, indent continuation lines. +param_name : type, optional + For optional parameters, specify default value in description. + Default is `default_value`. +*args : type + Description of variable positional arguments. +**kwargs : type + Description of variable keyword arguments. + +Returns +------- +type + Description of return value. +name : type + Use this format when returning named values or multiple values. + +Yields +------ +type + For generator functions, describe yielded values. + +Raises +------ +ExceptionType + Explanation of when this exception is raised. + +See Also +-------- +related_function : Brief description of relation. + +Notes +----- +Additional technical notes, mathematical formulas (using LaTeX), +or implementation details. + +Examples +-------- +>>> function_name(arg1, arg2) +expected_output +\"\"\" + +\#\#\# Type Annotation Conventions +- Basic types: `int`, `float`, `str`, `bool`, `None` +- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` +- Multiple types: `int or float`, `str or None` +- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` +- Callable: `callable` +- Optional params: append `, optional` after type + +\#\#\# Guidelines +1. First line: concise, imperative verb, no variable names, ends with period +2. Leave one blank line after the summary before Parameters +3. Align parameter descriptions consistently +4. Include realistic, runnable Examples when behavior isn't obvious +5. Document all exceptions that may be explicitly raised +6. For boolean params, describe what True/False means +""" + + +DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate + + +def build_payload(model, system_msgs, user_msgs, stream): + messages = [] + for s in system_msgs: + messages.append({"role": "system", "content": s}) + for u in user_msgs: + messages.append({"role": "user", "content": u}) + + return { + "model": model, + "messages": messages, + "stream": stream, + "keep_alive": 0 # Unload model after request + } + + +def main(): + parser = argparse.ArgumentParser(description="Send system and user prompts to model endpoint") + parser.add_argument("--url", default=DEFAULT_URL, help="API endpoint URL") + parser.add_argument("--model", required=True, help="Model name, e.g. qwen2.5-coder:32b") + parser.add_argument("--system", action="append", default=None, help="System prompt (repeatable). Overrides default.") + parser.add_argument("--user", action="append", default=[], help="User prompt (repeatable)") + parser.add_argument("--stream", action="store_true", help="Enable streaming mode") + parser.add_argument("--timeout", type=float, default=120.0, help="Request timeout in seconds") + + args = parser.parse_args() + + # Use default system prompt if none provided + system_msgs = args.system if args.system else [SYSTEM_PROMPT] + + if not args.user: + print("Error: at least one --user prompt is required.", file=sys.stderr) + sys.exit(2) + + payload = build_payload(args.model, system_msgs, args.user, args.stream) + + try: + resp = requests.post(args.url, json=payload, timeout=args.timeout) + resp.raise_for_status() + except requests.RequestException as e: + print(f"Request failed: {e}", file=sys.stderr) + sys.exit(1) + + try: + data = resp.json() + except ValueError: + print("Response is not valid JSON", file=sys.stderr) + print(resp.text, file=sys.stderr) + sys.exit(1) + + # Handle /api/chat response format + if "message" in data: + print(data["message"].get("content", "")) + elif "response" in data: + print(data["response"]) + elif "choices" in data and isinstance(data["choices"], list): + for c in data["choices"]: + print(c.get("message", {}).get("content", c.get("text", ""))) + else: + print(json.dumps(data, indent=2)) + + +if __name__ == "__main__": + main() \ No newline at end of file From 6c8901947d59dcf1f6f1ec5087400f162434a343 Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Sun, 25 Jan 2026 16:19:43 +0100 Subject: [PATCH 2/6] filled the serve.py with fastapi-based endpoints, ollama is used as a backend --- .gitignore | 1 + README.md | 110 ++++++++- pyproject.toml | 2 + scripts/run_ollama.py | 107 +++------ src/training/prompts/system_prompt.md | 77 +++++++ src/training/serve.py | 120 ++++++++-- tests/test_serve.py | 312 ++++++++++++++++++++++++++ 7 files changed, 634 insertions(+), 95 deletions(-) create mode 100644 src/training/prompts/system_prompt.md create mode 100644 tests/test_serve.py diff --git a/.gitignore b/.gitignore index ad3c25b..d9c1a25 100644 --- a/.gitignore +++ b/.gitignore @@ -23,6 +23,7 @@ build/ # IDE .vscode/ +.pytest_cache/ # Jupyter notebooks/.ipynb_checkpoints/ diff --git a/README.md b/README.md index b94317a..4d607f4 100644 --- a/README.md +++ b/README.md @@ -53,8 +53,9 @@ Test Dataset + Model Predictions --> [benchmark.py] --> Metrics Report - **`train_lora.py`** - LoRA fine-tuning using HuggingFace Trainer + PEFT. Supports QLoRA (4-bit quantization) for training on 1-2 A100 GPUs. -- **`serve.py`** - FastAPI inference server that loads the fine-tuned model and - serves docstring generation via HTTP. +- **`serve.py`** - FastAPI inference server that uses ollama API to generate + docstrings. The server uses a hard-coded system prompt for NumPy-style docstring + generation. ### Evaluation (`src/evaluation/`) @@ -87,6 +88,111 @@ python -m src.data.convert_seed \ --output-dir data/processed/python-method ``` +## Serving + +The FastAPI inference server provides HTTP endpoints for docstring generation using +ollama as the backend. The server uses a system prompt stored in +`src/training/prompts/system_prompt.md` to generate NumPy-style docstrings. + +### Prerequisites + +1. **Install ollama**: Make sure [ollama](https://ollama.ai/) is installed and running locally +2. **Pull a model**: Download a code model (e.g., `qwen2.5-coder:32b`): + ```bash + ollama pull qwen2.5-coder:32b + ``` + +### Starting the Server + +Start the FastAPI server using uvicorn: + +```bash +# Using uvicorn directly +uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 + +# Or run the module directly +python -m src.training.serve +``` + +The server will start on `http://localhost:8000` by default. + +### Configuration + +The server can be configured using environment variables: + +- `OLLAMA_URL` - Ollama API endpoint (default: `http://localhost:11434/api/chat`) +- `OLLAMA_MODEL` - Model name to use (default: `qwen2.5-coder:32b`) +- `REQUEST_TIMEOUT` - Request timeout in seconds (default: `120.0`) + +Example: +```bash +OLLAMA_MODEL=qwen2.5-coder:7b uvicorn src.training.serve:app --port 8000 +``` + +### API Endpoints + +#### Health Check + +Check if the service is healthy and ollama is accessible: + +```bash +curl http://localhost:8000/health +``` + +**Response (200 OK):** +```json +{ + "status": "healthy", + "service": "ollama" +} +``` + +**Response (503 Service Unavailable):** +```json +{ + "detail": "Service unhealthy: ollama is not running or not accessible" +} +``` + +#### Generate Docstring + +Generate a docstring for a Python function: + +```bash +curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{ + "code": "def add(x, y):\n return x + y", + "max_new_tokens": 256 + }' +``` + +**Request Body:** +- `code` (required): Python function code as a string +- `max_new_tokens` (optional): Maximum number of tokens to generate (default: 256) + +**Response (200 OK):** +```json +{ + "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n\"\"\"" +} +``` + +**Response (500 Internal Server Error):** +```json +{ + "detail": "Failed to generate docstring: " +} +``` + +### Testing + +Run the test suite to verify the API endpoints: + +```bash +pytest tests/test_serve.py -v +``` + ## Dataset The seed dataset comes from the [NeuralCodeSum](https://github.com/wasiahmad/NeuralCodeSum) diff --git a/pyproject.toml b/pyproject.toml index 50756b8..4778490 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -24,12 +24,14 @@ dependencies = [ "safetensors", "fastapi>=0.104.0", "uvicorn>=0.24.0", + "requests>=2.31.0", ] [project.optional-dependencies] dev = [ "pytest>=7.0", "ruff>=0.1.0", + "httpx>=0.24.0", ] [tool.hatch.build.targets.wheel] diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py index 85f121a..bef6323 100644 --- a/scripts/run_ollama.py +++ b/scripts/run_ollama.py @@ -2,87 +2,19 @@ import requests import sys import json +import time +from pathlib import Path -SYSTEM_PROMPT = r""" -You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. - -\#\# Output Rules -- Output ONLY the docstring content (including the triple quotes) -- Do NOT include the function signature or body -- Do NOT add any explanation before or after the docstring - -\#\# NumPy Docstring Format - -\#\#\# Structure (include sections only when applicable) -\"\"\" -Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). - -Extended summary providing more details about the function behavior, -algorithm, or implementation notes. Optional but recommended for -complex functions. - -Parameters ----------- -param_name : type - Description of the parameter. If the description spans multiple - lines, indent continuation lines. -param_name : type, optional - For optional parameters, specify default value in description. - Default is `default_value`. -*args : type - Description of variable positional arguments. -**kwargs : type - Description of variable keyword arguments. - -Returns -------- -type - Description of return value. -name : type - Use this format when returning named values or multiple values. - -Yields ------- -type - For generator functions, describe yielded values. - -Raises ------- -ExceptionType - Explanation of when this exception is raised. - -See Also --------- -related_function : Brief description of relation. - -Notes ------ -Additional technical notes, mathematical formulas (using LaTeX), -or implementation details. - -Examples --------- ->>> function_name(arg1, arg2) -expected_output -\"\"\" - -\#\#\# Type Annotation Conventions -- Basic types: `int`, `float`, `str`, `bool`, `None` -- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` -- Multiple types: `int or float`, `str or None` -- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` -- Callable: `callable` -- Optional params: append `, optional` after type - -\#\#\# Guidelines -1. First line: concise, imperative verb, no variable names, ends with period -2. Leave one blank line after the summary before Parameters -3. Align parameter descriptions consistently -4. Include realistic, runnable Examples when behavior isn't obvious -5. Document all exceptions that may be explicitly raised -6. For boolean params, describe what True/False means -""" +def load_system_prompt() -> str: + """Load the default system prompt from the prompts directory.""" + prompt_path = Path(__file__).parent.parent / "src" / "training" / "prompts" / "system_prompt.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"System prompt file not found: {prompt_path}. " + "Please ensure the prompt file exists." + ) + return prompt_path.read_text(encoding="utf-8") DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate @@ -115,7 +47,14 @@ def main(): args = parser.parse_args() # Use default system prompt if none provided - system_msgs = args.system if args.system else [SYSTEM_PROMPT] + if args.system: + system_msgs = args.system + else: + try: + system_msgs = [load_system_prompt()] + except FileNotFoundError as e: + print(f"Error: {e}", file=sys.stderr) + sys.exit(1) if not args.user: print("Error: at least one --user prompt is required.", file=sys.stderr) @@ -123,11 +62,15 @@ def main(): payload = build_payload(args.model, system_msgs, args.user, args.stream) + # Track execution time + start_time = time.time() try: resp = requests.post(args.url, json=payload, timeout=args.timeout) resp.raise_for_status() except requests.RequestException as e: + elapsed_time = time.time() - start_time print(f"Request failed: {e}", file=sys.stderr) + print(f"Execution time: {elapsed_time:.2f}s", file=sys.stderr) sys.exit(1) try: @@ -148,6 +91,10 @@ def main(): else: print(json.dumps(data, indent=2)) + # Print execution time + elapsed_time = time.time() - start_time + print(f"Execution time: {elapsed_time:.2f}s", file=sys.stdout) + if __name__ == "__main__": main() \ No newline at end of file diff --git a/src/training/prompts/system_prompt.md b/src/training/prompts/system_prompt.md new file mode 100644 index 0000000..770e1da --- /dev/null +++ b/src/training/prompts/system_prompt.md @@ -0,0 +1,77 @@ +You are a Python documentation expert. Your task is to generate a docstring for the provided Python function following the NumPy documentation format strictly. + +## Output Rules +- Output ONLY the docstring content (including the triple quotes) +- Do NOT include the function signature or body +- Do NOT add any explanation before or after the docstring + +## NumPy Docstring Format + +### Structure (include sections only when applicable) +""" +Short one-line summary (imperative mood, e.g., "Compute", "Return", "Parse"). + +Extended summary providing more details about the function behavior, +algorithm, or implementation notes. Optional but recommended for +complex functions. + +Parameters +---------- +param_name : type + Description of the parameter. If the description spans multiple + lines, indent continuation lines. +param_name : type, optional + For optional parameters, specify default value in description. + Default is `default_value`. +*args : type + Description of variable positional arguments. +**kwargs : type + Description of variable keyword arguments. + +Returns +------- +type + Description of return value. +name : type + Use this format when returning named values or multiple values. + +Yields +------ +type + For generator functions, describe yielded values. + +Raises +------ +ExceptionType + Explanation of when this exception is raised. + +See Also +-------- +related_function : Brief description of relation. + +Notes +----- +Additional technical notes, mathematical formulas (using LaTeX), +or implementation details. + +Examples +-------- +>>> function_name(arg1, arg2) +expected_output +""" + +### Type Annotation Conventions +- Basic types: `int`, `float`, `str`, `bool`, `None` +- Collections: `list of int`, `dict of {str: int}`, `tuple of (int, str)` +- Multiple types: `int or float`, `str or None` +- Array-like: `array_like`, `numpy.ndarray of shape (n, m)` +- Callable: `callable` +- Optional params: append `, optional` after type + +### Guidelines +1. First line: concise, imperative verb, no variable names, ends with period +2. Leave one blank line after the summary before Parameters +3. Align parameter descriptions consistently +4. Include realistic, runnable Examples when behavior isn't obvious +5. Document all exceptions that may be explicitly raised +6. For boolean params, describe what True/False means diff --git a/src/training/serve.py b/src/training/serve.py index 619ca5c..ea5f9e8 100644 --- a/src/training/serve.py +++ b/src/training/serve.py @@ -11,30 +11,124 @@ GET /health - Health check """ +import os +from pathlib import Path +from typing import Optional -def load_model(base_model: str, adapter_path: str): - """Load base model + LoRA adapter for inference.""" - raise NotImplementedError +import requests +from fastapi import FastAPI, HTTPException +from pydantic import BaseModel + +# Configuration +OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/chat") +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:32b") +REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "120.0")) + + +def load_system_prompt() -> str: + """Load the default system prompt from the prompts directory.""" + prompt_path = Path(__file__).parent / "prompts" / "system_prompt.md" + if not prompt_path.exists(): + raise FileNotFoundError( + f"System prompt file not found: {prompt_path}. " + "Please ensure the prompt file exists." + ) + return prompt_path.read_text(encoding="utf-8") + + +# Load system prompt at module level +SYSTEM_PROMPT = load_system_prompt() def generate_docstring(code: str, max_new_tokens: int = 256) -> str: + """Generate a docstring for the given code snippet using ollama API.""" + payload = { + "model": OLLAMA_MODEL, + "messages": [ + {"role": "system", "content": SYSTEM_PROMPT}, + {"role": "user", "content": code} + ], + "stream": False, + "keep_alive": 0, + "options": { + "num_predict": max_new_tokens + } + } + + try: + resp = requests.post(OLLAMA_URL, json=payload, timeout=REQUEST_TIMEOUT) + resp.raise_for_status() + data = resp.json() + + # Handle /api/chat response format + if "message" in data: + return data["message"].get("content", "") + elif "response" in data: + return data["response"] + elif "choices" in data and isinstance(data["choices"], list): + content = "" + for c in data["choices"]: + content += c.get("message", {}).get("content", c.get("text", "")) + return content + else: + raise ValueError(f"Unexpected response format: {data}") + except requests.RequestException as e: + raise RuntimeError(f"Failed to generate docstring: {e}") from e + + +def check_ollama_health() -> bool: + """Check if ollama is running locally by making a test request.""" + try: + # Try to list models as a health check + health_url = OLLAMA_URL.replace("/api/chat", "/api/tags") + resp = requests.get(health_url, timeout=5.0) + return resp.status_code == 200 + except requests.RequestException: + return False + + +# FastAPI app +app = FastAPI(title="Docstring Generation API", version="0.1.0") + + +class GenerateRequest(BaseModel): + """Request model for docstring generation.""" + code: str + max_new_tokens: Optional[int] = 256 + + +class GenerateResponse(BaseModel): + """Response model for docstring generation.""" + docstring: str + + +@app.post("/generate", response_model=GenerateResponse) +async def generate(request: GenerateRequest): """Generate a docstring for the given code snippet.""" - raise NotImplementedError + try: + docstring = generate_docstring(request.code, request.max_new_tokens) + return GenerateResponse(docstring=docstring) + except Exception as e: + raise HTTPException(status_code=500, detail=str(e)) -# FastAPI app will be defined here once dependencies are implemented. -# app = FastAPI() -# -# @app.post("/generate") -# async def generate(request: dict): ... -# -# @app.get("/health") -# async def health(): ... +@app.get("/health") +async def health(): + """Health check endpoint that verifies ollama is running.""" + is_healthy = check_ollama_health() + if is_healthy: + return {"status": "healthy", "service": "ollama"} + else: + raise HTTPException( + status_code=503, + detail="Service unhealthy: ollama is not running or not accessible" + ) def main(): """Start the inference server.""" - raise NotImplementedError + import uvicorn + uvicorn.run(app, host="0.0.0.0", port=8000) if __name__ == "__main__": diff --git a/tests/test_serve.py b/tests/test_serve.py new file mode 100644 index 0000000..922211d --- /dev/null +++ b/tests/test_serve.py @@ -0,0 +1,312 @@ +"""Tests for src.training.serve module.""" + +import json +from unittest.mock import Mock, patch + +import pytest +from fastapi.testclient import TestClient + +from src.training.serve import app, check_ollama_health, generate_docstring + + +@pytest.fixture +def client(): + """Create a test client for the FastAPI app.""" + return TestClient(app) + + +class TestHealthEndpoint: + """Tests for the /health endpoint.""" + + @patch("src.training.serve.requests.get") + def test_health_success(self, mock_get, client): + """Health check should return 200 when ollama is running.""" + # Mock successful response from ollama + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert data["status"] == "healthy" + assert data["service"] == "ollama" + mock_get.assert_called_once() + + @patch("src.training.serve.requests.get") + def test_health_failure_connection_error(self, mock_get, client): + """Health check should return 503 when ollama is not accessible.""" + # Mock connection error + import requests + mock_get.side_effect = requests.RequestException("Connection refused") + + response = client.get("/health") + + assert response.status_code == 503 + data = response.json() + assert "unhealthy" in data["detail"].lower() + assert "ollama" in data["detail"].lower() + + @patch("src.training.serve.requests.get") + def test_health_failure_non_200_status(self, mock_get, client): + """Health check should return 503 when ollama returns non-200 status.""" + # Mock non-200 response + mock_response = Mock() + mock_response.status_code = 500 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 503 + data = response.json() + assert "unhealthy" in data["detail"].lower() + + @patch("src.training.serve.requests.get") + def test_health_failure_timeout(self, mock_get, client): + """Health check should return 503 when ollama request times out.""" + import requests + mock_get.side_effect = requests.Timeout("Request timed out") + + response = client.get("/health") + + assert response.status_code == 503 + + +class TestGenerateEndpoint: + """Tests for the /generate endpoint.""" + + @patch("src.training.serve.requests.post") + def test_generate_success_message_format(self, mock_post, client): + """Generate should return docstring when ollama responds with message format.""" + # Mock successful ollama response with message format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": { + "content": '"""Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n"""' + } + } + mock_post.return_value = mock_response + + request_data = { + "code": "def add(x, y):\n return x + y", + "max_new_tokens": 256 + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Compute the sum" in data["docstring"] + mock_post.assert_called_once() + + @patch("src.training.serve.requests.post") + def test_generate_success_response_format(self, mock_post, client): + """Generate should return docstring when ollama responds with response format.""" + # Mock successful ollama response with response format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "response": '"""Return the product of two numbers.\n\nParameters\n----------\na : float\n First number.\nb : float\n Second number.\n\nReturns\n-------\nfloat\n Product of a and b.\n"""' + } + mock_post.return_value = mock_response + + request_data = { + "code": "def multiply(a, b):\n return a * b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Return the product" in data["docstring"] + + @patch("src.training.serve.requests.post") + def test_generate_success_choices_format(self, mock_post, client): + """Generate should return docstring when ollama responds with choices format.""" + # Mock successful ollama response with choices format + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "choices": [ + { + "message": { + "content": '"""Calculate the difference.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Difference of x and y.\n"""' + } + } + ] + } + mock_post.return_value = mock_response + + request_data = { + "code": "def subtract(x, y):\n return x - y", + "max_new_tokens": 128 + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert "docstring" in data + assert "Calculate the difference" in data["docstring"] + + @patch("src.training.serve.requests.post") + def test_generate_default_max_new_tokens(self, mock_post, client): + """Generate should use default max_new_tokens when not provided.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": "def test(): pass"} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + # Verify that max_new_tokens was included in the payload + call_args = mock_post.call_args + assert call_args is not None + payload = call_args[1]["json"] + assert "options" in payload + assert payload["options"]["num_predict"] == 256 + + @patch("src.training.serve.requests.post") + def test_generate_failure_connection_error(self, mock_post, client): + """Generate should return 500 when ollama connection fails.""" + import requests + mock_post.side_effect = requests.ConnectionError("Connection refused") + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert "Failed to generate docstring" in data["detail"] + + @patch("src.training.serve.requests.post") + def test_generate_failure_timeout(self, mock_post, client): + """Generate should return 500 when ollama request times out.""" + import requests + mock_post.side_effect = requests.Timeout("Request timed out") + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + + @patch("src.training.serve.requests.post") + def test_generate_failure_non_200_status(self, mock_post, client): + """Generate should return 500 when ollama returns non-200 status.""" + mock_response = Mock() + mock_response.status_code = 500 + mock_response.raise_for_status.side_effect = Exception("Internal server error") + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + + @patch("src.training.serve.requests.post") + def test_generate_failure_unexpected_format(self, mock_post, client): + """Generate should return 500 when ollama returns unexpected format.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "unexpected": "format" + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 500 + data = response.json() + assert "detail" in data + assert "Unexpected response format" in data["detail"] + + def test_generate_missing_code_field(self, client): + """Generate should return 422 when code field is missing.""" + request_data = {} + response = client.post("/generate", json=request_data) + + assert response.status_code == 422 + + def test_generate_empty_code(self, client): + """Generate should accept empty code string.""" + with patch("src.training.serve.requests.post") as mock_post: + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Empty function."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": ""} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + +class TestHelperFunctions: + """Tests for helper functions.""" + + @patch("src.training.serve.requests.get") + def test_check_ollama_health_success(self, mock_get): + """check_ollama_health should return True when ollama is accessible.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + result = check_ollama_health() + + assert result is True + mock_get.assert_called_once() + + @patch("src.training.serve.requests.get") + def test_check_ollama_health_failure(self, mock_get): + """check_ollama_health should return False when ollama is not accessible.""" + import requests + mock_get.side_effect = requests.RequestException("Connection refused") + + result = check_ollama_health() + + assert result is False + + @patch("src.training.serve.requests.post") + def test_generate_docstring_success(self, mock_post): + """generate_docstring should return docstring content.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + result = generate_docstring("def test(): pass", max_new_tokens=128) + + assert result == '"""Test docstring."""' + mock_post.assert_called_once() + + @patch("src.training.serve.requests.post") + def test_generate_docstring_failure(self, mock_post): + """generate_docstring should raise RuntimeError on failure.""" + import requests + mock_post.side_effect = requests.RequestException("Connection error") + + with pytest.raises(RuntimeError) as exc_info: + generate_docstring("def test(): pass") + + assert "Failed to generate docstring" in str(exc_info.value) From 957aa218ce53e824d575cfcd4a93d523d44dac0b Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Tue, 3 Feb 2026 18:12:53 +0100 Subject: [PATCH 3/6] add multi-model support for Qwen 2.5 and Qwen3 Coder - Add model configuration registry (src/training/models.py) with SamplingConfig and ModelConfig dataclasses - Support Qwen 2.5 Coder (7B, 14B, 32B) and Qwen3 Coder 30B-A3B (MoE) - Add per-request model selection via API and environment variable - Apply model-specific sampling parameters (Qwen3 uses temp=1.0, top_p=0.95) - Add /models endpoint to list available configurations - Update health endpoint to report active model info - Add --model-key and --list-models CLI options - Add .gitattributes for cross-platform line endings - Add Windows PowerShell/CMD examples in documentation - Add 48 new tests (20 for models, 10 for serve model selection) Co-Authored-By: Claude Opus 4.5 --- .gitattributes | 27 ++++++ README.md | 136 +++++++++++++++++++++++++--- scripts/run_ollama.py | 174 +++++++++++++++++++++++++++++++---- src/training/models.py | 193 +++++++++++++++++++++++++++++++++++++++ src/training/serve.py | 151 +++++++++++++++++++++++++++---- tests/test_models.py | 199 +++++++++++++++++++++++++++++++++++++++++ tests/test_serve.py | 192 ++++++++++++++++++++++++++++++++++++++- 7 files changed, 1023 insertions(+), 49 deletions(-) create mode 100644 .gitattributes create mode 100644 src/training/models.py create mode 100644 tests/test_models.py diff --git a/.gitattributes b/.gitattributes new file mode 100644 index 0000000..f3a8404 --- /dev/null +++ b/.gitattributes @@ -0,0 +1,27 @@ +# Auto detect text files and perform LF normalization +* text=auto + +# Python files +*.py text eol=lf + +# Markdown files +*.md text eol=lf + +# JSON files +*.json text eol=lf + +# YAML files +*.yml text eol=lf +*.yaml text eol=lf + +# Shell scripts +*.sh text eol=lf + +# Configuration files +*.toml text eol=lf +*.cfg text eol=lf +*.ini text eol=lf + +# Keep Windows batch files with CRLF +*.bat text eol=crlf +*.cmd text eol=crlf diff --git a/README.md b/README.md index 4d607f4..89e2819 100644 --- a/README.md +++ b/README.md @@ -54,8 +54,10 @@ Test Dataset + Model Predictions --> [benchmark.py] --> Metrics Report QLoRA (4-bit quantization) for training on 1-2 A100 GPUs. - **`serve.py`** - FastAPI inference server that uses ollama API to generate - docstrings. The server uses a hard-coded system prompt for NumPy-style docstring - generation. + docstrings. Supports multiple Qwen Coder models with model-specific configurations. + +- **`models.py`** - Model configuration registry with sampling parameters for + Qwen 2.5 Coder and Qwen3 Coder variants. ### Evaluation (`src/evaluation/`) @@ -97,15 +99,22 @@ ollama as the backend. The server uses a system prompt stored in ### Prerequisites 1. **Install ollama**: Make sure [ollama](https://ollama.ai/) is installed and running locally -2. **Pull a model**: Download a code model (e.g., `qwen2.5-coder:32b`): +2. **Pull a model**: Download one of the supported code models: ```bash - ollama pull qwen2.5-coder:32b + # Qwen 2.5 Coder (dense models) + ollama pull qwen2.5-coder:32b # Default, ~18GB Q4 + ollama pull qwen2.5-coder:14b # Mid-size, ~8GB Q4 + ollama pull qwen2.5-coder:7b # Fast, ~4GB Q4 + + # Qwen3 Coder (MoE model) + ollama pull qwen3-coder:30b-a3b # Best quality, ~18GB Q4, 256K context ``` ### Starting the Server Start the FastAPI server using uvicorn: +**Linux/macOS:** ```bash # Using uvicorn directly uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 @@ -114,6 +123,11 @@ uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 python -m src.training.serve ``` +**Windows (PowerShell):** +```powershell +uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 +``` + The server will start on `http://localhost:8000` by default. ### Configuration @@ -121,12 +135,63 @@ The server will start on `http://localhost:8000` by default. The server can be configured using environment variables: - `OLLAMA_URL` - Ollama API endpoint (default: `http://localhost:11434/api/chat`) -- `OLLAMA_MODEL` - Model name to use (default: `qwen2.5-coder:32b`) +- `OLLAMA_MODEL` - Model key or Ollama model name (default: `qwen2.5-coder-32b`) - `REQUEST_TIMEOUT` - Request timeout in seconds (default: `120.0`) -Example: +**Linux/macOS:** +```bash +OLLAMA_MODEL=qwen3-coder-30b uvicorn src.training.serve:app --port 8000 +``` + +**Windows (PowerShell):** +```powershell +$env:OLLAMA_MODEL="qwen3-coder-30b"; uvicorn src.training.serve:app --port 8000 +``` + +**Windows (CMD):** +```cmd +set OLLAMA_MODEL=qwen3-coder-30b && uvicorn src.training.serve:app --port 8000 +``` + +### Available Models + +| Model Key | Ollama Model | Architecture | Memory (Q4) | Context | Description | +|-----------|--------------|--------------|-------------|---------|-------------| +| `qwen2.5-coder-32b` | `qwen2.5-coder:32b` | Dense | ~18GB | 32K | Default, balanced quality/speed | +| `qwen2.5-coder-14b` | `qwen2.5-coder:14b` | Dense | ~8GB | 32K | Mid-size, good performance | +| `qwen2.5-coder-7b` | `qwen2.5-coder:7b` | Dense | ~4GB | 32K | Fast inference | +| `qwen3-coder-30b` | `qwen3-coder:30b-a3b` | MoE | ~18GB | 256K | Best quality, 3.3B active params | + +Each model has optimized sampling parameters: +- **Qwen 2.5 Coder**: temperature=0.7, top_p=0.9, top_k=40 +- **Qwen3 Coder**: temperature=1.0, top_p=0.95, top_k=40 (per official recommendations) + +### Model Selection + +You can select a model in two ways: + +1. **Environment variable** (applies to all requests): + ```bash + OLLAMA_MODEL=qwen3-coder-30b uvicorn src.training.serve:app + ``` + +2. **Per-request** (via API): + ```bash + curl -X POST http://localhost:8000/generate \ + -H "Content-Type: application/json" \ + -d '{"code": "def add(x, y): return x + y", "model": "qwen3-coder-30b"}' + ``` + +### List Available Models + +**Via CLI:** +```bash +python scripts/run_ollama.py --list-models +``` + +**Via API:** ```bash -OLLAMA_MODEL=qwen2.5-coder:7b uvicorn src.training.serve:app --port 8000 +curl http://localhost:8000/models ``` ### API Endpoints @@ -143,7 +208,9 @@ curl http://localhost:8000/health ```json { "status": "healthy", - "service": "ollama" + "service": "ollama", + "active_model": "Qwen 2.5 Coder 32B", + "ollama_model": "qwen2.5-coder:32b" } ``` @@ -169,12 +236,14 @@ curl -X POST http://localhost:8000/generate \ **Request Body:** - `code` (required): Python function code as a string -- `max_new_tokens` (optional): Maximum number of tokens to generate (default: 256) +- `max_new_tokens` (optional): Maximum number of tokens to generate (uses model default if not specified) +- `model` (optional): Model key or Ollama model name to use for this request **Response (200 OK):** ```json { - "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\n\"\"\"" + "docstring": "\"\"\"Compute the sum of two numbers.\n\nParameters\n----------\nx : int\n First number.\ny : int\n Second number.\n\nReturns\n-------\nint\n Sum of x and y.\"\"\"", + "model": "qwen2.5-coder:32b" } ``` @@ -185,12 +254,57 @@ curl -X POST http://localhost:8000/generate \ } ``` +#### List Models + +Get available model configurations: + +```bash +curl http://localhost:8000/models +``` + +**Response (200 OK):** +```json +{ + "default": "qwen2.5-coder-32b", + "active": "qwen2.5-coder-32b", + "models": [ + { + "key": "qwen2.5-coder-32b", + "name": "Qwen 2.5 Coder 32B", + "ollama_model": "qwen2.5-coder:32b", + "context_window": 32768, + "architecture": "dense", + "memory_q4": "~18GB", + "description": "Dense 32B model, good balance of quality and speed" + } + ] +} +``` + +### CLI Tool + +The CLI tool allows testing docstring generation directly: + +```bash +# Use default model +python scripts/run_ollama.py --user "def add(x, y): return x + y" + +# Use specific model by key +python scripts/run_ollama.py --model-key qwen3-coder-30b --user "def foo(): pass" + +# Use raw Ollama model name +python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def bar(): pass" + +# List available models +python scripts/run_ollama.py --list-models +``` + ### Testing Run the test suite to verify the API endpoints: ```bash -pytest tests/test_serve.py -v +pytest tests/test_serve.py tests/test_models.py -v ``` ## Dataset diff --git a/scripts/run_ollama.py b/scripts/run_ollama.py index bef6323..40496d0 100644 --- a/scripts/run_ollama.py +++ b/scripts/run_ollama.py @@ -1,10 +1,37 @@ +"""CLI script for testing docstring generation with Ollama. + +Supports model selection via registry keys or raw Ollama model names, +with model-specific sampling parameters. + +Usage: + python scripts/run_ollama.py --model-key qwen2.5-coder-32b --user "def add(x, y): return x + y" + python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def foo(): pass" + python scripts/run_ollama.py --list-models +""" + import argparse -import requests -import sys import json +import sys import time from pathlib import Path +import requests + +# Add parent directory to path for imports +sys.path.insert(0, str(Path(__file__).parent.parent)) + +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + SamplingConfig, + get_model_config, + list_models, +) + + +DEFAULT_URL = "http://localhost:11434/api/chat" + def load_system_prompt() -> str: """Load the default system prompt from the prompts directory.""" @@ -17,10 +44,13 @@ def load_system_prompt() -> str: return prompt_path.read_text(encoding="utf-8") -DEFAULT_URL = "http://localhost:11434/api/chat" # Changed from /api/generate - - -def build_payload(model, system_msgs, user_msgs, stream): +def build_payload( + model_config: ModelConfig, + system_msgs: list[str], + user_msgs: list[str], + stream: bool +) -> dict: + """Build the request payload with model-specific sampling parameters.""" messages = [] for s in system_msgs: messages.append({"role": "system", "content": s}) @@ -28,24 +58,131 @@ def build_payload(model, system_msgs, user_msgs, stream): messages.append({"role": "user", "content": u}) return { - "model": model, + "model": model_config.ollama_model, "messages": messages, "stream": stream, - "keep_alive": 0 # Unload model after request + "keep_alive": model_config.keep_alive, + "options": { + "temperature": model_config.sampling.temperature, + "top_p": model_config.sampling.top_p, + "top_k": model_config.sampling.top_k, + } } +def print_models_list(): + """Print all available models in a formatted table.""" + print("\nAvailable Models:") + print("=" * 80) + print(f"{'Key':<22} {'Ollama Model':<25} {'Arch':<6} {'Memory':<8} {'Context'}") + print("-" * 80) + + for model_info in list_models(): + key = model_info["key"] + ollama = model_info["ollama_model"] + arch = model_info["architecture"] + memory = model_info.get("memory_q4", "N/A") + context = f"{model_info['context_window']:,}" + + # Mark default model + marker = " *" if key == DEFAULT_MODEL_KEY else "" + print(f"{key:<22} {ollama:<25} {arch:<6} {memory:<8} {context}{marker}") + + print("-" * 80) + print(f"* = default model ({DEFAULT_MODEL_KEY})") + print() + + def main(): - parser = argparse.ArgumentParser(description="Send system and user prompts to model endpoint") - parser.add_argument("--url", default=DEFAULT_URL, help="API endpoint URL") - parser.add_argument("--model", required=True, help="Model name, e.g. qwen2.5-coder:32b") - parser.add_argument("--system", action="append", default=None, help="System prompt (repeatable). Overrides default.") - parser.add_argument("--user", action="append", default=[], help="User prompt (repeatable)") - parser.add_argument("--stream", action="store_true", help="Enable streaming mode") - parser.add_argument("--timeout", type=float, default=120.0, help="Request timeout in seconds") + parser = argparse.ArgumentParser( + description="Send prompts to Ollama for docstring generation", + formatter_class=argparse.RawDescriptionHelpFormatter, + epilog=""" +Examples: + # Use a registered model by key + python scripts/run_ollama.py --model-key qwen2.5-coder-32b --user "def add(x, y): return x + y" + + # Use a raw Ollama model name + python scripts/run_ollama.py --model qwen2.5-coder:7b --user "def foo(): pass" + + # Use Qwen3 MoE model + python scripts/run_ollama.py --model-key qwen3-coder-30b --user "def hello(): print('hi')" + + # List available models + python scripts/run_ollama.py --list-models +""" + ) + parser.add_argument( + "--url", + default=DEFAULT_URL, + help="API endpoint URL (default: %(default)s)" + ) + parser.add_argument( + "--model-key", + help=f"Model key from registry (e.g., {DEFAULT_MODEL_KEY}, qwen3-coder-30b)" + ) + parser.add_argument( + "--model", + help="Raw Ollama model name (e.g., qwen2.5-coder:32b). Overrides --model-key" + ) + parser.add_argument( + "--list-models", + action="store_true", + help="List available model configurations and exit" + ) + parser.add_argument( + "--system", + action="append", + default=None, + help="System prompt (repeatable). Overrides default." + ) + parser.add_argument( + "--user", + action="append", + default=[], + help="User prompt (repeatable)" + ) + parser.add_argument( + "--stream", + action="store_true", + help="Enable streaming mode" + ) + parser.add_argument( + "--timeout", + type=float, + default=120.0, + help="Request timeout in seconds (default: %(default)s)" + ) args = parser.parse_args() + # Handle --list-models + if args.list_models: + print_models_list() + sys.exit(0) + + # Determine model configuration + if args.model: + # Raw Ollama model name provided + model_config = get_model_config(args.model) + print(f"Using model: {model_config.ollama_model}", file=sys.stderr) + elif args.model_key: + # Registry key provided + model_config = get_model_config(args.model_key) + print(f"Using model: {model_config.name} ({model_config.ollama_model})", file=sys.stderr) + else: + # Use default + model_config = get_model_config(DEFAULT_MODEL_KEY) + print(f"Using default model: {model_config.name} ({model_config.ollama_model})", file=sys.stderr) + + # Print sampling parameters + print( + f"Sampling: temp={model_config.sampling.temperature}, " + f"top_p={model_config.sampling.top_p}, " + f"top_k={model_config.sampling.top_k}", + file=sys.stderr + ) + # Use default system prompt if none provided if args.system: system_msgs = args.system @@ -58,9 +195,10 @@ def main(): if not args.user: print("Error: at least one --user prompt is required.", file=sys.stderr) + print("Use --help for usage information.", file=sys.stderr) sys.exit(2) - payload = build_payload(args.model, system_msgs, args.user, args.stream) + payload = build_payload(model_config, system_msgs, args.user, args.stream) # Track execution time start_time = time.time() @@ -93,8 +231,8 @@ def main(): # Print execution time elapsed_time = time.time() - start_time - print(f"Execution time: {elapsed_time:.2f}s", file=sys.stdout) + print(f"\nExecution time: {elapsed_time:.2f}s", file=sys.stderr) if __name__ == "__main__": - main() \ No newline at end of file + main() diff --git a/src/training/models.py b/src/training/models.py new file mode 100644 index 0000000..3040ca2 --- /dev/null +++ b/src/training/models.py @@ -0,0 +1,193 @@ +"""Model configuration registry for docstring generation. + +Provides model-specific configurations including sampling parameters, +context windows, and memory requirements for different LLM backends. +""" + +from dataclasses import dataclass +from typing import Optional + + +@dataclass(frozen=True) +class SamplingConfig: + """Sampling parameters for text generation. + + Parameters + ---------- + temperature : float + Controls randomness in generation. Higher values = more random. + top_p : float + Nucleus sampling parameter. Considers tokens with cumulative probability >= top_p. + top_k : int + Limits sampling to top k most likely tokens. + num_predict : int + Maximum number of tokens to generate. + """ + temperature: float = 0.7 + top_p: float = 0.9 + top_k: int = 40 + num_predict: int = 256 + + +@dataclass(frozen=True) +class ModelConfig: + """Configuration for a specific model. + + Parameters + ---------- + name : str + Human-readable display name for the model. + ollama_model : str + Ollama model identifier (e.g., "qwen2.5-coder:32b"). + context_window : int + Maximum context length in tokens. + sampling : SamplingConfig + Default sampling parameters for this model. + keep_alive : int + Model keep-alive time in seconds (0 = unload immediately after request). + architecture : str + Model architecture type: "dense" or "moe". + total_params : str, optional + Total parameter count description (e.g., "32B", "30B (3.3B active)"). + memory_q4 : str, optional + Approximate memory requirement at Q4 quantization. + description : str + Brief description of the model's characteristics. + """ + name: str + ollama_model: str + context_window: int + sampling: SamplingConfig + keep_alive: int = 0 + architecture: str = "dense" + total_params: Optional[str] = None + memory_q4: Optional[str] = None + description: str = "" + + +# Pre-defined model configurations +MODEL_REGISTRY: dict[str, ModelConfig] = { + "qwen2.5-coder-32b": ModelConfig( + name="Qwen 2.5 Coder 32B", + ollama_model="qwen2.5-coder:32b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="32B", + memory_q4="~18GB", + description="Dense 32B model, good balance of quality and speed" + ), + "qwen2.5-coder-14b": ModelConfig( + name="Qwen 2.5 Coder 14B", + ollama_model="qwen2.5-coder:14b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="14B", + memory_q4="~8GB", + description="Mid-size dense model, balanced performance" + ), + "qwen2.5-coder-7b": ModelConfig( + name="Qwen 2.5 Coder 7B", + ollama_model="qwen2.5-coder:7b", + context_window=32768, + sampling=SamplingConfig(temperature=0.7, top_p=0.9, top_k=40), + keep_alive=0, + architecture="dense", + total_params="7B", + memory_q4="~4GB", + description="Smaller variant, faster inference" + ), + "qwen3-coder-30b": ModelConfig( + name="Qwen3 Coder 30B-A3B", + ollama_model="qwen3-coder:30b-a3b", + context_window=262144, # 256K tokens + sampling=SamplingConfig(temperature=1.0, top_p=0.95, top_k=40), + keep_alive=300, # Longer keep_alive for MoE model to avoid reload overhead + architecture="moe", + total_params="30B (3.3B active)", + memory_q4="~18GB", + description="MoE model with 256K context, best quality" + ), +} + +# Default model key +DEFAULT_MODEL_KEY = "qwen2.5-coder-32b" + + +def get_model_config(model_key: str) -> ModelConfig: + """Get configuration for a model by registry key or Ollama model name. + + Parameters + ---------- + model_key : str + Either a registry key (e.g., "qwen2.5-coder-32b") or a raw Ollama + model name (e.g., "qwen2.5-coder:32b"). + + Returns + ------- + ModelConfig + The model configuration. If the key is not found in the registry, + creates a fallback configuration with default sampling parameters. + """ + # First, try direct registry lookup + if model_key in MODEL_REGISTRY: + return MODEL_REGISTRY[model_key] + + # Check if it matches any ollama_model in the registry + for config in MODEL_REGISTRY.values(): + if config.ollama_model == model_key: + return config + + # Fallback: create a default config for unknown models + return create_fallback_config(model_key) + + +def create_fallback_config(ollama_model: str) -> ModelConfig: + """Create a fallback configuration for unknown models. + + Parameters + ---------- + ollama_model : str + The Ollama model identifier. + + Returns + ------- + ModelConfig + A configuration with default sampling parameters. + """ + return ModelConfig( + name=ollama_model, + ollama_model=ollama_model, + context_window=32768, # Conservative default + sampling=SamplingConfig(), # Use defaults + keep_alive=0, + architecture="unknown", + description="Custom model (using default configuration)" + ) + + +def list_models() -> list[dict]: + """List all available model configurations. + + Returns + ------- + list of dict + List of model information dictionaries containing key, name, + ollama_model, context_window, architecture, memory_q4, and description. + """ + return [ + { + "key": key, + "name": config.name, + "ollama_model": config.ollama_model, + "context_window": config.context_window, + "architecture": config.architecture, + "total_params": config.total_params, + "memory_q4": config.memory_q4, + "description": config.description, + } + for key, config in MODEL_REGISTRY.items() + ] diff --git a/src/training/serve.py b/src/training/serve.py index ea5f9e8..cbcde12 100644 --- a/src/training/serve.py +++ b/src/training/serve.py @@ -1,7 +1,7 @@ -"""FastAPI inference server for the fine-tuned docstring generation model. +"""FastAPI inference server for docstring generation using multiple LLM backends. -Loads a LoRA-adapted model and serves predictions via HTTP. Designed to be -called by the VS Code extension for local, offline docstring generation. +Serves predictions via HTTP using Ollama as the backend. Supports multiple models +including Qwen 2.5 Coder and Qwen3 Coder variants with model-specific configurations. Usage: uvicorn src.training.serve:app --host 0.0.0.0 --port 8000 @@ -9,6 +9,7 @@ Endpoints: POST /generate - Generate a docstring for a given code snippet GET /health - Health check + GET /models - List available models """ import os @@ -19,12 +20,25 @@ from fastapi import FastAPI, HTTPException from pydantic import BaseModel +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + get_model_config, + list_models, +) + # Configuration OLLAMA_URL = os.getenv("OLLAMA_URL", "http://localhost:11434/api/chat") -OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", "qwen2.5-coder:32b") +OLLAMA_MODEL = os.getenv("OLLAMA_MODEL", DEFAULT_MODEL_KEY) REQUEST_TIMEOUT = float(os.getenv("REQUEST_TIMEOUT", "120.0")) +def get_active_model() -> ModelConfig: + """Get the currently active model configuration from environment.""" + return get_model_config(OLLAMA_MODEL) + + def load_system_prompt() -> str: """Load the default system prompt from the prompts directory.""" prompt_path = Path(__file__).parent / "prompts" / "system_prompt.md" @@ -40,26 +54,62 @@ def load_system_prompt() -> str: SYSTEM_PROMPT = load_system_prompt() -def generate_docstring(code: str, max_new_tokens: int = 256) -> str: - """Generate a docstring for the given code snippet using ollama API.""" +def generate_docstring( + code: str, + max_new_tokens: Optional[int] = None, + model_config: Optional[ModelConfig] = None +) -> str: + """Generate a docstring for the given code snippet using ollama API. + + Parameters + ---------- + code : str + The Python code snippet to generate a docstring for. + max_new_tokens : int, optional + Maximum number of tokens to generate. If None, uses the model's default. + model_config : ModelConfig, optional + Model configuration to use. If None, uses the active model from environment. + + Returns + ------- + str + The generated docstring. + + Raises + ------ + RuntimeError + If the request to ollama fails. + ValueError + If the response format is unexpected. + """ + if model_config is None: + model_config = get_active_model() + + # Use model-specific defaults if not provided + if max_new_tokens is None: + max_new_tokens = model_config.sampling.num_predict + payload = { - "model": OLLAMA_MODEL, + "model": model_config.ollama_model, "messages": [ {"role": "system", "content": SYSTEM_PROMPT}, {"role": "user", "content": code} ], "stream": False, - "keep_alive": 0, + "keep_alive": model_config.keep_alive, "options": { - "num_predict": max_new_tokens + "num_predict": max_new_tokens, + "temperature": model_config.sampling.temperature, + "top_p": model_config.sampling.top_p, + "top_k": model_config.sampling.top_k, } } - + try: resp = requests.post(OLLAMA_URL, json=payload, timeout=REQUEST_TIMEOUT) resp.raise_for_status() data = resp.json() - + # Handle /api/chat response format if "message" in data: return data["message"].get("content", "") @@ -88,36 +138,84 @@ def check_ollama_health() -> bool: # FastAPI app -app = FastAPI(title="Docstring Generation API", version="0.1.0") +app = FastAPI( + title="Docstring Generation API", + version="0.2.0", + description="Generate Python docstrings using Qwen Coder models via Ollama" +) class GenerateRequest(BaseModel): """Request model for docstring generation.""" code: str - max_new_tokens: Optional[int] = 256 + max_new_tokens: Optional[int] = None + model: Optional[str] = None class GenerateResponse(BaseModel): """Response model for docstring generation.""" docstring: str + model: str + + +class HealthResponse(BaseModel): + """Response model for health check.""" + status: str + service: str + active_model: str + ollama_model: str + + +class ModelsResponse(BaseModel): + """Response model for listing available models.""" + default: str + active: str + models: list @app.post("/generate", response_model=GenerateResponse) async def generate(request: GenerateRequest): - """Generate a docstring for the given code snippet.""" + """Generate a docstring for the given code snippet. + + Optionally specify a model to use for generation. If not specified, + uses the active model from the OLLAMA_MODEL environment variable. + """ try: - docstring = generate_docstring(request.code, request.max_new_tokens) - return GenerateResponse(docstring=docstring) + # Determine which model to use + if request.model: + model_config = get_model_config(request.model) + else: + model_config = get_active_model() + + docstring = generate_docstring( + request.code, + request.max_new_tokens, + model_config + ) + return GenerateResponse( + docstring=docstring, + model=model_config.ollama_model + ) except Exception as e: raise HTTPException(status_code=500, detail=str(e)) -@app.get("/health") +@app.get("/health", response_model=HealthResponse) async def health(): - """Health check endpoint that verifies ollama is running.""" + """Health check endpoint that verifies ollama is running. + + Returns information about the active model configuration. + """ is_healthy = check_ollama_health() + active_model = get_active_model() + if is_healthy: - return {"status": "healthy", "service": "ollama"} + return HealthResponse( + status="healthy", + service="ollama", + active_model=active_model.name, + ollama_model=active_model.ollama_model + ) else: raise HTTPException( status_code=503, @@ -125,6 +223,21 @@ async def health(): ) +@app.get("/models", response_model=ModelsResponse) +async def get_models(): + """List all available model configurations. + + Returns the default model, currently active model, and a list of + all registered models with their configurations. + """ + active_model = get_active_model() + return ModelsResponse( + default=DEFAULT_MODEL_KEY, + active=OLLAMA_MODEL, + models=list_models() + ) + + def main(): """Start the inference server.""" import uvicorn diff --git a/tests/test_models.py b/tests/test_models.py new file mode 100644 index 0000000..0954689 --- /dev/null +++ b/tests/test_models.py @@ -0,0 +1,199 @@ +"""Tests for src.training.models module.""" + +import pytest + +from src.training.models import ( + MODEL_REGISTRY, + DEFAULT_MODEL_KEY, + ModelConfig, + SamplingConfig, + get_model_config, + create_fallback_config, + list_models, +) + + +class TestSamplingConfig: + """Tests for SamplingConfig dataclass.""" + + def test_default_values(self): + """SamplingConfig should have sensible defaults.""" + config = SamplingConfig() + assert config.temperature == 0.7 + assert config.top_p == 0.9 + assert config.top_k == 40 + assert config.num_predict == 256 + + def test_custom_values(self): + """SamplingConfig should accept custom values.""" + config = SamplingConfig(temperature=1.0, top_p=0.95, top_k=50, num_predict=512) + assert config.temperature == 1.0 + assert config.top_p == 0.95 + assert config.top_k == 50 + assert config.num_predict == 512 + + def test_immutable(self): + """SamplingConfig should be immutable (frozen).""" + config = SamplingConfig() + with pytest.raises(AttributeError): + config.temperature = 0.5 + + +class TestModelConfig: + """Tests for ModelConfig dataclass.""" + + def test_required_fields(self): + """ModelConfig should require essential fields.""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + assert config.name == "Test Model" + assert config.ollama_model == "test:latest" + assert config.context_window == 4096 + + def test_default_values(self): + """ModelConfig should have sensible defaults for optional fields.""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + assert config.keep_alive == 0 + assert config.architecture == "dense" + assert config.total_params is None + assert config.memory_q4 is None + assert config.description == "" + + def test_immutable(self): + """ModelConfig should be immutable (frozen).""" + config = ModelConfig( + name="Test Model", + ollama_model="test:latest", + context_window=4096, + sampling=SamplingConfig() + ) + with pytest.raises(AttributeError): + config.name = "New Name" + + +class TestModelRegistry: + """Tests for MODEL_REGISTRY and related functions.""" + + def test_registry_has_default_model(self): + """Registry should contain the default model.""" + assert DEFAULT_MODEL_KEY in MODEL_REGISTRY + + def test_registry_models_are_valid(self): + """All models in registry should be valid ModelConfig instances.""" + for key, config in MODEL_REGISTRY.items(): + assert isinstance(config, ModelConfig) + assert isinstance(config.sampling, SamplingConfig) + assert config.context_window > 0 + assert config.ollama_model + + def test_qwen25_coder_32b_config(self): + """Qwen 2.5 Coder 32B should have correct configuration.""" + config = MODEL_REGISTRY["qwen2.5-coder-32b"] + assert config.name == "Qwen 2.5 Coder 32B" + assert config.ollama_model == "qwen2.5-coder:32b" + assert config.context_window == 32768 + assert config.architecture == "dense" + assert config.sampling.temperature == 0.7 + assert config.sampling.top_p == 0.9 + + def test_qwen25_coder_7b_config(self): + """Qwen 2.5 Coder 7B should have correct configuration.""" + config = MODEL_REGISTRY["qwen2.5-coder-7b"] + assert config.name == "Qwen 2.5 Coder 7B" + assert config.ollama_model == "qwen2.5-coder:7b" + assert config.architecture == "dense" + + def test_qwen3_coder_30b_config(self): + """Qwen3 Coder 30B should have correct MoE configuration.""" + config = MODEL_REGISTRY["qwen3-coder-30b"] + assert config.name == "Qwen3 Coder 30B-A3B" + assert config.ollama_model == "qwen3-coder:30b-a3b" + assert config.context_window == 262144 # 256K + assert config.architecture == "moe" + # Qwen3 uses different sampling parameters + assert config.sampling.temperature == 1.0 + assert config.sampling.top_p == 0.95 + # MoE models have longer keep_alive + assert config.keep_alive == 300 + + +class TestGetModelConfig: + """Tests for get_model_config function.""" + + def test_get_by_registry_key(self): + """Should return config when given a valid registry key.""" + config = get_model_config("qwen2.5-coder-32b") + assert config.ollama_model == "qwen2.5-coder:32b" + + def test_get_by_ollama_model_name(self): + """Should return config when given a valid Ollama model name.""" + config = get_model_config("qwen2.5-coder:32b") + assert config.name == "Qwen 2.5 Coder 32B" + + def test_fallback_for_unknown_model(self): + """Should create fallback config for unknown models.""" + config = get_model_config("unknown-model:latest") + assert config.ollama_model == "unknown-model:latest" + assert config.architecture == "unknown" + assert config.context_window == 32768 # Conservative default + + def test_fallback_uses_default_sampling(self): + """Fallback config should use default sampling parameters.""" + config = get_model_config("custom-model:v1") + assert config.sampling.temperature == 0.7 + assert config.sampling.top_p == 0.9 + assert config.sampling.top_k == 40 + + +class TestCreateFallbackConfig: + """Tests for create_fallback_config function.""" + + def test_creates_valid_config(self): + """Should create a valid ModelConfig for any model name.""" + config = create_fallback_config("my-model:latest") + assert isinstance(config, ModelConfig) + assert config.ollama_model == "my-model:latest" + assert config.name == "my-model:latest" + + def test_uses_conservative_defaults(self): + """Fallback should use conservative defaults.""" + config = create_fallback_config("test") + assert config.context_window == 32768 + assert config.keep_alive == 0 + assert config.architecture == "unknown" + + +class TestListModels: + """Tests for list_models function.""" + + def test_returns_list(self): + """Should return a list of model info dictionaries.""" + models = list_models() + assert isinstance(models, list) + assert len(models) == len(MODEL_REGISTRY) + + def test_model_info_structure(self): + """Each model info should have required fields.""" + models = list_models() + for model_info in models: + assert "key" in model_info + assert "name" in model_info + assert "ollama_model" in model_info + assert "context_window" in model_info + assert "architecture" in model_info + assert "description" in model_info + + def test_contains_all_registry_models(self): + """Should contain all models from the registry.""" + models = list_models() + keys = {m["key"] for m in models} + assert keys == set(MODEL_REGISTRY.keys()) diff --git a/tests/test_serve.py b/tests/test_serve.py index 922211d..6d932be 100644 --- a/tests/test_serve.py +++ b/tests/test_serve.py @@ -6,7 +6,8 @@ import pytest from fastapi.testclient import TestClient -from src.training.serve import app, check_ollama_health, generate_docstring +from src.training.serve import app, check_ollama_health, generate_docstring, get_active_model +from src.training.models import MODEL_REGISTRY, DEFAULT_MODEL_KEY, get_model_config @pytest.fixture @@ -32,6 +33,9 @@ def test_health_success(self, mock_get, client): data = response.json() assert data["status"] == "healthy" assert data["service"] == "ollama" + # New fields for model info + assert "active_model" in data + assert "ollama_model" in data mock_get.assert_called_once() @patch("src.training.serve.requests.get") @@ -99,6 +103,8 @@ def test_generate_success_message_format(self, mock_post, client): data = response.json() assert "docstring" in data assert "Compute the sum" in data["docstring"] + # New field for model info + assert "model" in data mock_post.assert_called_once() @patch("src.training.serve.requests.post") @@ -310,3 +316,187 @@ def test_generate_docstring_failure(self, mock_post): generate_docstring("def test(): pass") assert "Failed to generate docstring" in str(exc_info.value) + + +class TestModelsEndpoint: + """Tests for the /models endpoint.""" + + def test_list_models_returns_all(self, client): + """GET /models should return all registered models.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "models" in data + assert len(data["models"]) == len(MODEL_REGISTRY) + + def test_list_models_includes_default(self, client): + """Response should indicate the default model.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "default" in data + assert data["default"] == DEFAULT_MODEL_KEY + + def test_list_models_includes_active(self, client): + """Response should indicate the active model.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + assert "active" in data + + def test_list_models_model_structure(self, client): + """Each model in the list should have required fields.""" + response = client.get("/models") + + assert response.status_code == 200 + data = response.json() + for model in data["models"]: + assert "key" in model + assert "name" in model + assert "ollama_model" in model + assert "context_window" in model + assert "architecture" in model + + +class TestModelSelection: + """Tests for per-request model selection.""" + + @patch("src.training.serve.requests.post") + def test_generate_with_specific_model_key(self, mock_post, client): + """Should use the specified model key in the request.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass", + "model": "qwen2.5-coder-7b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + data = response.json() + assert data["model"] == "qwen2.5-coder:7b" + + # Verify the correct model was used in the payload + call_args = mock_post.call_args + payload = call_args[1]["json"] + assert payload["model"] == "qwen2.5-coder:7b" + + @patch("src.training.serve.requests.post") + def test_generate_with_raw_ollama_model(self, mock_post, client): + """Should accept raw Ollama model names.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = { + "code": "def test(): pass", + "model": "qwen2.5-coder:14b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + @patch("src.training.serve.requests.post") + def test_generate_applies_model_sampling(self, mock_post, client): + """Should apply model-specific sampling parameters.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + # Use Qwen3 which has different sampling params + request_data = { + "code": "def test(): pass", + "model": "qwen3-coder-30b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify Qwen3 sampling parameters were used + call_args = mock_post.call_args + payload = call_args[1]["json"] + options = payload["options"] + assert options["temperature"] == 1.0 + assert options["top_p"] == 0.95 + assert options["top_k"] == 40 + + @patch("src.training.serve.requests.post") + def test_generate_applies_model_keep_alive(self, mock_post, client): + """Should apply model-specific keep_alive setting.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + # Use Qwen3 which has keep_alive=300 + request_data = { + "code": "def test(): pass", + "model": "qwen3-coder-30b" + } + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify Qwen3 keep_alive was used + call_args = mock_post.call_args + payload = call_args[1]["json"] + assert payload["keep_alive"] == 300 + + @patch("src.training.serve.requests.post") + def test_generate_without_model_uses_default(self, mock_post, client): + """Should use default model when none specified.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_response.json.return_value = { + "message": {"content": '"""Test docstring."""'} + } + mock_post.return_value = mock_response + + request_data = {"code": "def test(): pass"} + response = client.post("/generate", json=request_data) + + assert response.status_code == 200 + + # Verify default model was used + call_args = mock_post.call_args + payload = call_args[1]["json"] + default_config = get_model_config(DEFAULT_MODEL_KEY) + assert payload["model"] == default_config.ollama_model + + +class TestHealthWithModel: + """Tests for health endpoint with model information.""" + + @patch("src.training.serve.requests.get") + def test_health_reports_active_model(self, mock_get, client): + """Health endpoint should report the active model.""" + mock_response = Mock() + mock_response.status_code = 200 + mock_get.return_value = mock_response + + response = client.get("/health") + + assert response.status_code == 200 + data = response.json() + assert "active_model" in data + assert "ollama_model" in data + # Should match the configured model + active_model = get_active_model() + assert data["active_model"] == active_model.name + assert data["ollama_model"] == active_model.ollama_model From 0ecae074c16c775a0ac60f22c84f7d337fe128f6 Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Tue, 3 Feb 2026 21:12:20 +0100 Subject: [PATCH 4/6] Add GitHub Actions CI with tests, coverage, and linting - Add ci.yml workflow that runs on push/PR to master and qwen2.5-coder - Test on Python 3.12 with pytest and coverage reporting - Run ruff linter (warnings only, doesn't fail build) - Add pytest-cov to dev dependencies Co-Authored-By: Claude Opus 4.5 --- .github/workflows/ci.yml | 33 +++++++++++++++++++++++++++++++++ pyproject.toml | 1 + 2 files changed, 34 insertions(+) create mode 100644 .github/workflows/ci.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 0000000..a02d026 --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,33 @@ +name: CI + +on: + push: + branches: [master, qwen2.5-coder] + pull_request: + branches: [master, qwen2.5-coder] + +jobs: + test: + runs-on: ubuntu-latest + + steps: + - uses: actions/checkout@v4 + + - name: Set up Python 3.12 + uses: actions/setup-python@v5 + with: + python-version: "3.12" + cache: 'pip' + + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install -e ".[dev]" + + - name: Run linter + run: python -m ruff check . + continue-on-error: true + + - name: Run tests with coverage + run: | + python -m pytest tests/ -v --tb=short --cov=src --cov-report=term-missing diff --git a/pyproject.toml b/pyproject.toml index 4778490..8bd2fa4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dependencies = [ [project.optional-dependencies] dev = [ "pytest>=7.0", + "pytest-cov>=4.0", "ruff>=0.1.0", "httpx>=0.24.0", ] From 54e85f189ae7fdde14c75c19a63bb5bfd14b825f Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Tue, 3 Feb 2026 21:13:27 +0100 Subject: [PATCH 5/6] Add Codecov integration for coverage reporting - Generate XML coverage report for Codecov - Upload coverage using codecov-action@v4 - Requires CODECOV_TOKEN secret in repository settings Co-Authored-By: Claude Opus 4.5 --- .github/workflows/ci.yml | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index a02d026..bce6017 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -30,4 +30,12 @@ jobs: - name: Run tests with coverage run: | - python -m pytest tests/ -v --tb=short --cov=src --cov-report=term-missing + python -m pytest tests/ -v --tb=short --cov=src --cov-report=term-missing --cov-report=xml + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v4 + with: + token: ${{ secrets.CODECOV_TOKEN }} + files: ./coverage.xml + fail_ci_if_error: false + verbose: true From 14a64ed5b9a1514d5a6589c59f511e71377ced02 Mon Sep 17 00:00:00 2001 From: Marat Saidov Date: Mon, 16 Feb 2026 18:27:56 +0000 Subject: [PATCH 6/6] Added the codecov-based reporting. --- .github/workflows/ci.yml | 2 ++ codecov.yml | 11 +++++++++++ 2 files changed, 13 insertions(+) create mode 100644 codecov.yml diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bce6017..a00ba4b 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -37,5 +37,7 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} files: ./coverage.xml + flags: unittests + name: ci-coverage fail_ci_if_error: false verbose: true diff --git a/codecov.yml b/codecov.yml new file mode 100644 index 0000000..3f4c24a --- /dev/null +++ b/codecov.yml @@ -0,0 +1,11 @@ +comment: + layout: "reach,diff,flags,tree" + behavior: default + require_changes: false + require_base: no + require_head: yes + +# Optional: configure thresholds or ignore patterns below +# coverage: +# precision: 2 +# round: down \ No newline at end of file