Skip to content

Commit 4ff363f

Browse files
committed
Add filtering by cluster name
1 parent a950cd5 commit 4ff363f

File tree

8 files changed

+516
-23
lines changed

8 files changed

+516
-23
lines changed
Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
package vulnerability
2+
3+
import (
4+
"context"
5+
"fmt"
6+
7+
v1 "github.com/stackrox/rox/generated/api/v1"
8+
)
9+
10+
// resolveClusterNameToID resolves a cluster name to its ID.
11+
// Returns empty string if clusterName is empty (no filtering).
12+
// Returns error if cluster name is not found or if API call fails.
13+
func resolveClusterNameToID(ctx context.Context, clusterName string, client v1.ClustersServiceClient) (string, error) {
14+
// Empty cluster name means no filtering
15+
if clusterName == "" {
16+
return "", nil
17+
}
18+
19+
// Fetch all clusters
20+
resp, err := client.GetClusters(ctx, &v1.GetClustersRequest{})
21+
if err != nil {
22+
return "", fmt.Errorf("failed to fetch clusters: %w", err)
23+
}
24+
25+
// Find cluster by exact name match (case-sensitive)
26+
for _, cluster := range resp.GetClusters() {
27+
if cluster.GetName() == clusterName {
28+
return cluster.GetId(), nil
29+
}
30+
}
31+
32+
return "", fmt.Errorf("cluster with name %q not found", clusterName)
33+
}
Lines changed: 110 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,110 @@
1+
package vulnerability
2+
3+
import (
4+
"context"
5+
"errors"
6+
"net"
7+
"testing"
8+
9+
v1 "github.com/stackrox/rox/generated/api/v1"
10+
"github.com/stackrox/rox/generated/storage"
11+
"github.com/stackrox/stackrox-mcp/internal/toolsets/mock"
12+
"github.com/stretchr/testify/assert"
13+
"github.com/stretchr/testify/require"
14+
"google.golang.org/grpc"
15+
"google.golang.org/grpc/credentials/insecure"
16+
)
17+
18+
func TestResolveClusterNameToID(t *testing.T) {
19+
tests := map[string]struct {
20+
clusterName string
21+
mockClusters []*storage.Cluster
22+
mockError error
23+
expectedID string
24+
expectError bool
25+
expectedErrText string
26+
}{
27+
"empty cluster name returns empty ID": {
28+
clusterName: "",
29+
mockClusters: []*storage.Cluster{{Id: "cluster-1", Name: "production"}},
30+
expectedID: "",
31+
expectError: false,
32+
},
33+
"cluster name found returns correct ID": {
34+
clusterName: "production",
35+
mockClusters: []*storage.Cluster{
36+
{Id: "cluster-0", Name: "dev"},
37+
{Id: "cluster-1", Name: "production"},
38+
{Id: "cluster-2", Name: "staging"},
39+
},
40+
expectedID: "cluster-1",
41+
expectError: false,
42+
},
43+
"cluster name not found returns error": {
44+
clusterName: "nonexistent",
45+
mockClusters: []*storage.Cluster{
46+
{Id: "cluster-1", Name: "production"},
47+
},
48+
expectedID: "",
49+
expectError: true,
50+
expectedErrText: `cluster with name "nonexistent" not found`,
51+
},
52+
"case sensitive matching": {
53+
clusterName: "Production",
54+
mockClusters: []*storage.Cluster{
55+
{Id: "cluster-1", Name: "production"},
56+
},
57+
expectedID: "",
58+
expectError: true,
59+
expectedErrText: `cluster with name "Production" not found`,
60+
},
61+
"API error propagation": {
62+
clusterName: "production",
63+
mockError: errors.New("API connection failed"),
64+
expectedID: "",
65+
expectError: true,
66+
expectedErrText: "failed to fetch clusters:",
67+
},
68+
"exact match required": {
69+
clusterName: "prod",
70+
mockClusters: []*storage.Cluster{
71+
{Id: "cluster-1", Name: "production"},
72+
},
73+
expectedID: "",
74+
expectError: true,
75+
expectedErrText: `cluster with name "prod" not found`,
76+
},
77+
}
78+
79+
for testName, testCase := range tests {
80+
t.Run(testName, func(t *testing.T) {
81+
mockService := mock.NewClustersServiceMock(testCase.mockClusters, testCase.mockError)
82+
grpcServer, listener := mock.SetupClusterServer(mockService)
83+
defer grpcServer.Stop()
84+
85+
// Create a gRPC client connection to the mock server
86+
conn, err := grpc.NewClient(
87+
"passthrough://buffer",
88+
grpc.WithLocalDNSResolution(),
89+
grpc.WithContextDialer(func(_ context.Context, _ string) (net.Conn, error) {
90+
return listener.Dial()
91+
}),
92+
grpc.WithTransportCredentials(insecure.NewCredentials()),
93+
)
94+
require.NoError(t, err)
95+
defer conn.Close()
96+
97+
client := v1.NewClustersServiceClient(conn)
98+
99+
id, err := resolveClusterNameToID(context.Background(), testCase.clusterName, client)
100+
101+
if testCase.expectError {
102+
require.Error(t, err)
103+
assert.Contains(t, err.Error(), testCase.expectedErrText)
104+
} else {
105+
require.NoError(t, err)
106+
assert.Equal(t, testCase.expectedID, id)
107+
}
108+
})
109+
}
110+
}

