Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 24 additions & 2 deletions cmd/commit.go
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
package cmd

import (
"context"
"fmt"
"html"
"io"
"os"
"path"
"strings"
Expand Down Expand Up @@ -73,9 +75,29 @@ func init() {
"display the prompt without sending to OpenAI")
commitCmd.PersistentFlags().BoolVar(&noConfirm, "no_confirm", false,
"skip all confirmation prompts")
commitCmd.PersistentFlags().Bool("stream", false,
"enable streaming output for real-time token display")
_ = viper.BindPFlag("openai.stream", commitCmd.PersistentFlags().Lookup("stream"))
_ = viper.BindPFlag("output.file", commitCmd.PersistentFlags().Lookup("file"))
}

func callCompletion(
ctx context.Context,
client core.Generative,
content string,
w io.Writer,
) (*core.Response, error) {
if viper.GetBool("openai.stream") {
resp, err := client.CompletionStream(ctx, content, w)
if err != nil {
return nil, err
}
fmt.Fprintln(w)
Copy link

Copilot AI Mar 7, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

callCompletion ignores the error from the trailing fmt.Fprintln(w) after streaming. If stdout/file is closed or fails, the command will still proceed as if successful; consider checking and returning this error.

Suggested change
fmt.Fprintln(w)
if _, err := fmt.Fprintln(w); err != nil {
return resp, err
}

Copilot uses AI. Check for mistakes.
return resp, nil
}
return client.Completion(ctx, content)
}

