-
Notifications
You must be signed in to change notification settings - Fork 3.6k
Add symbol extraction to get_file_contents #1983
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: sammorrowdrums/tree-sitter-semantic-diff
Are you sure you want to change the base?
Changes from all commits
db10882
9f6f67d
04e669a
fb8a684
d0290bf
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,67 @@ | ||
| package github | ||
|
|
||
| import ( | ||
| "fmt" | ||
| "strings" | ||
| ) | ||
|
|
||
| // ExtractSymbol searches source code for a named symbol and returns its text. | ||
| // It searches top-level declarations first, then recursively searches nested | ||
| // declarations (e.g. methods inside classes). Returns the symbol text and its | ||
| // kind, or an error if the symbol is not found or the language is unsupported. | ||
|
Comment on lines
+8
to
+11
|
||
| func ExtractSymbol(path string, source []byte, symbolName string) (text string, kind string, err error) { | ||
| config := languageForPath(path) | ||
| if config == nil { | ||
| return "", "", fmt.Errorf("symbol extraction is not supported for this file type") | ||
| } | ||
|
|
||
| decls, err := extractDeclarations(config, source) | ||
| if err != nil { | ||
| return "", "", fmt.Errorf("failed to parse file: %w", err) | ||
| } | ||
|
|
||
| // Search top-level declarations | ||
| if text, kind, found := findSymbol(decls, symbolName); found { | ||
| return text, kind, nil | ||
| } | ||
|
|
||
| // Search nested declarations (methods inside classes, etc.) | ||
| for _, decl := range decls { | ||
| nested := extractChildDeclarationsFromText(config, decl.Text) | ||
| if text, kind, found := findSymbol(nested, symbolName); found { | ||
| return text, kind, nil | ||
| } | ||
| } | ||
|
|
||
| // Build list of available symbols for the error message | ||
| available := listSymbolNames(config, decls) | ||
| return "", "", fmt.Errorf("symbol %q not found. Available symbols: %s", symbolName, strings.Join(available, ", ")) | ||
|
Comment on lines
+28
to
+38
|
||
| } | ||
|
|
||
| // findSymbol searches a slice of declarations for a matching name. | ||
| func findSymbol(decls []declaration, name string) (string, string, bool) { | ||
| for _, d := range decls { | ||
| if d.Name == name { | ||
| return d.Text, d.Kind, true | ||
| } | ||
| } | ||
| return "", "", false | ||
| } | ||
|
|
||
| // listSymbolNames returns all symbol names from top-level and one level of | ||
| // nested declarations, for use in error messages. | ||
| func listSymbolNames(config *languageConfig, decls []declaration) []string { | ||
| var names []string | ||
| for _, d := range decls { | ||
| if !strings.HasPrefix(d.Name, "_") { | ||
| names = append(names, d.Name) | ||
| } | ||
| nested := extractChildDeclarationsFromText(config, d.Text) | ||
| for _, n := range nested { | ||
| if !strings.HasPrefix(n.Name, "_") { | ||
| names = append(names, n.Name) | ||
| } | ||
| } | ||
| } | ||
| return names | ||
| } | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,116 @@ | ||
| package github | ||
|
|
||
| import ( | ||
| "testing" | ||
|
|
||
| "github.com/stretchr/testify/assert" | ||
| "github.com/stretchr/testify/require" | ||
| ) | ||
|
|
||
| func TestExtractSymbol(t *testing.T) { | ||
| t.Run("Go function", func(t *testing.T) { | ||
| source := []byte("package main\n\nfunc hello() {\n\tfmt.Println(\"hello\")\n}\n\nfunc world() {\n\tfmt.Println(\"world\")\n}\n") | ||
| text, kind, err := ExtractSymbol("main.go", source, "hello") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "function_declaration", kind) | ||
| assert.Contains(t, text, "func hello()") | ||
| assert.Contains(t, text, "hello") | ||
| assert.NotContains(t, text, "world") | ||
| }) | ||
|
|
||
| t.Run("Go method with receiver", func(t *testing.T) { | ||
| source := []byte("package main\n\ntype Server struct{}\n\nfunc (s *Server) Start() {\n\tlog.Println(\"start\")\n}\n\nfunc (s *Server) Stop() {\n\tlog.Println(\"stop\")\n}\n") | ||
| text, kind, err := ExtractSymbol("main.go", source, "(*Server).Start") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "method_declaration", kind) | ||
| assert.Contains(t, text, "Start") | ||
| assert.NotContains(t, text, "Stop") | ||
| }) | ||
|
|
||
| t.Run("Go type", func(t *testing.T) { | ||
| source := []byte("package main\n\ntype Config struct {\n\tHost string\n\tPort int\n}\n") | ||
| text, kind, err := ExtractSymbol("main.go", source, "Config") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "type_declaration", kind) | ||
| assert.Contains(t, text, "Host string") | ||
| }) | ||
|
|
||
| t.Run("Python function", func(t *testing.T) { | ||
| source := []byte("def hello():\n print('hello')\n\ndef world():\n print('world')\n") | ||
| text, kind, err := ExtractSymbol("app.py", source, "hello") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "function_definition", kind) | ||
| assert.Contains(t, text, "print('hello')") | ||
| assert.NotContains(t, text, "world") | ||
| }) | ||
|
|
||
| t.Run("Python class method (nested)", func(t *testing.T) { | ||
| source := []byte("class Dog:\n def bark(self):\n return 'woof'\n def fetch(self):\n return 'ball'\n") | ||
| text, kind, err := ExtractSymbol("app.py", source, "bark") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "function_definition", kind) | ||
| assert.Contains(t, text, "woof") | ||
| assert.NotContains(t, text, "ball") | ||
| }) | ||
|
|
||
| t.Run("TypeScript class", func(t *testing.T) { | ||
| source := []byte("class Api {\n get() {\n return fetch('/data');\n }\n}\n\nfunction helper() { return 1; }\n") | ||
| text, kind, err := ExtractSymbol("api.ts", source, "Api") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "class_declaration", kind) | ||
| assert.Contains(t, text, "get()") | ||
| assert.NotContains(t, text, "helper") | ||
| }) | ||
|
|
||
| t.Run("TypeScript class method (nested)", func(t *testing.T) { | ||
| source := []byte("class Api {\n get() {\n return fetch('/data');\n }\n post() {\n return fetch('/post');\n }\n}\n") | ||
| text, kind, err := ExtractSymbol("api.ts", source, "get") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "method_definition", kind) | ||
| assert.Contains(t, text, "/data") | ||
| assert.NotContains(t, text, "/post") | ||
| }) | ||
|
|
||
| t.Run("symbol not found lists available", func(t *testing.T) { | ||
| source := []byte("package main\n\nfunc hello() {}\n\nfunc world() {}\n") | ||
| _, _, err := ExtractSymbol("main.go", source, "nonexistent") | ||
| require.Error(t, err) | ||
| assert.Contains(t, err.Error(), "not found") | ||
| assert.Contains(t, err.Error(), "hello") | ||
| assert.Contains(t, err.Error(), "world") | ||
| }) | ||
|
|
||
| t.Run("unsupported file type", func(t *testing.T) { | ||
| source := []byte("some content") | ||
| _, _, err := ExtractSymbol("README.md", source, "anything") | ||
| require.Error(t, err) | ||
| assert.Contains(t, err.Error(), "not supported") | ||
| }) | ||
|
|
||
| t.Run("Java class with methods", func(t *testing.T) { | ||
| source := []byte("class Calculator {\n int add(int a, int b) {\n return a + b;\n }\n int multiply(int a, int b) {\n return a * b;\n }\n}\n") | ||
| text, kind, err := ExtractSymbol("Calculator.java", source, "add") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "method_declaration", kind) | ||
| assert.Contains(t, text, "a + b") | ||
| assert.NotContains(t, text, "a * b") | ||
| }) | ||
|
|
||
| t.Run("Rust function", func(t *testing.T) { | ||
| source := []byte("fn hello() {\n println!(\"hello\");\n}\n\nfn world() {\n println!(\"world\");\n}\n") | ||
| text, kind, err := ExtractSymbol("main.rs", source, "hello") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "function_item", kind) | ||
| assert.Contains(t, text, "hello") | ||
| assert.NotContains(t, text, "world") | ||
| }) | ||
|
|
||
| t.Run("Go var declaration", func(t *testing.T) { | ||
| source := []byte("package main\n\nvar defaultTimeout = 30\n\nvar maxRetries = 3\n") | ||
| text, kind, err := ExtractSymbol("main.go", source, "defaultTimeout") | ||
| require.NoError(t, err) | ||
| assert.Equal(t, "var_declaration", kind) | ||
| assert.Contains(t, text, "30") | ||
| assert.NotContains(t, text, "maxRetries") | ||
| }) | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Adding the
symbolproperty changes the tool input schema, butTest_GetFileContentsinrepositories_test.gocurrently asserts the old schema keys (it checks forshabut notsymbol). That test will fail once this PR is merged; update the schema assertions and consider adding a tool-level test case that passessymboland verifies the extracted symbol is returned.This issue also appears on line 655 of the same file.