diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2023-02-18 14:53:29 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-02-18 14:53:29 +0100 |
commit | 665da0c90fc7d9a26d4307d24840267d809147e6 (patch) | |
tree | 67c07c2427161a24a1cc053965afd54795f0743d | |
parent | 21e0c4913dd0bc88dfec3016d4552f57fc0e7c4b (diff) | |
parent | 136821c9cdb5f2b9199ed7ebb7cc743faec6b785 (diff) |
Merge pull request #26099 from vespa-engine/lesters/generator-model-ids
Add model ids for generator models
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java | 23 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java | 6 |
2 files changed, 22 insertions, 7 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java index 0abd7212017..76403d369dd 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/xml/ModelIdResolver.java @@ -3,6 +3,9 @@ package com.yahoo.vespa.model.container.xml; import com.yahoo.text.XML; import org.w3c.dom.Element; + +import java.util.Collections; +import java.util.HashMap; import java.util.Map; import java.util.stream.Collectors; @@ -14,10 +17,22 @@ import java.util.stream.Collectors; */ public class ModelIdResolver { - private static final Map<String, String> providedModels = - Map.of("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", - "mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx", - "bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"); + private static Map<String, String> setupProvidedModels() { + Map<String, String> models = new HashMap<>(); + models.put("minilm-l6-v2", "https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx"); + models.put("mpnet-base-v2", "https://data.vespa.oath.cloud/onnx_models/sentence-all-mpnet-base-v2.onnx"); + models.put("bert-base-uncased", "https://data.vespa.oath.cloud/onnx_models/bert-base-uncased-vocab.txt"); + models.put("flan-t5-vocab", "https://data.vespa.oath.cloud/onnx_models/flan-t5-spiece.model"); + models.put("flan-t5-small-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-encoder-model.onnx"); + models.put("flan-t5-small-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-small-decoder-model.onnx"); + models.put("flan-t5-base-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-encoder-model.onnx"); + models.put("flan-t5-base-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-base-decoder-model.onnx"); + models.put("flan-t5-large-encoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-encoder-model.onnx"); + models.put("flan-t5-large-decoder", "https://data.vespa.oath.cloud/onnx_models/flan-t5-large-decoder-model.onnx"); + return Collections.unmodifiableMap(models); + } + + private static final Map<String, String> providedModels = setupProvidedModels(); /** * Finds any config values of type 'model' below the given config element and diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 509d8527bf5..e7f8086e554 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -24,6 +24,7 @@ import java.nio.charset.StandardCharsets; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class EmbedderTestCase { @@ -91,8 +92,7 @@ public class EmbedderTestCase { " </config>" + "</component>"; assertTransformThrows(embedder, - "Unknown model id 'my_model_id' on 'transformerModel'. " + - "Available models are [bert-base-uncased, minilm-l6-v2, mpnet-base-v2]", + "Unknown model id 'my_model_id' on 'transformerModel'", true); } @@ -194,7 +194,7 @@ public class EmbedderTestCase { ModelIdResolver.resolveModelIds(createElement(embedder), hosted); fail("Expected exception was not thrown: " + expectedMessage); } catch (IllegalArgumentException e) { - assertEquals(expectedMessage, e.getMessage()); + assertTrue(e.getMessage().contains(expectedMessage), "Expected error message not found"); } } |