internal/toolsets/vulnerability/clusters.go

Lines changed: 30 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,15 +18,20 @@ import (
1818

1919
// getClustersForCVEInput defines the input parameters for get_clusters_for_cve tool.
2020
type getClustersForCVEInput struct {
21-
CVEName string `json:"cveName"`
22-
FilterClusterID string `json:"filterClusterId,omitempty"`
21+
CVEName string `json:"cveName"`
22+
FilterClusterID string `json:"filterClusterId,omitempty"`
23+
FilterClusterName string `json:"filterClusterName,omitempty"`
2324
}
2425

2526
func (input *getClustersForCVEInput) validate() error {
2627
if input.CVEName == "" {
2728
return errors.New("CVE name is required")
2829
}
2930

31+
if input.FilterClusterID != "" && input.FilterClusterName != "" {
32+
return errors.New("cannot specify both filterClusterId and filterClusterName")
33+
}
34+
3035
return nil
3136
}
3237

@@ -76,9 +81,7 @@ func (t *getClustersForCVETool) GetTool() *mcp.Tool {
7681
" Call ALL THREE CVE tools (get_clusters_with_orchestrator_cve, get_deployments_for_cve, get_nodes_for_cve)" +
7782
" for comprehensive coverage." +
7883
" 2) When user asks specifically about 'orchestrator', 'Kubernetes components'," +
79-
" or 'control plane': Use ONLY this tool." +
80-
" 3) For single cluster queries (e.g., 'in cluster X'): First call list_clusters to get cluster ID," +
81-
" then call ONLY this tool with filterClusterId.",
84+
" or 'control plane': Use ONLY this tool.",
8285
InputSchema: getClustersForCVEInputSchema(),
8386
}
8487
}
@@ -97,11 +100,13 @@ func getClustersForCVEInputSchema() *jsonschema.Schema {
97100

98101
schema.Properties["cveName"].Description = "CVE name to filter clusters (e.g., CVE-2021-44228)"
99102
schema.Properties["filterClusterId"].Description =
100-
"Optional cluster ID (cluster ID only, not cluster name) to verify if CVE is detected in a specific cluster." +
101-
" Only use this parameter when the user's query explicitly mentions a specific cluster name." +
102-
" When checking if a CVE exists at all, call without this parameter to check all clusters at once." +
103-
" To resolve cluster names to IDs, use list_clusters tool first." +
104-
" If the cluster doesn't exist, respond that the CVE is not detected in that cluster (since it doesn't exist)."
103+
"Optional cluster ID to verify if CVE is detected in a specific cluster." +
104+
" Cannot be used together with filterClusterName." +
105+
" When checking if a CVE exists at all, call without this parameter to check all clusters at once."
106+
schema.Properties["filterClusterName"].Description =
107+
"Optional cluster name to verify if CVE is detected in a specific cluster." +
108+
" Cannot be used together with filterClusterId." +
109+
" When checking if a CVE exists at all, call without this parameter to check all clusters at once."
105110

106111
return schema
107112
}
@@ -143,7 +148,21 @@ func (t *getClustersForCVETool) handle(
143148

144149
clustersClient := v1.NewClustersServiceClient(conn)
145150

146-
query := buildClusterQuery(input)
151+
// Resolve cluster name to ID if provided
152+
resolvedClusterID := input.FilterClusterID
153+
if input.FilterClusterName != "" {
154+
resolvedClusterID, err = resolveClusterNameToID(callCtx, input.FilterClusterName, clustersClient)
155+
if err != nil {
156+
return nil, nil, err
157+
}
158+
}
159+
160+
// Build query using the resolved cluster ID
161+
queryInput := getClustersForCVEInput{
162+
CVEName: input.CVEName,
163+
FilterClusterID: resolvedClusterID,
164+
}
165+
query := buildClusterQuery(queryInput)
147166

148167
resp, err := clustersClient.GetClusters(callCtx, &v1.GetClustersRequest{
149168
Query: query,

internal/toolsets/vulnerability/clusters_test.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,29 @@ func TestClusterInputValidate(t *testing.T) {
7878
expectError: true,
7979
errorMsg: "CVE name is required",
8080
},
81+
"both cluster ID and name provided": {
82+
input: getClustersForCVEInput{
83+
CVEName: "CVE-2021-44228",
84+
FilterClusterID: "cluster-123",
85+
FilterClusterName: "production",
86+
},
87+
expectError: true,
88+
errorMsg: "cannot specify both filterClusterId and filterClusterName",
89+
},
90+
"only cluster ID provided": {
91+
input: getClustersForCVEInput{
92+
CVEName: "CVE-2021-44228",
93+
FilterClusterID: "cluster-123",
94+
},
95+
expectError: false,
96+
},
97+
"only cluster name provided": {
98+
input: getClustersForCVEInput{
99+
CVEName: "CVE-2021-44228",
100+
FilterClusterName: "production",
101+
},
102+
expectError: false,
103+
},
81104
}
82105

83106
for testName, testCase := range tests {
@@ -297,3 +320,68 @@ func TestClusterHandle_WithFilters(t *testing.T) {
297320
})
298321
}
299322
}
323+
324+
func TestClusterHandle_WithClusterName(t *testing.T) {
325+
tests := map[string]struct {
326+
clusterName string
327+
availableClusters []*storage.Cluster
328+
expectError bool
329+
expectedErrText string
330+
expectedQuery string
331+
}{
332+
"cluster name found": {
333+
clusterName: "production",
334+
availableClusters: []*storage.Cluster{
335+
{Id: "cluster-1", Name: "production"},
336+
{Id: "cluster-2", Name: "staging"},
337+
},
338+
expectError: false,
339+
expectedQuery: `CVE:"CVE-2021-44228"+Cluster ID:"cluster-1"`,
340+
},
341+
"cluster name not found": {
342+
clusterName: "nonexistent",
343+
availableClusters: []*storage.Cluster{
344+
{Id: "cluster-1", Name: "production"},
345+
},
346+
expectError: true,
347+
expectedErrText: `cluster with name "nonexistent" not found`,
348+
},
349+
"empty cluster name": {
350+
clusterName: "",
351+
availableClusters: []*storage.Cluster{
352+
{Id: "cluster-1", Name: "production"},
353+
},
354+
expectError: false,
355+
expectedQuery: `CVE:"CVE-2021-44228"`,
356+
},
357+
}
358+
359+
for testName, testCase := range tests {
360+
t.Run(testName, func(t *testing.T) {
361+
mockService := mock.NewClustersServiceMock(testCase.availableClusters, nil)
362+
grpcServer, listener := mock.SetupClusterServer(mockService)
363+
defer grpcServer.Stop()
364+
365+
tool, ok := NewGetClustersForCVETool(createTestClient(t, listener)).(*getClustersForCVETool)
366+
require.True(t, ok)
367+
368+
input := getClustersForCVEInput{
369+
CVEName: "CVE-2021-44228",
370+
FilterClusterName: testCase.clusterName,
371+
}
372+
373+
result, output, err := tool.handle(context.Background(), &mcp.CallToolRequest{}, input)
374+
375+
if testCase.expectError {
376+
require.Error(t, err)
377+
assert.Contains(t, err.Error(), testCase.expectedErrText)
378+
assert.Nil(t, output)
379+
} else {
380+
require.NoError(t, err)
381+
require.NotNil(t, output)
382+
assert.Nil(t, result)
383+
assert.Contains(t, mockService.GetLastCallQuery(), testCase.expectedQuery)
384+
}
385+
})
386+
}
387+
}

