// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model.container.xml; import com.yahoo.component.ComponentId; import com.yahoo.config.InnerNode; import com.yahoo.config.ModelNode; import com.yahoo.config.ModelReference; import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; import com.yahoo.text.XML; import com.yahoo.vespa.config.ConfigDefinitionKey; import com.yahoo.vespa.config.ConfigPayloadBuilder; import com.yahoo.vespa.model.VespaModel; import com.yahoo.vespa.model.container.ApplicationContainerCluster; import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; import org.w3c.dom.Document; import org.w3c.dom.Element; import org.w3c.dom.NamedNodeMap; import org.xml.sax.SAXException; import java.io.ByteArrayInputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; import java.util.List; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; import static org.junit.jupiter.api.Assertions.fail; public class EmbedderTestCase { @Test void testApplicationComponentWithModelReference_hosted() throws IOException, SAXException { String input = "" + " " + " " + " " + " " + ""; String component = "" + " " + " " + " " + " " + ""; assertTransform(input, component, true); } @Test void testUnknownModelId_hosted() throws IOException, SAXException { String embedder = "" + " " + " " + " " + " " + ""; assertTransformThrows(embedder, "Unknown model id 'my_model_id' on 'model'", true); } @Test void huggingfaceEmbedder_selfhosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); var cluster = model.getContainerClusters().get("container"); var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); assertEquals("my_input_ids", embedderCfg.transformerInputIds()); assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); assertEquals(1024, embedderCfg.transformerMaxTokens()); var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); } @Test void huggingfaceEmbedder_hosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); var cluster = model.getContainerClusters().get("container"); var embedderCfg = assertHuggingfaceEmbedderComponentPresent(cluster); assertEquals("my_input_ids", embedderCfg.transformerInputIds()); assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); assertEquals(1024, embedderCfg.transformerMaxTokens()); var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); } @Test void bertEmbedder_selfhosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); var cluster = model.getContainerClusters().get("container"); var embedderCfg = assertBertEmbedderComponentPresent(cluster); assertEquals("application-url", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value()); } @Test void bertEmbedder_hosted() throws Exception { var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); var cluster = model.getContainerClusters().get("container"); var embedderCfg = assertBertEmbedderComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); assertTrue(modelReference(embedderCfg, "tokenizerVocab").url().isEmpty()); assertEquals("files/vocab.txt", modelReference(embedderCfg, "tokenizerVocab").path().orElseThrow().value()); } @Test void passesXmlValidation() { new VespaModelCreatorWithFilePkg("src/test/cfg/application/embed/").create(); } @Test void testApplicationPackageWithApplicationEmbedder_selfhosted() throws Exception { Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); VespaModel model = loadModel(applicationDir, false); ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); assertEquals("minilm-l6-v2 application-url \"\"", config.getObject("model").getValue()); assertEquals("\"\" \"\" files/vocab.txt", config.getObject("vocab").getValue()); } @Test void testApplicationPackageWithApplicationEmbedder_hosted() throws Exception { Path applicationDir = Path.fromString("src/test/cfg/application/embed_generic/"); VespaModel model = loadModel(applicationDir, true); ApplicationContainerCluster containerCluster = model.getContainerClusters().get("container"); Component testComponent = containerCluster.getComponentsMap().get(new ComponentId("transformer")); ConfigPayloadBuilder config = testComponent.getUserConfigs().get(new ConfigDefinitionKey("sentence-embedder", "ai.vespa.example.paragraph")); assertEquals("minilm-l6-v2 https://data.vespa.oath.cloud/onnx_models/sentence_all_MiniLM_L6_v2.onnx \"\"", config.getObject("model").getValue()); assertEquals("\"\" \"\" files/vocab.txt", config.getObject("vocab").getValue()); } @Test void testApplicationPackageWithApplicationEmbedder_selfhosted_cloud_only() throws Exception { try { Path applicationDir = Path.fromString("src/test/cfg/application/embed_cloud_only/"); VespaModel model = loadModel(applicationDir, false); fail("Expected failure"); } catch (IllegalArgumentException e) { assertEquals("model is configured with only a 'model-id'. Add a 'path' or 'url' to deploy this outside Vespa Cloud", Exceptions.toMessageString(e)); } } private VespaModel loadModel(Path path, boolean hosted) throws Exception { FilesApplicationPackage applicationPackage = FilesApplicationPackage.fromFile(path.toFile()); TestProperties properties = new TestProperties().setHostedVespa(hosted); DeployState state = new DeployState.Builder().properties(properties).applicationPackage(applicationPackage).build(); return new VespaModel(state); } private void assertTransform(String inputComponent, String expectedComponent, boolean hosted) throws IOException, SAXException { Element component = createElement(inputComponent); ModelIdResolver.resolveModelIds(component, hosted); assertSpec(createElement(expectedComponent), component); } private void assertSpec(Element e1, Element e2) { assertEquals(e1.getTagName(), e2.getTagName()); assertAttributes(e1, e2); assertAttributes(e2, e1); assertEquals(XML.getValue(e1).trim(), XML.getValue(e2).trim(), "Content of " + e1.getTagName() + "' is identical"); assertChildren(e1, e2); } private void assertAttributes(Element e1, Element e2) { NamedNodeMap map = e1.getAttributes(); for (int i = 0; i < map.getLength(); ++i) { String attribute = map.item(i).getNodeName(); assertEquals(e1.getAttribute(attribute), e2.getAttribute(attribute), "Attribute '" + attribute + "' is equal"); } } private void assertChildren(Element e1, Element e2) { List list1 = XML.getChildren(e1); List list2 = XML.getChildren(e2); assertEquals(list1.size(), list2.size()); for (int i = 0; i < list1.size(); ++i) { Element child1 = list1.get(i); Element child2 = list2.get(i); assertSpec(child1, child2); } } private void assertTransformThrows(String embedder, String expectedMessage, boolean hosted) throws IOException, SAXException { try { ModelIdResolver.resolveModelIds(createElement(embedder), hosted); fail("Expected exception was not thrown: " + expectedMessage); } catch (IllegalArgumentException e) { assertTrue(e.getMessage().contains(expectedMessage), "Expected error message not found"); } } private Element createElement(String xml) throws IOException, SAXException { Document doc = XML.getDocumentBuilder().parse(new ByteArrayInputStream(xml.getBytes(StandardCharsets.UTF_8))); return (Element) doc.getFirstChild(); } private static HuggingFaceTokenizerConfig assertHuggingfaceTokenizerComponentPresent(ApplicationContainerCluster cluster) { var hfTokenizer = (HuggingFaceTokenizer) cluster.getComponentsMap().get(new ComponentId("hf-tokenizer")); assertEquals("com.yahoo.language.huggingface.HuggingFaceTokenizer", hfTokenizer.getClassId().getName()); var cfgBuilder = new HuggingFaceTokenizerConfig.Builder(); hfTokenizer.getConfig(cfgBuilder); return cfgBuilder.build(); } private static HuggingFaceEmbedderConfig assertHuggingfaceEmbedderComponentPresent(ApplicationContainerCluster cluster) { var hfEmbedder = (HuggingFaceEmbedder) cluster.getComponentsMap().get(new ComponentId("hf-embedder")); assertEquals("ai.vespa.embedding.huggingface.HuggingFaceEmbedder", hfEmbedder.getClassId().getName()); var cfgBuilder = new HuggingFaceEmbedderConfig.Builder(); hfEmbedder.getConfig(cfgBuilder); return cfgBuilder.build(); } private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); var cfgBuilder = new BertBaseEmbedderConfig.Builder(); bertEmbedder.getConfig(cfgBuilder); return cfgBuilder.build(); } // Ugly hack to read underlying model reference from config instance private static ModelReference modelReference(InnerNode cfg, String name) { try { var f = cfg.getClass().getDeclaredField(name); f.setAccessible(true); return ((ModelNode) f.get(cfg)).getModelReference(); } catch (NoSuchFieldException | IllegalAccessException e) { throw new RuntimeException(e); } } }