Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 52 additions & 0 deletions core/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -201,6 +201,58 @@
<plugin>
<artifactId>maven-compiler-plugin</artifactId>
</plugin>
<plugin>
<artifactId>maven-surefire-plugin</artifactId>
<executions>
<execution>
<id>basic</id>
<goals>
<goal>test</goal>
</goals>
</execution>
<execution>
<id>apigee-llm</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<test>ApigeeLlmTest</test>
<!-- this test requires the following env -->
<environmentVariables>
<GOOGLE_API_KEY>api-key</GOOGLE_API_KEY>
</environmentVariables>
</configuration>
</execution>
<execution>
<id>apigee-llm-vertex-ai</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<test>ApigeeLlmTest#generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi</test>
<environmentVariables>
<GOOGLE_API_KEY>api-key</GOOGLE_API_KEY>
<!-- runs a second variant of the test method -->
<GOOGLE_GENAI_USE_VERTEXAI>true</GOOGLE_GENAI_USE_VERTEXAI>
</environmentVariables>
</configuration>
</execution>
<execution>
<id>apigee-llm-proxy-url</id>
<goals>
<goal>test</goal>
</goals>
<configuration>
<test>ApigeeLlmTest#build_withoutProxyUrl_readsFromEnvironment</test>
<environmentVariables>
<GOOGLE_API_KEY>api-key</GOOGLE_API_KEY>
<!-- runs a second variant of the test method -->
<APIGEE_PROXY_URL>proxy-url</APIGEE_PROXY_URL>
</environmentVariables>
</configuration>
</execution>
</executions>
</plugin>
</plugins>
</build>
</project>
40 changes: 20 additions & 20 deletions core/src/test/java/com/google/adk/models/ApigeeLlmTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -57,11 +57,11 @@ public void checkApiKey() {
@Test
public void build_withValidModelStrings_succeeds() {
String[] validModelStrings = {
"apigee/gemini-1.5-flash",
"apigee/v1/gemini-1.5-flash",
"apigee/vertex_ai/gemini-1.5-flash",
"apigee/gemini/v1/gemini-1.5-flash",
"apigee/vertex_ai/v1beta/gemini-1.5-flash"
"apigee/whatever-model",
"apigee/v1/whatever-model",
"apigee/vertex_ai/whatever-model",
"apigee/gemini/v1/whatever-model",
"apigee/vertex_ai/v1beta/whatever-model"
};

for (String modelName : validModelStrings) {
Expand Down Expand Up @@ -93,26 +93,26 @@ public void build_withInvalidModelStrings_throwsException() {
public void generateContent_stripsApigeePrefixAndSendsToDelegate() {
when(mockGeminiDelegate.generateContent(any(), anyBoolean())).thenReturn(Flowable.empty());

ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/gemini-1.5-flash", mockGeminiDelegate);
ApigeeLlm llm = new ApigeeLlm("apigee/gemini/v1/whatever-model", mockGeminiDelegate);

LlmRequest request =
LlmRequest.builder()
.model("apigee/gemini/v1/gemini-1.5-flash")
.model("apigee/gemini/v1/whatever-model")
.contents(ImmutableList.of(Content.builder().parts(Part.fromText("hi")).build()))
.build();
llm.generateContent(request, true).test().assertNoErrors();

ArgumentCaptor<LlmRequest> requestCaptor = ArgumentCaptor.forClass(LlmRequest.class);
verify(mockGeminiDelegate).generateContent(requestCaptor.capture(), eq(true));
assertThat(requestCaptor.getValue().model()).hasValue("gemini-1.5-flash");
assertThat(requestCaptor.getValue().model()).hasValue("whatever-model");
}

// Add a test to verify the vertexAI flag is set correctly.
@Test
public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() {
ApigeeLlm llm =
ApigeeLlm.builder()
.modelName("apigee/vertex_ai/gemini-1.5-flash")
.modelName("apigee/vertex_ai/whatever-model")
.proxyUrl(PROXY_URL)
.build();
assertThat(llm.getApiClient().vertexAI()).isTrue();
Expand All @@ -122,7 +122,7 @@ public void generateContent_setsVertexAiFlagCorrectly_withVertexAi() {
public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() {

ApigeeLlm llm =
ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").proxyUrl(PROXY_URL).build();
ApigeeLlm.builder().modelName("apigee/whatever-model").proxyUrl(PROXY_URL).build();
if (System.getenv("GOOGLE_GENAI_USE_VERTEXAI") != null) {
assertThat(llm.getApiClient().vertexAI()).isTrue();
} else {
Expand All @@ -133,7 +133,7 @@ public void generateContent_setsVertexAiFlagCorrectly_withOrWithoutVertexAi() {
@Test
public void generateContent_setsVertexAiFlagCorrectly_withGemini() {
ApigeeLlm llm =
ApigeeLlm.builder().modelName("apigee/gemini/gemini-1.5-flash").proxyUrl(PROXY_URL).build();
ApigeeLlm.builder().modelName("apigee/gemini/whatever-model").proxyUrl(PROXY_URL).build();
assertThat(llm.getApiClient().vertexAI()).isFalse();
}

Expand All @@ -142,11 +142,11 @@ public void generateContent_setsVertexAiFlagCorrectly_withGemini() {
public void generateContent_setsApiVersionCorrectly() {
ImmutableMap<String, String> modelToApiVersion =
ImmutableMap.of(
"apigee/gemini-1.5-flash", "",
"apigee/v1/gemini-1.5-flash", "v1",
"apigee/vertex_ai/gemini-1.5-flash", "",
"apigee/gemini/v1/gemini-1.5-flash", "v1",
"apigee/vertex_ai/v1beta/gemini-1.5-flash", "v1beta");
"apigee/whatever-model", "",
"apigee/v1/whatever-model", "v1",
"apigee/vertex_ai/whatever-model", "",
"apigee/gemini/v1/whatever-model", "v1",
"apigee/vertex_ai/v1beta/whatever-model", "v1beta");

for (Map.Entry<String, String> entry : modelToApiVersion.entrySet()) {
String modelName = entry.getKey();
Expand All @@ -165,7 +165,7 @@ public void build_withCustomHeaders_setsHeadersInHttpOptions() {
ImmutableMap<String, String> customHeaders = ImmutableMap.of("X-Test-Header", "TestValue");
ApigeeLlm llm =
ApigeeLlm.builder()
.modelName("apigee/gemini-1.5-flash")
.modelName("apigee/whatever-model")
.proxyUrl(PROXY_URL)
.customHeaders(customHeaders)
.build();
Expand All @@ -192,14 +192,14 @@ public void build_withTrailingSlashInModel_parsesVersionAndModelId() {
public void build_withoutProxyUrl_readsFromEnvironment() {
String envProxyUrl = System.getenv("APIGEE_PROXY_URL");
if (envProxyUrl != null) {
ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build();
ApigeeLlm llm = ApigeeLlm.builder().modelName("apigee/whatever-model").build();
assertThat(llm.getHttpOptions().baseUrl()).hasValue(envProxyUrl);
} else {
assertThrows(
IllegalArgumentException.class,
() -> ApigeeLlm.builder().modelName("apigee/gemini-1.5-flash").build());
() -> ApigeeLlm.builder().modelName("apigee/whatever-model").build());
ApigeeLlm llm =
ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/gemini-1.5-flash").build();
ApigeeLlm.builder().proxyUrl(PROXY_URL).modelName("apigee/whatever-model").build();
assertThat(llm.getHttpOptions().baseUrl()).hasValue(PROXY_URL);
}
}
Expand Down