internal/toolsets/vulnerability/deployments.go

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ const (
3333
type getDeploymentsForCVEInput struct {
3434
CVEName string `json:"cveName"`
3535
FilterClusterID string `json:"filterClusterId,omitempty"`
36+
FilterClusterName string `json:"filterClusterName,omitempty"`
3637
FilterNamespace string `json:"filterNamespace,omitempty"`
3738
FilterPlatform filterPlatformType `json:"filterPlatform,omitempty"`
3839
IncludeDetectedImages bool `json:"includeDetectedImages,omitempty"`
@@ -44,6 +45,10 @@ func (input *getDeploymentsForCVEInput) validate() error {
4445
return errors.New("CVE name is required")
4546
}
4647

48+
if input.FilterClusterID != "" && input.FilterClusterName != "" {
49+
return errors.New("cannot specify both filterClusterId and filterClusterName")
50+
}
51+
4752
return nil
4853
}
4954

@@ -100,9 +105,7 @@ func (t *getDeploymentsForCVETool) GetTool() *mcp.Tool {
100105
" Call ALL THREE CVE tools (get_clusters_with_orchestrator_cve, get_deployments_for_cve, get_nodes_for_cve)" +
101106
" for comprehensive coverage." +
102107
" 2) When user asks specifically about 'deployments', 'workloads', 'applications'," +
103-
" or 'containers': Use ONLY this tool." +
104-
" 3) For single cluster queries (e.g., 'in cluster X'): First call list_clusters to get cluster ID," +
105-
" then call ONLY this tool with filterClusterId.",
108+
" or 'containers': Use ONLY this tool.",
106109
InputSchema: getDeploymentsForCVEInputSchema(),
107110
}
108111
}
@@ -120,7 +123,10 @@ func getDeploymentsForCVEInputSchema() *jsonschema.Schema {
120123
schema.Required = []string{"cveName"}
121124

122125
schema.Properties["cveName"].Description = "CVE name to filter deployments (e.g., CVE-2021-44228)"
123-
schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter deployments"
126+
schema.Properties["filterClusterId"].Description = "Optional cluster ID to filter deployments." +
127+
" Cannot be used together with filterClusterName."
128+
schema.Properties["filterClusterName"].Description = "Optional cluster name to filter deployments." +
129+
" Cannot be used together with filterClusterId."
124130
schema.Properties["filterNamespace"].Description = "Optional namespace to filter deployments"
125131

126132
schema.Properties["filterPlatform"].Description =
@@ -287,10 +293,29 @@ func (t *getDeploymentsForCVETool) handle(
287293
}
288294

289295
callCtx := auth.WithMCPRequestContext(ctx, req)
296+
297+
// Resolve cluster name to ID if provided
298+
resolvedClusterID := input.FilterClusterID
299+
if input.FilterClusterName != "" {
300+
clustersClient := v1.NewClustersServiceClient(conn)
301+
resolvedClusterID, err = resolveClusterNameToID(callCtx, input.FilterClusterName, clustersClient)
302+
if err != nil {
303+
return nil, nil, err
304+
}
305+
}
306+
290307
deploymentClient := v1.NewDeploymentServiceClient(conn)
291308

309+
// Build query using the resolved cluster ID
310+
queryInput := getDeploymentsForCVEInput{
311+
CVEName: input.CVEName,
312+
FilterClusterID: resolvedClusterID,
313+
FilterNamespace: input.FilterNamespace,
314+
FilterPlatform: input.FilterPlatform,
315+
}
316+
292317
listReq := &v1.RawQuery{
293-
Query: buildQuery(input),
318+
Query: buildQuery(queryInput),
294319
Pagination: &v1.Pagination{
295320
Offset: currCursor.GetOffset(),
296321
Limit: defaultLimit + 1,

0 commit comments

Comments
 (0)