aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2023-02-18 14:53:29 +0100
committerGitHub <noreply@github.com>2023-02-18 14:53:29 +0100
commit665da0c90fc7d9a26d4307d24840267d809147e6 (patch)
tree67c07c2427161a24a1cc053965afd54795f0743d
parent21e0c4913dd0bc88dfec3016d4552f57fc0e7c4b (diff)
parent136821c9cdb5f2b9199ed7ebb7cc743faec6b785 (diff)
Merge pull request #26099 from vespa-engine/lesters/generator-model-ids
Add model ids for generator models
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java23
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java6
2 files changed, 22 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
index 0abd7212017..76403d369dd 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java
@@ -3,6 +3,9 @@ package com.yahoo.vespa.model.container.xml;
import com.yahoo.text.XML;
import org.w3c.dom.Element;
+
+import java.util.Collections;
+import java.util.HashMap;
import java.util.Map;
import java.util.stream.Collectors;
@@ -14,10 +17,22 @@ import java.util.stream.Collectors;
*/
public class ModelIdResolver {
- private static final Map<String, String> providedModels =
- Map.of("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx",
- "mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx",
- "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt");
+ private static Map<String, String> setupProvidedModels() {
+ Map<String, String> models = new HashMap<>();
+ models.put("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx");
+ models.put("mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx");
+ models.put("bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt");
+ models.put("flan-t5-vocab", "https://data.vespa.oath.cloud/onnx_models/flan-t5-spiece.model");
+ models.put("flan-t5-small-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-encoder-model.onnx");
+ models.put("flan-t5-small-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-decoder-model.onnx");
+ models.put("flan-t5-base-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-encoder-model.onnx");
+ models.put("flan-t5-base-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-decoder-model.onnx");
+ models.put("flan-t5-large-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-encoder-model.onnx");
+ models.put("flan-t5-large-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-decoder-model.onnx");
+ return Collections.unmodifiableMap(models);
+ }
+
+ private static final Map<String, String> providedModels = setupProvidedModels();
/**
* Finds any config values of type 'model' below the given config element and
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 509d8527bf5..e7f8086e554 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
import static org.junit.jupiter.api.Assertions.fail;
public class EmbedderTestCase {
@@ -91,8 +92,7 @@ public class EmbedderTestCase {
" </config>" +
"</component>";
assertTransformThrows(embedder,
- "Unknown model id 'my_model_id' on 'transformerModel'. " +
- "Available models are [bert-base-uncased, minilm-l6-v2, mpnet-base-v2]",
+ "Unknown model id 'my_model_id' on 'transformerModel'",
true);
}
@@ -194,7 +194,7 @@ public class EmbedderTestCase {
ModelIdResolver.resolveModelIds(createElement(embedder), hosted);
fail("Expected exception was not thrown: " + expectedMessage);
} catch (IllegalArgumentException e) {
- assertEquals(expectedMessage, e.getMessage());
+ assertTrue(e.getMessage().contains(expectedMessage), "Expected error message not found");
}
}