// commitCmd represents the commit command.
var commitCmd = &cobra.Command{
Use: "commit",
Expand Down Expand Up @@ -152,7 +174,7 @@ var commitCmd = &cobra.Command{

// Get summarized comment from diff data
color.Cyan("Summarizing git diff...")
resp, err := client.Completion(cmd.Context(), out)
resp, err := callCompletion(cmd.Context(), client, out, os.Stdout)
if err != nil {
return err
}
Expand Down Expand Up @@ -284,7 +306,7 @@ var commitCmd = &cobra.Command{
viper.GetString("output.lang"),
),
)
resp, err := client.Completion(cmd.Context(), out)
resp, err := callCompletion(cmd.Context(), client, out, os.Stdout)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions cmd/config_list.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ var availableKeys = map[string]string{
"openai.top_p": "Nucleus sampling parameter: controls diversity by limiting to top percentage of probability mass",
"openai.frequency_penalty": "Parameter to reduce repetition by penalizing tokens based on their frequency",
"openai.presence_penalty": "Parameter to encourage topic diversity by penalizing previously used tokens",
"openai.stream": "Enable streaming output for real-time token display",
"prompt.folder": "Directory path for custom prompt templates",
"gemini.project_id": "VertexAI project for Gemini provider",
"gemini.location": "VertexAI location for Gemini provider",
Expand Down
8 changes: 6 additions & 2 deletions cmd/review.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package cmd

import (
"os"
"strings"

"github.com/appleboy/CodeGPT/core"
Expand Down Expand Up @@ -31,6 +32,9 @@ func init() {
"Replace the tip of the current branch by creating a new commit")
reviewCmd.PersistentFlags().BoolVar(&promptOnly, "prompt_only", false,
"Show prompt only without sending request to OpenAI")
reviewCmd.PersistentFlags().Bool("stream", false,
"enable streaming output for real-time token display")
_ = viper.BindPFlag("openai.stream", reviewCmd.PersistentFlags().Lookup("stream"))
}

var reviewCmd = &cobra.Command{
Expand Down Expand Up @@ -87,7 +91,7 @@ var reviewCmd = &cobra.Command{

// Get summarize comment from diff datas
color.Cyan("We are trying to review code changes")
resp, err := client.Completion(cmd.Context(), out)
resp, err := callCompletion(cmd.Context(), client, out, os.Stdout)
if err != nil {
return err
}
Expand All @@ -109,7 +113,7 @@ var reviewCmd = &cobra.Command{
// translate a git commit message
color.Cyan("we are trying to translate code review to " +
prompt.GetLanguage(viper.GetString("output.lang")) + " language")
resp, err := client.Completion(cmd.Context(), out)
resp, err := callCompletion(cmd.Context(), client, out, os.Stdout)
if err != nil {
return err
}
Expand Down
5 changes: 5 additions & 0 deletions core/openai.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ package core

import (
"context"
"io"
"strconv"

"github.com/sashabaranov/go-openai"
Expand Down Expand Up @@ -49,4 +50,8 @@ type Generative interface {
// GetSummaryPrefix generates a summary prefix based on the provided content.
// It takes a context and a string as input and returns a Response pointer and an error.
GetSummaryPrefix(ctx context.Context, content string) (resp *Response, err error)

// CompletionStream generates a completion and streams tokens to the writer as they arrive.
// Returns the full accumulated Response on completion.
CompletionStream(ctx context.Context, content string, w io.Writer) (resp *Response, err error)
}
54 changes: 54 additions & 0 deletions provider/anthropic/anthropic.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,8 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"strings"

"github.com/appleboy/CodeGPT/core"
"github.com/appleboy/CodeGPT/core/transport"
Expand Down Expand Up @@ -65,6 +67,58 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response
}, nil
}

// CompletionStream streams completion tokens to the writer as they arrive.
func (c *Client) CompletionStream(
ctx context.Context,
content string,
w io.Writer,
) (*core.Response, error) {
var sb strings.Builder
resp, err := c.client.CreateMessagesStream(ctx, anthropic.MessagesStreamRequest{
MessagesRequest: anthropic.MessagesRequest{
Model: c.model,
Messages: []anthropic.Message{
anthropic.NewUserTextMessage(content),
},
MaxTokens: c.maxTokens,
Temperature: convert.ToPtr(c.temperature),
TopP: convert.ToPtr(c.topP),
},
OnContentBlockDelta: func(data anthropic.MessagesEventContentBlockDeltaData) {
if data.Delta.Text != nil {
sb.WriteString(*data.Delta.Text)
_, _ = io.WriteString(w, *data.Delta.Text)
}
},
})
if err != nil {
var e *anthropic.APIError
if errors.As(err, &e) {
fmt.Printf("Messages error, type: %s, message: %s", e.Type, e.Message)
} else {
fmt.Printf("Messages error: %v\n", err)
}
return nil, err
}

usage := core.Usage{
PromptTokens: resp.Usage.InputTokens,
CompletionTokens: resp.Usage.OutputTokens,
TotalTokens: resp.Usage.InputTokens + resp.Usage.OutputTokens,
}

if resp.Usage.CacheCreationInputTokens > 0 || resp.Usage.CacheReadInputTokens > 0 {
usage.PromptTokensDetails = &openai.PromptTokensDetails{
CachedTokens: resp.Usage.CacheCreationInputTokens + resp.Usage.CacheReadInputTokens,
}
}

return &core.Response{
Content: sb.String(),
Usage: usage,
}, nil
}

// GetSummaryPrefix is an API call to get a summary prefix using function call.
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) {
request := anthropic.MessagesRequest{
Expand Down
73 changes: 73 additions & 0 deletions provider/anthropic/stream_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
package anthropic

import (
"bytes"
"context"
"fmt"
"net/http"
"net/http/httptest"
"testing"

"github.com/liushuangls/go-anthropic/v2"
)

func TestCompletionStream(t *testing.T) {
// Create a mock SSE server that returns Anthropic streaming events
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/event-stream")

events := []string{
`event: message_start
data: {"type":"message_start","message":{"id":"msg_1","type":"message","role":"assistant","content":[],"model":"claude-sonnet-4-20250514","usage":{"input_tokens":10,"output_tokens":0}}}`,
`event: content_block_start
data: {"type":"content_block_start","index":0,"content_block":{"type":"text","text":""}}`,
`event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hello"}}`,
`event: content_block_delta
data: {"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":" world"}}`,
`event: content_block_stop
data: {"type":"content_block_stop","index":0}`,
`event: message_delta
data: {"type":"message_delta","delta":{"stop_reason":"end_turn"},"usage":{"output_tokens":2}}`,
`event: message_stop
data: {"type":"message_stop"}`,
}

for _, event := range events {
fmt.Fprintf(w, "%s\n\n", event)
}
}))
defer server.Close()

client := &Client{
client: anthropic.NewClient(
"test-token",
anthropic.WithBaseURL(server.URL),
),
model: anthropic.ModelClaude3Haiku20240307,
maxTokens: 1024,
}

var buf bytes.Buffer
resp, err := client.CompletionStream(context.Background(), "test prompt", &buf)
if err != nil {
t.Fatalf("CompletionStream failed: %v", err)
}

expectedContent := "Hello world"
if resp.Content != expectedContent {
t.Errorf("expected content %q, got %q", expectedContent, resp.Content)
}

if buf.String() != expectedContent {
t.Errorf("expected writer output %q, got %q", expectedContent, buf.String())
}

if resp.Usage.PromptTokens != 10 {
t.Errorf("expected prompt tokens 10, got %d", resp.Usage.PromptTokens)
}

if resp.Usage.TotalTokens != 12 {
t.Errorf("expected total tokens 12, got %d", resp.Usage.TotalTokens)
}
}
58 changes: 58 additions & 0 deletions provider/gemini/gemini.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,9 @@ import (
"context"
"errors"
"fmt"
"io"
"net/http"
"strings"

"github.com/appleboy/CodeGPT/core"
"github.com/appleboy/CodeGPT/core/transport"
Expand Down Expand Up @@ -66,6 +68,62 @@ func (c *Client) Completion(ctx context.Context, content string) (*core.Response
}, nil
}

// CompletionStream streams completion tokens to the writer as they arrive.
func (c *Client) CompletionStream(
ctx context.Context,
content string,
w io.Writer,
) (*core.Response, error) {
cfg := &genai.GenerateContentConfig{
TopP: convert.ToPtr(c.topP),
Temperature: convert.ToPtr(c.temperature),
MaxOutputTokens: c.maxTokens,
}
data := []*genai.Content{
{
Role: "user",
Parts: []*genai.Part{
{
Text: content,
},
},
},
}

var sb strings.Builder
usage := core.Usage{}
for resp, err := range c.client.Models.GenerateContentStream(ctx, c.model, data, cfg) {
if err != nil {
return nil, err
}

if resp.UsageMetadata != nil {
usage.PromptTokens = int(resp.UsageMetadata.PromptTokenCount)
usage.CompletionTokens = int(resp.UsageMetadata.CandidatesTokenCount)
usage.TotalTokens = int(resp.UsageMetadata.TotalTokenCount)
if resp.UsageMetadata.CachedContentTokenCount > 0 {
usage.PromptTokensDetails = &openai.PromptTokensDetails{
CachedTokens: int(resp.UsageMetadata.CachedContentTokenCount),
}
}
}

if len(resp.Candidates) > 0 && resp.Candidates[0].Content != nil {
for _, part := range resp.Candidates[0].Content.Parts {
if part.Text != "" {
sb.WriteString(part.Text)
_, _ = io.WriteString(w, part.Text)
}
}
}
}

return &core.Response{
Content: sb.String(),
Usage: usage,
}, nil
}

// GetSummaryPrefix is an API call to get a summary prefix using function call.
func (c *Client) GetSummaryPrefix(ctx context.Context, content string) (*core.Response, error) {
cfg := &genai.GenerateContentConfig{
Expand Down
Loading
Loading