diff options
111 files changed, 1836 insertions, 520 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java index d0e1ede2cfa..7501f6162c7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/builder/xml/dom/DomComponentBuilder.java @@ -7,10 +7,11 @@ import com.yahoo.config.model.producer.AnyConfigProducer; import com.yahoo.config.model.producer.TreeConfigProducer; import com.yahoo.osgi.provider.model.ComponentModel; import com.yahoo.text.XML; -import com.yahoo.vespa.model.container.component.BertEmbedder; -import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.BertEmbedder; +import com.yahoo.vespa.model.container.component.ColBertEmbedder; +import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.xml.BundleInstantiationSpecificationBuilder; import org.w3c.dom.Element; @@ -46,6 +47,7 @@ public class DomComponentBuilder extends VespaDomBuilder.DomConfigProducerBuilde case "hugging-face-embedder" -> new HuggingFaceEmbedder(spec, state); case "hugging-face-tokenizer" -> new HuggingFaceTokenizer(spec, state); case "bert-embedder" -> new BertEmbedder(spec, state); + case "colbert-embedder" -> new ColBertEmbedder(spec, state); default -> throw new IllegalArgumentException("Unknown component type '%s'".formatted(type)); }; } else { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java new file mode 100644 index 00000000000..c0fdfe3dc64 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -0,0 +1,93 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package com.yahoo.vespa.model.container.component; + +import com.yahoo.config.ModelReference; +import com.yahoo.config.model.deploy.DeployState; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; +import com.yahoo.vespa.model.container.xml.ModelIdResolver; +import org.w3c.dom.Element; + +import java.util.Optional; + +import static com.yahoo.config.model.builder.xml.XmlHelper.getOptionalChild; +import static com.yahoo.text.XML.getChildValue; +import static com.yahoo.vespa.model.container.ContainerModelEvaluation.INTEGRATION_BUNDLE_NAME; + + +/** + * @author bergum + */ +public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderConfig.Producer { + private final ModelReference model; + private final ModelReference vocab; + + private final Integer maxQueryTokens; + + private final Integer maxDocumentTokens; + + private final Integer transformerStartSequenceToken; + private final Integer transformerEndSequenceToken; + private final Integer transformerMaskToken; + private final Integer maxTokens; + private final String transformerInputIds; + private final String transformerAttentionMask; + + private final String transformerOutput; + private final String onnxExecutionMode; + private final Integer onnxInteropThreads; + private final Integer onnxIntraopThreads; + private final Integer onnxGpuDevice; + + public ColBertEmbedder(Element xml, DeployState state) { + super("ai.vespa.embedding.ColBertEmbedder", INTEGRATION_BUNDLE_NAME, xml); + var transformerModelElem = getOptionalChild(xml, "transformer-model").orElseThrow(); + model = ModelIdResolver.resolveToModelReference(transformerModelElem, state); + vocab = getOptionalChild(xml, "tokenizer-model") + .map(elem -> ModelIdResolver.resolveToModelReference(elem, state)) + .orElseGet(() -> resolveDefaultVocab(transformerModelElem, state)); + maxTokens = getChildValue(xml, "max-tokens").map(Integer::parseInt).orElse(null); + maxQueryTokens = getChildValue(xml, "max-query-tokens").map(Integer::parseInt).orElse(null); + maxDocumentTokens = getChildValue(xml, "max-document-tokens").map(Integer::parseInt).orElse(null); + transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); + transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); + transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); + transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); + transformerOutput = getChildValue(xml, "transformer-output").orElse(null); + onnxExecutionMode = getChildValue(xml, "onnx-execution-mode").orElse(null); + onnxInteropThreads = getChildValue(xml, "onnx-interop-threads").map(Integer::parseInt).orElse(null); + onnxIntraopThreads = getChildValue(xml, "onnx-intraop-threads").map(Integer::parseInt).orElse(null); + onnxGpuDevice = getChildValue(xml, "onnx-gpu-device").map(Integer::parseInt).orElse(null); + + } + + private static ModelReference resolveDefaultVocab(Element model, DeployState state) { + if (state.isHosted() && model.hasAttribute("model-id")) { + var implicitVocabId = model.getAttribute("model-id") + "-vocab"; + return ModelIdResolver.resolveToModelReference( + "tokenizer-model", Optional.of(implicitVocabId), Optional.empty(), Optional.empty(), state); + } + throw new IllegalArgumentException("'tokenizer-model' must be specified"); + } + + @Override + public void getConfig(ColBertEmbedderConfig.Builder b) { + b.transformerModel(model).tokenizerPath(vocab); + if (maxTokens != null) b.transformerMaxTokens(maxTokens); + if (transformerInputIds != null) b.transformerInputIds(transformerInputIds); + if (transformerAttentionMask != null) b.transformerAttentionMask(transformerAttentionMask); + if (transformerOutput != null) b.transformerOutput(transformerOutput); + if (maxQueryTokens != null) b.maxQueryTokens(maxQueryTokens); + if (maxDocumentTokens != null) b.maxDocumentTokens(maxDocumentTokens); + if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); + if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); + if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); + if (onnxExecutionMode != null) b.transformerExecutionMode( + ColBertEmbedderConfig.TransformerExecutionMode.Enum.valueOf(onnxExecutionMode)); + if (onnxInteropThreads != null) b.transformerInterOpThreads(onnxInteropThreads); + if (onnxIntraopThreads != null) b.transformerIntraOpThreads(onnxIntraopThreads); + if (onnxGpuDevice != null) b.transformerGpuDevice(onnxGpuDevice); + } +} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java b/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java index 6d1de4bbc0a..8bed5e64bf5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFiles.java @@ -57,7 +57,7 @@ public class UserConfiguredFiles implements Serializable { ConfigDefinition configDefinition = builder.getConfigDefinition(); if (configDefinition == null) { // TODO: throw new IllegalArgumentException("Unable to find config definition for " + builder); - logger.logApplicationPackage(Level.FINE, "Unable to find config definition " + key + + logger.logApplicationPackage(Level.INFO, "Unable to find config definition " + key + ". Will not register files for file distribution for this config"); return; } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index ba7e2b6674e..e0d5e6a3344 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -80,7 +80,7 @@ ComponentDefinition = TypedComponentDefinition = attribute id { xsd:Name } & - (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder) & + (HuggingFaceEmbedder | HuggingFaceTokenizer | BertBaseEmbedder | ColBertEmbedder) & GenericConfig* & Component* @@ -110,15 +110,36 @@ BertBaseEmbedder = element transformer-attention-mask { xsd:string }? & element transformer-token-type-ids { xsd:string }? & element transformer-output { xsd:string }? & - element transformer-start-sequence-token { xsd:integer }? & - element transformer-end-sequence-token { xsd:integer }? & + StartOfSequence & + EndOfSequence & OnnxModelExecutionParams & EmbedderPoolingStrategy + +ColBertEmbedder = + attribute type { "colbert-embedder" } & + element transformer-model { ModelReference } & + element tokenizer-model { ModelReference }? & + element max-tokens { xsd:positiveInteger }? & + element max-query-tokens { xsd:positiveInteger }? & + element max-document-tokens { xsd:positiveInteger }? & + element transformer-mask-token { xsd:integer }? & + element transformer-input-ids { xsd:string }? & + element transformer-attention-mask { xsd:string }? & + element transformer-token-type-ids { xsd:string }? & + element transformer-output { xsd:string }? & + element normalize { xsd:boolean }? & + OnnxModelExecutionParams & + StartOfSequence & + EndOfSequence + OnnxModelExecutionParams = element onnx-execution-mode { "parallel" | "sequential" }? & element onnx-interop-threads { xsd:integer }? & element onnx-intraop-threads { xsd:integer }? & element onnx-gpu-device { xsd:integer }? -EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }?
\ No newline at end of file +EmbedderPoolingStrategy = element pooling-strategy { "cls" | "mean" }? + +StartOfSequence = element transformer-start-sequence-token { xsd:integer }? +EndOfSequence = element transformer-end-sequence-token { xsd:integer }?
\ No newline at end of file diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 70eef7ea54a..efb33d36761 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -43,6 +43,25 @@ <onnx-gpu-device>1</onnx-gpu-device> </component> + <component id="colbert" type="colbert-embedder"> + <transformer-model model-id="e5-base-v2" url="https://my/url/model.onnx"/> + <tokenizer-model model-id="e5-base-v2-vocab" path="app/tokenizer.json"/> + <max-tokens>1024</max-tokens> + <max-query-tokens>32</max-query-tokens> + <max-document-tokens>512</max-document-tokens> + <transformer-start-sequence-token>101</transformer-start-sequence-token> + <transformer-end-sequence-token>102</transformer-end-sequence-token> + <transformer-mask-token>103</transformer-mask-token> + <transformer-input-ids>my_input_ids</transformer-input-ids> + <transformer-attention-mask>my_attention_mask</transformer-attention-mask> + <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> + <transformer-output>my_output</transformer-output> + <onnx-execution-mode>parallel</onnx-execution-mode> + <onnx-intraop-threads>10</onnx-intraop-threads> + <onnx-interop-threads>8</onnx-interop-threads> + <onnx-gpu-device>1</onnx-gpu-device> + </component> + <nodes> <node hostalias="node1" /> </nodes> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 42b78db66b1..5832445d0d7 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -9,6 +9,7 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.embedding.ColBertEmbedderConfig; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.huggingface.config.HuggingFaceTokenizerConfig; import com.yahoo.path.Path; @@ -21,6 +22,7 @@ import com.yahoo.vespa.model.container.component.BertEmbedder; import com.yahoo.vespa.model.container.component.Component; import com.yahoo.vespa.model.container.component.HuggingFaceEmbedder; import com.yahoo.vespa.model.container.component.HuggingFaceTokenizer; +import com.yahoo.vespa.model.container.component.ColBertEmbedder; import com.yahoo.vespa.model.test.utils.VespaModelCreatorWithFilePkg; import com.yahoo.yolean.Exceptions; import org.junit.jupiter.api.Test; @@ -96,6 +98,29 @@ public class EmbedderTestCase { assertEquals(-1, tokenizerCfg.maxLength()); } + void colBertEmbedder_selfhosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), false); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://my/url/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } + + void colBertEmbedder_hosted() throws Exception { + var model = loadModel(Path.fromString("src/test/cfg/application/embed/"), true); + var cluster = model.getContainerClusters().get("container"); + var embedderCfg = assertColBertEmbedderComponentPresent(cluster); + assertEquals("my_input_ids", embedderCfg.transformerInputIds()); + assertEquals("https://data.vespa.oath.cloud/onnx_models/e5-base-v2/model.onnx", modelReference(embedderCfg, "transformerModel").url().orElseThrow().value()); + assertEquals(1024, embedderCfg.transformerMaxTokens()); + var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); + assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); + assertEquals(-1, tokenizerCfg.maxLength()); + } @Test void bertEmbedder_selfhosted() throws Exception { @@ -233,6 +258,14 @@ public class EmbedderTestCase { return cfgBuilder.build(); } + private static ColBertEmbedderConfig assertColBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { + var colbert = (ColBertEmbedder) cluster.getComponentsMap().get(new ComponentId("colbert-embedder")); + assertEquals("ai.vespa.embedding.ColBertEmbedder", colbert.getClassId().getName()); + var cfgBuilder = new ColBertEmbedderConfig.Builder(); + colbert.getConfig(cfgBuilder); + return cfgBuilder.build(); + } + private static BertBaseEmbedderConfig assertBertEmbedderComponentPresent(ApplicationContainerCluster cluster) { var bertEmbedder = (BertEmbedder) cluster.getComponentsMap().get(new ComponentId("bert-embedder")); assertEquals("ai.vespa.embedding.BertBaseEmbedder", bertEmbedder.getClassId().getName()); diff --git a/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java b/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java index cd5f76b422a..bb5ba840c2c 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/filedistribution/UserConfiguredFilesTest.java @@ -15,8 +15,10 @@ import com.yahoo.vespa.config.ConfigPayloadBuilder; import com.yahoo.vespa.model.SimpleConfigProducer; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.io.TempDir; import java.nio.ByteBuffer; +import java.nio.file.Path; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -55,8 +57,14 @@ public class UserConfiguredFilesTest { @Override public List<Entry> export() { - return null; + return pathToRef.entrySet().stream() + .map(e -> new Entry(e.getKey(), e.getValue())) + .toList(); } + + @Override + public String toString() { return export().toString(); } + } private UserConfiguredFiles userConfiguredFiles() { @@ -273,5 +281,20 @@ public class UserConfiguredFilesTest { } } + @Test + void require_that_using_empty_dir_gives_sane_error_message(@TempDir Path tempDir) { + String relativeTempDir = tempDir.toString().substring(tempDir.toString().lastIndexOf("target")); + try { + def.addPathDef("pathVal"); + builder.setField("pathVal", relativeTempDir); + fileRegistry.pathToRef.put(relativeTempDir, new FileReference("bazshash")); + userConfiguredFiles().register(producer); + fail("Should have thrown exception"); + } catch (IllegalArgumentException e) { + assertEquals("Unable to register file specified in services.xml for config 'mynamespace.myname': Directory '" + + relativeTempDir + "' is empty", + e.getMessage()); + } + } } diff --git a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java index 0d21b155571..c1eb3be4275 100644 --- a/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java +++ b/config-provisioning/src/main/java/com/yahoo/config/provision/NodeResources.java @@ -157,6 +157,10 @@ public class NodeResources { return new NodeResources.GpuResources(1, thisMem - otherMem); } + public GpuResources multipliedBy(double factor) { + return new GpuResources(this.count, this.memoryGb * factor); + } + @Override public boolean equals(Object o) { if (this == o) return true; @@ -329,10 +333,12 @@ public class NodeResources { } public NodeResources multipliedBy(double factor) { + if (isUnspecified()) return this; return this.withVcpu(vcpu * factor) .withMemoryGb(memoryGb * factor) .withDiskGb(diskGb * factor) - .withBandwidthGbps(bandwidthGbps * factor); + .withBandwidthGbps(bandwidthGbps * factor) + .with(gpuResources.multipliedBy(factor)); } private boolean isInterchangeableWith(NodeResources other) { diff --git a/configdefinitions/src/vespa/CMakeLists.txt b/configdefinitions/src/vespa/CMakeLists.txt index 29ed0f53421..bb29db2e3e3 100644 --- a/configdefinitions/src/vespa/CMakeLists.txt +++ b/configdefinitions/src/vespa/CMakeLists.txt @@ -88,5 +88,6 @@ install_config_definition(dataplane-proxy.def cloud.config.dataplane-proxy.def) install_config_definition(hugging-face-embedder.def embedding.huggingface.hugging-face-embedder.def) install_config_definition(hugging-face-tokenizer.def language.huggingface.config.hugging-face-tokenizer.def) install_config_definition(bert-base-embedder.def embedding.bert-base-embedder.def) +install_config_definition(col-bert-embedder.def embedding.col-bert-embedder.def) install_config_definition(cloud-data-plane-filter.def jdisc.http.filter.security.cloud.config.cloud-data-plane-filter.def) install_config_definition(cloud-token-data-plane-filter.def jdisc.http.filter.security.cloud.config.cloud-token-data-plane-filter.def) diff --git a/configdefinitions/src/vespa/col-bert-embedder.def b/configdefinitions/src/vespa/col-bert-embedder.def new file mode 100644 index 00000000000..c7944847d8b --- /dev/null +++ b/configdefinitions/src/vespa/col-bert-embedder.def @@ -0,0 +1,36 @@ + +namespace=embedding + +# Path to tokenizer.json +tokenizerPath model + +# Path to model.onnx +transformerModel model + +# Max query tokens for ColBERT +maxQueryTokens int default=32 + +# Max document query tokens for ColBERT +maxDocumentTokens int default=512 + +# Max length of token sequence model can handle +transformerMaxTokens int default=512 + +# Input names +transformerInputIds string default=input_ids +transformerAttentionMask string default=attention_mask + +# special token ids +transformerStartSequenceToken int default=101 +transformerEndSequenceToken int default=102 +transformerMaskToken int default=103 + +# Output name +transformerOutput string default=contextual + +# Settings for ONNX model evaluation +transformerExecutionMode enum { parallel, sequential } default=sequential +transformerInterOpThreads int default=1 +transformerIntraOpThreads int default=-4 +# GPU device id, -1 for CPU +transformerGpuDevice int default=0 diff --git a/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java b/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java index 3c4c99e0337..4f8d42e895b 100644 --- a/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java +++ b/configserver-flags/src/test/java/com/yahoo/vespa/configserver/flags/http/FlagsHandlerTest.java @@ -34,7 +34,7 @@ public class FlagsHandlerTest { "id1", false, List.of("joe"), "2010-01-01", "2030-01-01", "desc1", "mod1"); private static final UnboundBooleanFlag FLAG2 = Flags.defineFeatureFlag( "id2", true, List.of("joe"), "2010-01-01", "2030-01-01", "desc2", "mod2", - FetchVector.Dimension.HOSTNAME, FetchVector.Dimension.APPLICATION_ID); + FetchVector.Dimension.HOSTNAME, FetchVector.Dimension.INSTANCE_ID); private final FlagsDb flagsDb = new FlagsDbImpl(new MockCurator()); private final FlagsHandler handler = new FlagsHandler(FlagsHandler.testContext(), flagsDb); @@ -111,7 +111,7 @@ public class FlagsHandlerTest { }, { "type": "blacklist", - "dimension": "application", + "dimension": "instance", "values": [ "app1", "app2" ] } ], @@ -127,7 +127,7 @@ public class FlagsHandlerTest { // GET on id2 should now return what was put verifySuccessfulRequest(Method.GET, "/data/" + FLAG2.id(), "", - "{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}"); + "{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}"); // The list of flag data should return id1 and id2 verifySuccessfulRequest(Method.GET, "/data", @@ -153,7 +153,7 @@ public class FlagsHandlerTest { // Get all recursivelly displays all flag data verifySuccessfulRequest(Method.GET, "/data?recursive=true", "", - "{\"flags\":[{\"id\":\"id1\",\"rules\":[{\"value\":false}]},{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}]}"); + "{\"flags\":[{\"id\":\"id1\",\"rules\":[{\"value\":false}]},{\"id\":\"id2\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"hostname\",\"values\":[\"host1\",\"host2\"]},{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"app1\",\"app2\"]}],\"value\":true}],\"attributes\":{\"zone\":\"zone1\"}}]}"); // Deleting both flags verifySuccessfulRequest(Method.DELETE, "/data/" + FLAG1.id(), "", ""); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java index 296e31ce801..693252da43a 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java @@ -50,7 +50,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import static com.yahoo.vespa.curator.Curator.CompletionWaiter; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static java.util.stream.Collectors.toSet; /** @@ -418,7 +418,7 @@ public class TenantApplications implements RequestHandler, HostValidator { if (vespaVersion.isEmpty()) return true; Version wantedVersion = applicationMapper.getForVersion(application, Optional.empty(), clock.instant()) .getModel().wantedNodeVersion(); - return VersionCompatibility.fromVersionList(incompatibleVersions.with(APPLICATION_ID, application.serializedForm()).value()) + return VersionCompatibility.fromVersionList(incompatibleVersions.with(INSTANCE_ID, application.serializedForm()).value()) .accept(vespaVersion.get(), wantedVersion); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index bab2e666d0a..142f98e13e3 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -296,7 +296,7 @@ public class ModelContextImpl implements ModelContext { private static <V> V flagValue(FlagSource source, ApplicationId appId, Version vespaVersion, UnboundFlag<? extends V, ?, ?> flag) { return flag.bindTo(source) - .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm()) + .with(FetchVector.Dimension.INSTANCE_ID, appId.serializedForm()) .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString()) .with(FetchVector.Dimension.TENANT_ID, appId.tenant().value()) .boxedValue(); @@ -308,7 +308,7 @@ public class ModelContextImpl implements ModelContext { ClusterSpec.Type clusterType, UnboundFlag<? extends V, ?, ?> flag) { return flag.bindTo(source) - .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm()) + .with(FetchVector.Dimension.INSTANCE_ID, appId.serializedForm()) .with(FetchVector.Dimension.CLUSTER_TYPE, clusterType.name()) .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString()) .boxedValue(); @@ -320,7 +320,7 @@ public class ModelContextImpl implements ModelContext { ClusterSpec.Id clusterId, UnboundFlag<? extends V, ?, ?> flag) { return flag.bindTo(source) - .with(FetchVector.Dimension.APPLICATION_ID, appId.serializedForm()) + .with(FetchVector.Dimension.INSTANCE_ID, appId.serializedForm()) .with(FetchVector.Dimension.CLUSTER_ID, clusterId.value()) .with(FetchVector.Dimension.VESPA_VERSION, vespaVersion.toFullString()) .boxedValue(); @@ -397,21 +397,21 @@ public class ModelContextImpl implements ModelContext { this.tenantSecretStores = tenantSecretStores; this.secretStore = secretStore; this.jvmGCOptionsFlag = PermanentFlags.JVM_GC_OPTIONS.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()); + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()); this.allowDisableMtls = PermanentFlags.ALLOW_DISABLE_MTLS.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value(); + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value(); this.operatorCertificates = operatorCertificates; this.tlsCiphersOverride = PermanentFlags.TLS_CIPHERS_OVERRIDE.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value(); + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value(); this.zoneDnsSuffixes = configserverConfig.zoneDnsSuffixes(); this.environmentVariables = PermanentFlags.ENVIRONMENT_VARIABLES.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value(); + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value(); this.cloudAccount = cloudAccount; this.allowUserFilters = PermanentFlags.ALLOW_USER_FILTERS.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value(); + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value(); this.endpointConnectionTtl = Duration.ofSeconds( PermanentFlags.ENDPOINT_CONNECTION_TTL.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value()); + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value()); this.dataplaneTokens = dataplaneTokens; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 44a656a1579..3b57945b21d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -82,7 +82,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import static com.yahoo.vespa.curator.Curator.CompletionWaiter; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static java.nio.file.Files.readAttributes; /** @@ -728,7 +728,7 @@ public class SessionRepository { } catch (IllegalArgumentException e) { if (configserverConfig.hostedVespa()) { UnboundStringFlag flag = PermanentFlags.APPLICATION_FILES_WITH_UNKNOWN_EXTENSION; - String value = flag.bindTo(flagSource).with(APPLICATION_ID, applicationId.serializedForm()).value(); + String value = flag.bindTo(flagSource).with(INSTANCE_ID, applicationId.serializedForm()).value(); switch (value) { case "FAIL" -> throw new InvalidApplicationException(e); case "LOG" -> deployLogger.ifPresent(logger -> logger.logApplicationPackage(Level.WARNING, e.getMessage())); diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java index fa8a0ddcba1..856af9f4132 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchive.java @@ -5,6 +5,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.fasterxml.jackson.databind.node.TextNode; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.CloudAccount; @@ -190,7 +191,7 @@ public class SystemFlagsDataArchive { flagData.rules().forEach(rule -> rule.conditions().forEach(condition -> { int force_switch_expression_dummy = switch (condition.type()) { case RELATIONAL -> switch (condition.dimension()) { - case APPLICATION_ID, CLOUD, CLOUD_ACCOUNT, CLUSTER_ID, CLUSTER_TYPE, CONSOLE_USER_EMAIL, + case INSTANCE_ID, CLOUD, CLOUD_ACCOUNT, CLUSTER_ID, CLUSTER_TYPE, CONSOLE_USER_EMAIL, ENVIRONMENT, HOSTNAME, NODE_TYPE, SYSTEM, TENANT_ID, ZONE_ID -> throw new FlagValidationException(condition.type().toWire() + " " + DimensionHelper.toWire(condition.dimension()) + @@ -205,7 +206,7 @@ public class SystemFlagsDataArchive { }; case WHITELIST, BLACKLIST -> switch (condition.dimension()) { - case APPLICATION_ID -> validateConditionValues(condition, ApplicationId::fromSerializedForm); + case INSTANCE_ID -> validateConditionValues(condition, ApplicationId::fromSerializedForm); case CONSOLE_USER_EMAIL -> validateConditionValues(condition, email -> { if (!email.contains("@")) throw new FlagValidationException("Invalid email address: " + email); @@ -255,6 +256,17 @@ public class SystemFlagsDataArchive { final JsonNode root; try { root = mapper.readTree(fileContent); + // TODO (mortent): Remove this after completing migration of APPLICATION_ID dimension + // replace "application" with "instance" for all dimension fields + List<JsonNode> dimensionParents = root.findParents("dimension"); + for (JsonNode parentNode : dimensionParents) { + JsonNode dimension = parentNode.get("dimension"); + if (dimension.isTextual() && "application".equals(dimension.textValue())) { + ObjectNode parent = (ObjectNode) parentNode; + parent.remove("dimension"); + parent.put("dimension", "instance"); + } + } } catch (JsonProcessingException e) { throw new FlagValidationException("Invalid JSON: " + e.getMessage()); } diff --git a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java index 759f21579d4..373f8ba9de2 100644 --- a/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java +++ b/controller-api/src/test/java/com/yahoo/vespa/hosted/controller/api/systemflags/v1/SystemFlagsDataArchiveTest.java @@ -245,7 +245,7 @@ public class SystemFlagsDataArchiveTest { "conditions": [ { "type": "whitelist", - "dimension": "application", + "dimension": "instance", "values": [ "f:o:o" ] } ], @@ -287,7 +287,7 @@ public class SystemFlagsDataArchiveTest { { "comment": "bar", "type": "whitelist", - "dimension": "application", + "dimension": "instance", "values": [ "f:o:o" ] } ], @@ -308,6 +308,7 @@ public class SystemFlagsDataArchiveTest { @Test void normalize_json_succeed_on_valid_values() { addFile(Condition.Type.WHITELIST, "application", "a:b:c"); + addFile(Condition.Type.WHITELIST, "instance", "a:b:c"); addFile(Condition.Type.WHITELIST, "cloud", "yahoo"); addFile(Condition.Type.WHITELIST, "cloud", "aws"); addFile(Condition.Type.WHITELIST, "cloud", "gcp"); @@ -361,7 +362,7 @@ public class SystemFlagsDataArchiveTest { @Test void normalize_json_fail_on_invalid_values() { - failAddFile(Condition.Type.WHITELIST, "application", "a.b.c", "In file flags/temporary/foo/default.json: Invalid application 'a.b.c' in whitelist condition: Application ids must be on the form tenant:application:instance, but was a.b.c"); + failAddFile(Condition.Type.WHITELIST, "application", "a.b.c", "In file flags/temporary/foo/default.json: Invalid instance 'a.b.c' in whitelist condition: Application ids must be on the form tenant:application:instance, but was a.b.c"); failAddFile(Condition.Type.WHITELIST, "cloud", "foo", "In file flags/temporary/foo/default.json: Unknown cloud: foo"); // cluster-id: any String is valid failAddFile(Condition.Type.WHITELIST, "cluster-type", "foo", "In file flags/temporary/foo/default.json: Invalid cluster-type 'foo' in whitelist condition: Illegal cluster type 'foo'"); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index 2f7b9f92316..90653d85aed 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -113,7 +113,7 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static com.yahoo.vespa.hosted.controller.api.integration.configserver.Node.State.active; import static com.yahoo.vespa.hosted.controller.api.integration.configserver.Node.State.reserved; import static com.yahoo.vespa.hosted.controller.versions.VespaVersion.Confidence.broken; @@ -676,7 +676,7 @@ public class ApplicationController { Optional<DockerImage> dockerImageRepo = Optional.ofNullable( dockerImageRepoFlag .with(FetchVector.Dimension.ZONE_ID, zone.value()) - .with(APPLICATION_ID, application.serializedForm()) + .with(INSTANCE_ID, application.serializedForm()) .value()) .filter(s -> !s.isBlank()) .map(DockerImage::fromString); @@ -962,7 +962,7 @@ public class ApplicationController { } public VersionCompatibility versionCompatibility(ApplicationId id) { - return VersionCompatibility.fromVersionList(incompatibleVersions.with(APPLICATION_ID, id.serializedForm()).value()); + return VersionCompatibility.fromVersionList(incompatibleVersions.with(INSTANCE_ID, id.serializedForm()).value()); } /** diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java index 2b2ec725d7e..091836a1eea 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java @@ -12,6 +12,7 @@ import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.zone.AuthMethod; import com.yahoo.config.provision.zone.RoutingMethod; +import com.yahoo.config.provision.zone.ZoneApi; import com.yahoo.config.provision.zone.ZoneId; import com.yahoo.vespa.flags.BooleanFlag; import com.yahoo.vespa.flags.FetchVector; @@ -132,46 +133,46 @@ public class RoutingController { } // Add zone-scoped endpoints - Map<EndpointId, List<GeneratedEndpoint>> generatedForDeclaredEndpoints = new HashMap<>(); + Map<EndpointId, GeneratedEndpointList> generatedForDeclaredEndpoints = new HashMap<>(); Set<ClusterSpec.Id> clustersWithToken = new HashSet<>(); - if (randomizedEndpointsEnabled(deployment.applicationId())) { // TODO(mpolden): Remove this guard once config-models < 8.220 are gone - RoutingPolicyList applicationPolicies = policies().read(TenantAndApplicationId.from(deployment.applicationId())); - RoutingPolicyList deploymentPolicies = applicationPolicies.deployment(deployment); - for (var container : services.containers()) { - ClusterSpec.Id clusterId = ClusterSpec.Id.from(container.id()); - boolean tokenSupported = container.authMethods().contains(BasicServicesXml.Container.AuthMethod.token); - if (tokenSupported) { - clustersWithToken.add(clusterId); - } - Optional<RoutingPolicy> clusterPolicy = deploymentPolicies.cluster(clusterId).first(); - List<GeneratedEndpoint> generatedForCluster = clusterPolicy.map(policy -> policy.generatedEndpoints().cluster().asList()) - .orElseGet(List::of); - // Generate endpoints if cluster does not have any - if (generatedForCluster.isEmpty()) { - generatedForCluster = generateEndpoints(tokenSupported, certificate, Optional.empty()); - } - endpoints = endpoints.and(endpointsOf(deployment, clusterId, GeneratedEndpointList.copyOf(generatedForCluster)).scope(Scope.zone)); + boolean generatedEndpointsEnabled = generatedEndpointsEnabled(deployment.applicationId()); + RoutingPolicyList applicationPolicies = policies().read(TenantAndApplicationId.from(deployment.applicationId())); + RoutingPolicyList deploymentPolicies = applicationPolicies.deployment(deployment); + for (var container : services.containers()) { + ClusterSpec.Id clusterId = ClusterSpec.Id.from(container.id()); + boolean tokenSupported = container.authMethods().contains(BasicServicesXml.Container.AuthMethod.token); + if (tokenSupported) { + clustersWithToken.add(clusterId); } - - // Generate endpoints if declared endpoint does not have any - for (var container : services.containers()) { - ClusterSpec.Id clusterId = ClusterSpec.Id.from(container.id()); - applicationPolicies.cluster(clusterId).asList().stream() - .flatMap(policy -> policy.generatedEndpoints().declared().asList().stream()) - .forEach(ge -> generatedForDeclaredEndpoints.computeIfAbsent(ge.endpoint().get(), (k) -> List.of(ge))); + Optional<RoutingPolicy> clusterPolicy = deploymentPolicies.cluster(clusterId).first(); + List<GeneratedEndpoint> generatedForCluster = clusterPolicy.map(policy -> policy.generatedEndpoints().cluster().asList()) + .orElseGet(List::of); + // Generate endpoints if cluster does not have any + if (generatedForCluster.isEmpty()) { + generatedForCluster = generateEndpoints(tokenSupported, certificate, Optional.empty()); } - Stream.concat(spec.endpoints().stream(), spec.instances().stream().flatMap(i -> i.endpoints().stream())) - .forEach(endpoint -> { - EndpointId endpointId = EndpointId.of(endpoint.endpointId()); - generatedForDeclaredEndpoints.computeIfAbsent(endpointId, (k) -> { - boolean tokenSupported = clustersWithToken.contains(ClusterSpec.Id.from(endpoint.containerId())); - return generateEndpoints(tokenSupported, certificate, Optional.of(endpointId)); - }); - }); + GeneratedEndpointList generatedEndpoints = generatedEndpointsEnabled ? GeneratedEndpointList.copyOf(generatedForCluster) : GeneratedEndpointList.EMPTY; + endpoints = endpoints.and(endpointsOf(deployment, clusterId, generatedEndpoints).scope(Scope.zone)); } // Add global- and application-scoped endpoints - endpoints = endpoints.and(declaredEndpointsOf(application.get().id(), spec, generatedForDeclaredEndpoints).targets(deployment)); + for (var container : services.containers()) { + ClusterSpec.Id clusterId = ClusterSpec.Id.from(container.id()); + applicationPolicies.cluster(clusterId).asList().stream() + .flatMap(policy -> policy.generatedEndpoints().declared().asList().stream()) + .forEach(ge -> generatedForDeclaredEndpoints.computeIfAbsent(ge.endpoint().get(), (k) -> GeneratedEndpointList.of(ge))); + } + // Generate endpoints if declared endpoint does not have any + Stream.concat(spec.endpoints().stream(), spec.instances().stream().flatMap(i -> i.endpoints().stream())) + .forEach(endpoint -> { + EndpointId endpointId = EndpointId.of(endpoint.endpointId()); + generatedForDeclaredEndpoints.computeIfAbsent(endpointId, (k) -> { + boolean tokenSupported = clustersWithToken.contains(ClusterSpec.Id.from(endpoint.containerId())); + return generatedEndpointsEnabled ? GeneratedEndpointList.copyOf(generateEndpoints(tokenSupported, certificate, Optional.of(endpointId))) : null; + }); + }); + Map<EndpointId, GeneratedEndpointList> generatedEndpoints = generatedEndpointsEnabled ? generatedForDeclaredEndpoints : Map.of(); + endpoints = endpoints.and(declaredEndpointsOf(application.get().id(), spec, generatedEndpoints).targets(deployment)); PreparedEndpoints prepared = new PreparedEndpoints(deployment, endpoints, application.get().require(deployment.applicationId().instance()).rotations(), @@ -203,10 +204,13 @@ public class RoutingController { .on(Port.fromRoutingMethod(routingMethod)) .target(cluster, deployment); endpoints.add(zoneEndpoint.in(controller.system())); + ZoneApi zone = controller.zoneRegistry().zones().all().get(deployment.zoneId()).get(); Endpoint.EndpointBuilder regionEndpoint = Endpoint.of(deployment.applicationId()) .routingMethod(routingMethod) .on(Port.fromRoutingMethod(routingMethod)) - .targetRegion(cluster, deployment.zoneId()); + .targetRegion(cluster, + zone.getCloudNativeRegionName(), + zone.getCloudName()); // Region endpoints are only used by global- and application-endpoints and are thus only needed in // production environments if (isProduction) { @@ -289,7 +293,7 @@ public class RoutingController { /** Read application and return endpoints for all instances in application */ public EndpointList readDeclaredEndpointsOf(Application application) { - return declaredEndpointsOf(application.id(), application.deploymentSpec(), readMultiDeploymentGeneratedEndpoints(application.id())); + return declaredEndpointsOf(application.id(), application.deploymentSpec(), readDeclaredGeneratedEndpoints(application.id())); } /** Read application and return declared endpoints for given instance */ @@ -299,7 +303,7 @@ public class RoutingController { return readDeclaredEndpointsOf(application).instance(instance.instance()); } - private EndpointList declaredEndpointsOf(TenantAndApplicationId application, DeploymentSpec deploymentSpec, Map<EndpointId, List<GeneratedEndpoint>> generatedEndpoints) { + private EndpointList declaredEndpointsOf(TenantAndApplicationId application, DeploymentSpec deploymentSpec, Map<EndpointId, GeneratedEndpointList> generatedEndpoints) { Set<Endpoint> endpoints = new LinkedHashSet<>(); // Global endpoints for (var spec : deploymentSpec.instances()) { @@ -311,7 +315,7 @@ public class RoutingController { ZoneId.from(Environment.prod, region))) .toList(); ClusterSpec.Id cluster = ClusterSpec.Id.from(declaredEndpoint.containerId()); - GeneratedEndpointList generatedForId = GeneratedEndpointList.copyOf(generatedEndpoints.getOrDefault(routingId.endpointId(), List.of())); + GeneratedEndpointList generatedForId = generatedEndpoints.getOrDefault(routingId.endpointId(), GeneratedEndpointList.EMPTY); endpoints.addAll(declaredEndpointsOf(routingId, cluster, deployments, generatedForId).asList()); } } @@ -323,7 +327,7 @@ public class RoutingController { t -> t.weight())); ClusterSpec.Id cluster = ClusterSpec.Id.from(declaredEndpoint.containerId()); EndpointId endpointId = EndpointId.of(declaredEndpoint.endpointId()); - GeneratedEndpointList generatedForId = GeneratedEndpointList.copyOf(generatedEndpoints.getOrDefault(endpointId, List.of())); + GeneratedEndpointList generatedForId = generatedEndpoints.getOrDefault(endpointId, GeneratedEndpointList.EMPTY); endpoints.addAll(declaredEndpointsOf(application, endpointId, cluster, deployments, generatedForId).asList()); } return EndpointList.copyOf(endpoints); @@ -407,7 +411,7 @@ public class RoutingController { var deployments = rotation.regions().stream() .map(region -> new DeploymentId(instance.id(), ZoneId.from(Environment.prod, region))) .toList(); - GeneratedEndpointList generatedForId = GeneratedEndpointList.copyOf(readMultiDeploymentGeneratedEndpoints(application.id()).getOrDefault(rotation.endpointId(), List.of())); + GeneratedEndpointList generatedForId = readDeclaredGeneratedEndpoints(application.id()).getOrDefault(rotation.endpointId(), GeneratedEndpointList.EMPTY); endpointsToRemove.addAll(declaredEndpointsOf(RoutingId.of(instance.id(), rotation.endpointId()), rotation.clusterId(), deployments, generatedForId) @@ -472,13 +476,13 @@ public class RoutingController { .toList(); } - /** Returns generated endpoint suitable for use in endpoints whose scope is {@link Scope#multiDeployment()} */ - private Map<EndpointId, List<GeneratedEndpoint>> readMultiDeploymentGeneratedEndpoints(TenantAndApplicationId application) { - Map<EndpointId, List<GeneratedEndpoint>> endpoints = new HashMap<>(); + /** Returns existing generated endpoints, grouped by their {@link Scope#multiDeployment()} endpoint */ + private Map<EndpointId, GeneratedEndpointList> readDeclaredGeneratedEndpoints(TenantAndApplicationId application) { + Map<EndpointId, GeneratedEndpointList> endpoints = new HashMap<>(); for (var policy : policies().read(application)) { - Map<EndpointId, List<GeneratedEndpoint>> generatedForDeclared = policy.generatedEndpoints().declared() - .asList().stream() - .collect(Collectors.groupingBy(ge -> ge.endpoint().get())); + Map<EndpointId, GeneratedEndpointList> generatedForDeclared = policy.generatedEndpoints() + .not().cluster() + .groupingBy(ge -> ge.endpoint().get()); generatedForDeclared.forEach(endpoints::putIfAbsent); } return endpoints; @@ -520,8 +524,8 @@ public class RoutingController { return Collections.unmodifiableList(routingMethods); } - public boolean randomizedEndpointsEnabled(ApplicationId instance) { - return randomizedEndpoints.with(FetchVector.Dimension.APPLICATION_ID, instance.serializedForm()).value(); + public boolean generatedEndpointsEnabled(ApplicationId instance) { + return randomizedEndpoints.with(FetchVector.Dimension.INSTANCE_ID, instance.serializedForm()).value(); } private static void requireGeneratedEndpoints(GeneratedEndpointList generatedEndpoints, boolean declared) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Endpoint.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Endpoint.java index b93d634101e..5c6611f80c3 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Endpoint.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/Endpoint.java @@ -2,7 +2,9 @@ package com.yahoo.vespa.hosted.controller.application; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.CloudName; import com.yahoo.config.provision.ClusterSpec; +import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.InstanceName; import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; @@ -302,22 +304,6 @@ public class Endpoint { return part.substring(Math.max(0, part.length() - 63)); } - /** Returns the given region without availability zone */ - private static RegionName effectiveRegion(RegionName region) { - if (region.value().length() < 2) return region; - String value = region.value(); - char lastChar = value.charAt(value.length() - 1); - if (lastChar >= 'a' && lastChar <= 'z') { // Remove availability zone - int skip = value.charAt(value.length() - 2) == '-' ? 2 : 1; - value = value.substring(0, value.length() - skip); - } - return RegionName.from(value); - } - - private static ZoneId effectiveZone(ZoneId zone) { - return ZoneId.from(zone.environment(), effectiveRegion(zone.region())); - } - private static ClusterSpec.Id requireCluster(ClusterSpec.Id cluster, boolean certificateName) { if (!certificateName && cluster.value().equals("*")) throw new IllegalArgumentException("Wildcard found in cluster ID which is not a certificate name"); return cluster; @@ -550,10 +536,11 @@ public class Endpoint { } /** Sets the region target for this, deduced from given zone */ - public EndpointBuilder targetRegion(ClusterSpec.Id cluster, ZoneId zone) { + public EndpointBuilder targetRegion(ClusterSpec.Id cluster, String cloudNativeRegion, CloudName cloudName) { this.cluster = cluster; this.scope = requireUnset(Scope.weighted); - this.targets = List.of(new Target(new DeploymentId(application.instance(instance.get()), effectiveZone(zone)))); + RegionName region = RegionName.from(cloudName.value() + "-" + cloudNativeRegion); + this.targets = List.of(new Target(new DeploymentId(application.instance(instance.get()), ZoneId.from(Environment.prod, region)))); this.authMethod = AuthMethod.none; return this; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/BasicServicesXml.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/BasicServicesXml.java index 9eb10857526..33f20327d92 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/BasicServicesXml.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/application/pkg/BasicServicesXml.java @@ -7,7 +7,6 @@ import org.w3c.dom.Element; import java.util.ArrayList; import java.util.List; import java.util.Objects; -import java.util.stream.Collectors; /** * A partially parsed variant of services.xml, for use by the {@link com.yahoo.vespa.hosted.controller.Controller}. @@ -78,7 +77,7 @@ public record BasicServicesXml(List<Container> containers) { this.authMethods = Objects.requireNonNull(authMethods).stream() .distinct() .sorted() - .collect(Collectors.toList()); + .toList(); if (authMethods.isEmpty()) throw new IllegalArgumentException("Container must have at least one auth method"); } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java index b1cadcc341c..e01da00a27e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/certificate/EndpointCertificates.java @@ -155,7 +155,7 @@ public class EndpointCertificates { } private Optional<EndpointCertificate> getOrProvision(Instance instance, ZoneId zone, DeploymentSpec deploymentSpec) { - if (controller.routing().randomizedEndpointsEnabled(instance.id())) { + if (controller.routing().generatedEndpointsEnabled(instance.id())) { return Optional.of(assignFromPool(instance, zone)); } Optional<AssignedCertificate> assignedCertificate = curator.readAssignedCertificate(TenantAndApplicationId.from(instance.id()), Optional.of(instance.id().instance())); @@ -234,8 +234,8 @@ public class EndpointCertificates { .forEach(requiredNames::addAll); log.log(Level.INFO, String.format("Requesting new endpoint certificate from Cameo for application %s", deployment.applicationId().serializedForm())); - String algo = this.endpointCertificateAlgo.with(FetchVector.Dimension.APPLICATION_ID, deployment.applicationId().serializedForm()).value(); - boolean useAlternativeProvider = useAlternateCertProvider.with(FetchVector.Dimension.APPLICATION_ID, deployment.applicationId().serializedForm()).value(); + String algo = this.endpointCertificateAlgo.with(FetchVector.Dimension.INSTANCE_ID, deployment.applicationId().serializedForm()).value(); + boolean useAlternativeProvider = useAlternateCertProvider.with(FetchVector.Dimension.INSTANCE_ID, deployment.applicationId().serializedForm()).value(); String keyPrefix = deployment.applicationId().toFullString(); var t0 = Instant.now(); EndpointCertificate endpointCertificate = certificateProvider.requestCaSignedCertificate(keyPrefix, List.copyOf(requiredNames), currentCert, algo, useAlternativeProvider); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java index 0773c95e1f2..1c417e750e3 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/deployment/JobController.java @@ -155,7 +155,7 @@ public class JobController { } public boolean isDisabled(JobId id) { - return disabledZones.with(Dimension.APPLICATION_ID, id.application().serializedForm()).value().contains(id.type().zone().value()); + return disabledZones.with(Dimension.INSTANCE_ID, id.application().serializedForm()).value().contains(id.type().zone().value()); } /** Returns all entries currently logged for the given run. */ diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EndpointCertificateMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EndpointCertificateMaintainer.java index e130e73cef1..c90fcb81c71 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EndpointCertificateMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/EndpointCertificateMaintainer.java @@ -283,7 +283,7 @@ public class EndpointCertificateMaintainer extends ControllerMaintainer { assignedCertificates.stream() .filter(c -> c.instance().isPresent()) .filter(c -> c.certificate().randomizedId().isEmpty()) - .filter(c -> assignRandomizedId.with(FetchVector.Dimension.APPLICATION_ID, c.application().instance(c.instance().get()).serializedForm()).value()) + .filter(c -> assignRandomizedId.with(FetchVector.Dimension.INSTANCE_ID, c.application().instance(c.instance().get()).serializedForm()).value()) .filter(c -> controller().applications().getApplication(c.application()).isPresent()) // In case application has been deleted, but certificate is pending deletion .limit(assignRandomizedIdRate.value()) .forEach(c -> assignRandomizedId(c.application(), c.instance().get())); diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainer.java index eed2a9ec991..6ee1a8b56d7 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainer.java @@ -231,13 +231,8 @@ public class ResourceMeterMaintainer extends ControllerMaintainer { } public static double cost(ClusterResources clusterResources, SystemName systemName) { - NodeResources nr = clusterResources.nodeResources(); - return cost(new ResourceAllocation(nr.vcpu(), nr.memoryGb(), nr.diskGb(), nr.architecture()).multiply(clusterResources.nodes()), systemName); - } - - private static double cost(ResourceAllocation allocation, SystemName systemName) { - var resources = new NodeResources(allocation.getCpuCores(), allocation.getMemoryGb(), allocation.getDiskGb(), 0); - return cost(resources, systemName); + var totalResources = clusterResources.nodeResources().multipliedBy(clusterResources.nodes()); + return cost(totalResources, systemName); } private static double cost(NodeResources resources, SystemName systemName) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/certificate/EndpointCertificatesHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/certificate/EndpointCertificatesHandler.java index 6c7ee4d0d85..912bd051a31 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/certificate/EndpointCertificatesHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/certificate/EndpointCertificatesHandler.java @@ -73,15 +73,15 @@ public class EndpointCertificatesHandler extends ThreadedHttpRequestHandler { public StringResponse reRequestEndpointCertificateFor(String instanceId, boolean ignoreExisting) { ApplicationId applicationId = ApplicationId.fromFullString(instanceId); - if (controller.routing().randomizedEndpointsEnabled(applicationId)) { + if (controller.routing().generatedEndpointsEnabled(applicationId)) { throw new IllegalArgumentException("Cannot re-request certificate. " + instanceId + " is assigned certificate from a pool"); } try (var lock = curator.lock(TenantAndApplicationId.from(applicationId))) { AssignedCertificate assignedCertificate = curator.readAssignedCertificate(applicationId) .orElseThrow(() -> new RestApiException.NotFound("No certificate found for application " + applicationId.serializedForm())); - String algo = this.endpointCertificateAlgo.with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value(); - boolean useAlternativeProvider = useAlternateCertProvider.with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()).value(); + String algo = this.endpointCertificateAlgo.with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value(); + boolean useAlternativeProvider = useAlternateCertProvider.with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()).value(); String keyPrefix = applicationId.toFullString(); EndpointCertificate cert = endpointCertificateProvider.requestCaSignedCertificate( diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializer.java index 54a24360b5a..c3acf01a53e 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializer.java @@ -20,7 +20,6 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.function.Predicate; -import java.util.stream.Collectors; import java.util.stream.Stream; /** @@ -72,7 +71,7 @@ public class UserFlagsSerializer { // For the other dimensions, filter the values down to an allowed subset switch (condition.dimension()) { case TENANT_ID: return valueSubset(condition, tenant -> isOperator || authorizedForTenantNames.contains(TenantName.from(tenant))); - case APPLICATION_ID: return valueSubset(condition, appId -> isOperator || authorizedForTenantNames.stream().anyMatch(tenant -> appId.startsWith(tenant.value() + ":"))); + case INSTANCE_ID: return valueSubset(condition, appId -> isOperator || authorizedForTenantNames.stream().anyMatch(tenant -> appId.startsWith(tenant.value() + ":"))); default: throw new IllegalArgumentException("Dimension " + condition.dimension() + " is not supported for user flags"); } } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index 1a886a50589..6901b6f93c9 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -1040,8 +1040,8 @@ public class ControllerTest { @Test void testDeployWithGlobalEndpointsAndMultipleRoutingMethods() { var context = tester.newDeploymentContext(); - var zone1 = ZoneId.from("prod", "us-west-1"); - var zone2 = ZoneId.from("prod", "us-east-3"); + var zone1 = ZoneId.from("prod", "aws-us-east-1a"); + var zone2 = ZoneId.from("prod", "aws-us-east-1b"); var applicationPackage = new ApplicationPackageBuilder() .athenzIdentity(AthenzDomain.from("domain"), AthenzService.from("service")) .endpoint("default", "default", zone1.region().value(), zone2.region().value()) @@ -1059,20 +1059,20 @@ public class ControllerTest { var expectedRecords = List.of( // The weighted record for zone 2's region new Record(Record.Type.ALIAS, - RecordName.from("application.tenant.us-east-3-w.vespa.oath.cloud"), - new WeightedAliasTarget(HostName.of("lb-0--tenant.application.default--prod.us-east-3"), - "dns-zone-1", "prod.us-east-3", 1).pack()), + RecordName.from("application.tenant.aws-us-east-1-w.vespa.oath.cloud"), + new WeightedAliasTarget(HostName.of("lb-0--tenant.application.default--prod.aws-us-east-1b"), + "dns-zone-1", "prod.aws-us-east-1b", 1).pack()), // The 'east' global endpoint, pointing to the weighted record for zone 2's region new Record(Record.Type.ALIAS, RecordName.from("east.application.tenant.global.vespa.oath.cloud"), - new LatencyAliasTarget(HostName.of("application.tenant.us-east-3-w.vespa.oath.cloud"), - "dns-zone-1", ZoneId.from("prod.us-east-3")).pack()), + new LatencyAliasTarget(HostName.of("application.tenant.aws-us-east-1-w.vespa.oath.cloud"), + "dns-zone-1", ZoneId.from("prod.aws-us-east-1b")).pack()), // The zone-scoped endpoint pointing to zone 2 with exclusive routing new Record(Record.Type.CNAME, - RecordName.from("application.tenant.us-east-3.vespa.oath.cloud"), - RecordData.from("lb-0--tenant.application.default--prod.us-east-3."))); + RecordName.from("application.tenant.aws-us-east-1b.vespa.oath.cloud"), + RecordData.from("lb-0--tenant.application.default--prod.aws-us-east-1b."))); assertEquals(expectedRecords, List.copyOf(tester.controllerTester().nameService().records())); } @@ -1091,6 +1091,8 @@ public class ControllerTest { .region(zone1.region()) .region(zone2.region()) .region(zone3.region()) + .container("qrs", AuthMethod.mtls) + .container("default", AuthMethod.mtls) .endpoint("default", "default") .endpoint("foo", "qrs") .endpoint("us", "default", zone1.region().value(), zone2.region().value()) @@ -1100,15 +1102,19 @@ public class ControllerTest { // Deployment passes container endpoints to config server for (var zone : List.of(zone1, zone2)) { assertEquals(Set.of("application.tenant.global.vespa.oath.cloud", - "foo.application.tenant.global.vespa.oath.cloud", - "us.application.tenant.global.vespa.oath.cloud"), + "foo.application.tenant.global.vespa.oath.cloud", + "us.application.tenant.global.vespa.oath.cloud", + "qrs.application.tenant." + zone.region().value() + ".vespa.oath.cloud", + "application.tenant." + zone.region().value() + ".vespa.oath.cloud"), tester.configServer().containerEndpointNames(context.deploymentIdIn(zone)), "Expected container endpoints in " + zone); } assertEquals(Set.of("application.tenant.global.vespa.oath.cloud", - "foo.application.tenant.global.vespa.oath.cloud"), - tester.configServer().containerEndpointNames(context.deploymentIdIn(zone3)), - "Expected container endpoints in " + zone3); + "foo.application.tenant.global.vespa.oath.cloud", + "qrs.application.tenant.eu-west-1.vespa.oath.cloud", + "application.tenant.eu-west-1.vespa.oath.cloud"), + tester.configServer().containerEndpointNames(context.deploymentIdIn(zone3)), + "Expected container endpoints in " + zone3); } @Test diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/EndpointTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/EndpointTest.java index 79e40e61387..fbc5567101f 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/EndpointTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/application/EndpointTest.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.application; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.CloudName; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.zone.AuthMethod; @@ -245,44 +246,37 @@ public class EndpointTest { Map<String, Endpoint> tests = Map.of( "https://a1.t1.aws-us-north-1.w.vespa-app.cloud/", Endpoint.of(instance1) - .targetRegion(cluster, ZoneId.from("prod", "aws-us-north-1a")) + .targetRegion(cluster, "us-north-1", CloudName.AWS) .routingMethod(RoutingMethod.exclusive) .on(Port.tls()) .in(SystemName.Public), "https://a1.t1.gcp-us-south1.w.vespa-app.cloud/", Endpoint.of(instance1) - .targetRegion(cluster, ZoneId.from("prod", "gcp-us-south1-c")) + .targetRegion(cluster, "us-south1", CloudName.GCP) .routingMethod(RoutingMethod.exclusive) .on(Port.tls()) .in(SystemName.Public), - "https://a1.t1.us-north-2.w.vespa-app.cloud/", + "https://c1.a1.t1.aws-us-north-2.w.vespa-app.cloud/", Endpoint.of(instance1) - .targetRegion(cluster, prodZone) + .targetRegion(ClusterSpec.Id.from("c1"), "us-north-2", CloudName.AWS) .routingMethod(RoutingMethod.exclusive) .on(Port.tls()) .in(SystemName.Public), - "https://c1.a1.t1.us-north-2.w.vespa-app.cloud/", + "https://deadbeef.cafed00d.aws-us-north-2.w.vespa-app.cloud/", Endpoint.of(instance1) - .targetRegion(ClusterSpec.Id.from("c1"), prodZone) + .targetRegion(ClusterSpec.Id.from("c1"), "us-north-2", CloudName.AWS) .routingMethod(RoutingMethod.exclusive) + .generatedFrom(new GeneratedEndpoint("deadbeef", "cafed00d", AuthMethod.mtls, Optional.empty())) .on(Port.tls()) .in(SystemName.Public), - "https://deadbeef.cafed00d.us-north-2.w.vespa-app.cloud/", + "https://c1.a1.t1.aws-us-north-2-w.vespa.oath.cloud/", Endpoint.of(instance1) - .targetRegion(ClusterSpec.Id.from("c1"), prodZone) + .targetRegion(ClusterSpec.Id.from("c1"), "us-north-2", CloudName.AWS) .routingMethod(RoutingMethod.exclusive) - .generatedFrom(new GeneratedEndpoint("deadbeef", "cafed00d", AuthMethod.mtls, Optional.empty())) .on(Port.tls()) - .in(SystemName.Public) + .in(SystemName.main) ); tests.forEach((expected, endpoint) -> assertEquals(expected, endpoint.url().toString())); - - assertEquals("aws-us-north-1", - tests.get("https://a1.t1.aws-us-north-1.w.vespa-app.cloud/").targets().get(0).deployment().zoneId().region().value(), - "Availability zone is removed from region"); - assertEquals("gcp-us-south1", - tests.get("https://a1.t1.gcp-us-south1.w.vespa-app.cloud/").targets().get(0).deployment().zoneId().region().value(), - "Availability zone is removed from region"); } @Test diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java index 5c20bce0099..3623ddc4e56 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/ApplicationPackageBuilder.java @@ -129,7 +129,7 @@ public class ApplicationPackageBuilder { .append(m.name()).append("-").append(i) .append("'/>\n"); } - servicesBody.append(" </client>\n"); + servicesBody.append(" </client>\n"); } servicesBody.append(" </clients>\n") .append(" </container>\n"); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ZoneRegistryMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ZoneRegistryMock.java index dbb7f80df0e..ec526163507 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ZoneRegistryMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ZoneRegistryMock.java @@ -73,8 +73,8 @@ public class ZoneRegistryMock extends AbstractComponent implements ZoneRegistry ZoneApiMock.fromId("dev.us-east-1"), ZoneApiMock.newBuilder().withId("dev.aws-us-east-2a").withCloud("aws").build(), ZoneApiMock.fromId("perf.us-east-3"), - ZoneApiMock.newBuilder().withId("prod.aws-us-east-1a").withCloud("aws").build(), - ZoneApiMock.newBuilder().withId("prod.aws-us-east-1b").withCloud("aws").build(), + ZoneApiMock.newBuilder().withId("prod.aws-us-east-1a").withCloud("aws").withCloudNativeRegionName("us-east-1").build(), + ZoneApiMock.newBuilder().withId("prod.aws-us-east-1b").withCloud("aws").withCloudNativeRegionName("us-east-1").build(), ZoneApiMock.fromId("prod.ap-northeast-1"), ZoneApiMock.fromId("prod.ap-northeast-2"), ZoneApiMock.fromId("prod.ap-southeast-1"), diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainerTest.java index fac05fc125f..8f9ba75f95c 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/maintenance/ResourceMeterMaintainerTest.java @@ -3,6 +3,7 @@ package com.yahoo.vespa.hosted.controller.maintenance; import com.yahoo.component.Version; import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.ClusterResources; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; @@ -123,6 +124,18 @@ public class ResourceMeterMaintainerTest { assertEquals(lastRefreshTime + millisAdvanced, tester.curator().readMeteringRefreshTime()); } + @Test + public void testClusterCost() { + var nodeResources = new NodeResources(10, 64, 100, 10, + NodeResources.DiskSpeed.fast, + NodeResources.StorageType.local, + NodeResources.Architecture.x86_64, + new NodeResources.GpuResources(2, 16)); + var clusterResources = new ClusterResources(5, 1, nodeResources); + + assertEquals(5 * nodeResources.cost(), ResourceMeterMaintainer.cost(clusterResources, SystemName.Public), 0.001); + } + private void setUpZones() { ZoneApiMock zone1 = ZoneApiMock.newBuilder().withId("prod.region-2").build(); ZoneApiMock zone2 = ZoneApiMock.newBuilder().withId("test.region-3").build(); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java index ed7d02d0047..eb3f9daef53 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/user/UserFlagsSerializerTest.java @@ -26,7 +26,7 @@ import java.util.Set; import java.util.stream.Collectors; import java.util.stream.Stream; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CONSOLE_USER_EMAIL; import static com.yahoo.vespa.flags.FetchVector.Dimension.TENANT_ID; @@ -42,7 +42,7 @@ public class UserFlagsSerializerTest { try (Flags.Replacer ignored = Flags.clearFlagsForTesting()) { Flags.defineStringFlag("string-id", "default value", List.of("owner"), "1970-01-01", "2100-01-01", "desc", "mod", CONSOLE_USER_EMAIL); - Flags.defineIntFlag("int-id", 123, List.of("owner"), "1970-01-01", "2100-01-01", "desc", "mod", CONSOLE_USER_EMAIL, TENANT_ID, APPLICATION_ID); + Flags.defineIntFlag("int-id", 123, List.of("owner"), "1970-01-01", "2100-01-01", "desc", "mod", CONSOLE_USER_EMAIL, TENANT_ID, INSTANCE_ID); Flags.defineDoubleFlag("double-id", 3.14d, List.of("owner"), "1970-01-01", "2100-01-01", "desc", "mod"); Flags.defineListFlag("list-id", List.of("a"), String.class, List.of("owner"), "1970-01-01", "2100-01-01", "desc", "mod", CONSOLE_USER_EMAIL); Flags.defineJacksonFlag("jackson-id", new ExampleJacksonClass(123, "abc"), ExampleJacksonClass.class, @@ -52,9 +52,9 @@ public class UserFlagsSerializerTest { flagData("string-id", rule("\"value1\"", condition(CONSOLE_USER_EMAIL, Condition.Type.WHITELIST, email1))), flagData("int-id", rule("456")), flagData("list-id", - rule("[\"value1\"]", condition(CONSOLE_USER_EMAIL, Condition.Type.WHITELIST, email1), condition(APPLICATION_ID, Condition.Type.BLACKLIST, "tenant1:video:default", "tenant1:video:default", "tenant2:music:default")), + rule("[\"value1\"]", condition(CONSOLE_USER_EMAIL, Condition.Type.WHITELIST, email1), condition(INSTANCE_ID, Condition.Type.BLACKLIST, "tenant1:video:default", "tenant1:video:default", "tenant2:music:default")), rule("[\"value2\"]", condition(CONSOLE_USER_EMAIL, Condition.Type.WHITELIST, email2)), - rule("[\"value1\",\"value3\"]", condition(APPLICATION_ID, Condition.Type.BLACKLIST, "tenant1:video:default", "tenant1:video:default", "tenant2:music:default"))), + rule("[\"value1\",\"value3\"]", condition(INSTANCE_ID, Condition.Type.BLACKLIST, "tenant1:video:default", "tenant1:video:default", "tenant2:music:default"))), flagData("jackson-id", rule("{\"integer\":456,\"string\":\"xyz\"}", condition(CONSOLE_USER_EMAIL, Condition.Type.WHITELIST, email1), condition(TENANT_ID, Condition.Type.WHITELIST, "tenant1", "tenant3"))) ).collect(Collectors.toMap(FlagData::id, fd -> fd)); @@ -63,7 +63,7 @@ public class UserFlagsSerializerTest { "{\"id\":\"int-id\",\"rules\":[{\"value\":456}]}," + // Default from DB "{\"id\":\"jackson-id\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"tenant\"}],\"value\":{\"integer\":456,\"string\":\"xyz\"}},{\"value\":{\"integer\":123,\"string\":\"abc\"}}]}," + // Resolved for email // Resolved for email, but conditions are empty since this user is not authorized for any tenants - "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\"}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\"}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + + "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\"}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\"}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + "{\"id\":\"string-id\",\"rules\":[{\"value\":\"value1\"}]}]}", // resolved for email flagData, Set.of(), false, email1); @@ -72,7 +72,7 @@ public class UserFlagsSerializerTest { "{\"id\":\"int-id\",\"rules\":[{\"value\":456}]}," + // Default from DB "{\"id\":\"jackson-id\",\"rules\":[{\"conditions\":[{\"type\":\"whitelist\",\"dimension\":\"tenant\",\"values\":[\"tenant1\"]}],\"value\":{\"integer\":456,\"string\":\"xyz\"}},{\"value\":{\"integer\":123,\"string\":\"abc\"}}]}," + // Resolved for email // Resolved for email, but conditions have filtered out tenant2 - "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + + "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\"]},{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + "{\"id\":\"string-id\",\"rules\":[{\"value\":\"value1\"}]}]}", // resolved for email flagData, Set.of("tenant1"), false, email1); @@ -81,7 +81,7 @@ public class UserFlagsSerializerTest { "{\"id\":\"int-id\",\"rules\":[{\"value\":456}]}," + // Default from DB "{\"id\":\"jackson-id\",\"rules\":[{\"value\":{\"integer\":123,\"string\":\"abc\"}}]}," + // Default from code, no DB values match // Includes last value from DB which is not conditioned on email and the default from code - "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"application\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\",\"tenant2:music:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + + "{\"id\":\"list-id\",\"rules\":[{\"conditions\":[{\"type\":\"blacklist\",\"dimension\":\"instance\",\"values\":[\"tenant1:video:default\",\"tenant1:video:default\",\"tenant2:music:default\"]}],\"value\":[\"value1\",\"value3\"]},{\"value\":[\"a\"]}]}," + "{\"id\":\"string-id\",\"rules\":[{\"value\":\"default value\"}]}]}", // Default from code flagData, Set.of(), true, "operator@domain.tld"); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java index dd7ab709161..630de5137bb 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/routing/RoutingPoliciesTest.java @@ -8,6 +8,7 @@ import com.yahoo.config.application.api.ValidationId; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.AthenzDomain; import com.yahoo.config.provision.AthenzService; +import com.yahoo.config.provision.CloudName; import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.Environment; import com.yahoo.config.provision.HostName; @@ -16,6 +17,7 @@ import com.yahoo.config.provision.RegionName; import com.yahoo.config.provision.SystemName; import com.yahoo.config.provision.zone.AuthMethod; import com.yahoo.config.provision.zone.RoutingMethod; +import com.yahoo.config.provision.zone.ZoneApi; import com.yahoo.config.provision.zone.ZoneId; import com.yahoo.vespa.flags.Flags; import com.yahoo.vespa.hosted.controller.ControllerTester; @@ -24,6 +26,7 @@ import com.yahoo.vespa.hosted.controller.api.identifiers.DeploymentId; import com.yahoo.vespa.hosted.controller.api.integration.certificates.EndpointCertificate; import com.yahoo.vespa.hosted.controller.api.integration.configserver.ContainerEndpoint; import com.yahoo.vespa.hosted.controller.api.integration.configserver.LoadBalancer; +import com.yahoo.vespa.hosted.controller.api.integration.deployment.JobType; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record; import com.yahoo.vespa.hosted.controller.api.integration.dns.Record.Type; import com.yahoo.vespa.hosted.controller.api.integration.dns.RecordData; @@ -62,6 +65,7 @@ import java.util.stream.Collectors; import java.util.stream.Stream; import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; import static org.junit.jupiter.api.Assertions.assertSame; import static org.junit.jupiter.api.Assertions.assertThrows; import static org.junit.jupiter.api.Assertions.assertTrue; @@ -73,12 +77,41 @@ import static org.junit.jupiter.api.Assertions.fail; */ public class RoutingPoliciesTest { - private static final ZoneId zone1 = ZoneId.from("prod", "us-west-1"); - private static final ZoneId zone2 = ZoneId.from("prod", "us-central-1"); - private static final ZoneId zone3 = ZoneId.from("prod", "aws-us-east-1a"); - private static final ZoneId zone4 = ZoneId.from("prod", "aws-us-east-1b"); - private static final ZoneId zone5 = ZoneId.from("prod", "north"); - private static final ZoneId zone6 = ZoneId.from("prod", "south"); + private static final ZoneApiMock zoneApi1 = ZoneApiMock.newBuilder() + .with(ZoneId.from("prod", "aws-us-west-11a")) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-west-11") + .build(); + private static final ZoneApiMock zoneApi2 = ZoneApiMock.newBuilder().with(ZoneId.from("prod", "aws-us-central-22a")) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-central-22") + .build(); + private static final ZoneApiMock zoneApi3 = ZoneApiMock.newBuilder().with(ZoneId.from("prod", "aws-us-east-33a")) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-east-33") + .build(); + private static final ZoneApiMock zoneApi4 = ZoneApiMock.newBuilder() + .with(ZoneId.from("prod", "aws-us-east-33b")) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-east-33") + .build(); + private static final ZoneApiMock zoneApi5 = ZoneApiMock.newBuilder() + .with(ZoneId.from("prod", "aws-us-north-44a")) + .with(CloudName.AWS) + .withCloudNativeRegionName("north-44") + .build(); + private static final ZoneApiMock zoneApi6 = ZoneApiMock.newBuilder() + .with(ZoneId.from("prod", "aws-us-south-55a")) + .with(CloudName.AWS) + .withCloudNativeRegionName("south-55") + .build(); + + private static final ZoneId zone1 = zoneApi1.getId(); + private static final ZoneId zone2 = zoneApi2.getId(); + private static final ZoneId zone3 = zoneApi3.getId(); + private static final ZoneId zone4 = zoneApi4.getId(); + private static final ZoneId zone5 = zoneApi5.getId(); + private static final ZoneId zone6 = zoneApi6.getId(); private static final ApplicationPackage applicationPackage = applicationPackageBuilder().region(zone1.region()) .region(zone2.region()) @@ -95,7 +128,7 @@ public class RoutingPoliciesTest { .region(zone1.region()) .region(zone2.region()) .endpoint("r0", "c0") - .endpoint("r1", "c0", "us-west-1") + .endpoint("r1", "c0", zone1.region().value()) .endpoint("r2", "c1") .build(); tester.provisionLoadBalancers(clustersPerZone, context1.instanceId(), zone1, zone2); @@ -115,7 +148,7 @@ public class RoutingPoliciesTest { .region(zone2.region()) .region(zone3.region()) .endpoint("r0", "c0") - .endpoint("r1", "c0", "us-west-1") + .endpoint("r1", "c0", zone1.region().value()) .endpoint("r2", "c1") .build(); numberOfDeployments++; @@ -175,14 +208,14 @@ public class RoutingPoliciesTest { assertTrue(policies.asList().stream().allMatch(policy -> policy.instanceEndpoints().isEmpty()), "Rotation membership is removed from all policies"); assertEquals(1, tester.aliasDataOf(endpoint4).size(), "Rotations for " + context2.application() + " are not removed"); - assertEquals(List.of("c0.app1.tenant1.aws-us-east-1a.vespa.oath.cloud", - "c0.app1.tenant1.us-central-1.vespa.oath.cloud", - "c0.app1.tenant1.us-west-1.vespa.oath.cloud", - "c0.app2.tenant1.us-west-1-w.vespa.oath.cloud", - "c0.app2.tenant1.us-west-1.vespa.oath.cloud", - "c1.app1.tenant1.aws-us-east-1a.vespa.oath.cloud", - "c1.app1.tenant1.us-central-1.vespa.oath.cloud", - "c1.app1.tenant1.us-west-1.vespa.oath.cloud", + assertEquals(List.of("c0.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-east-33a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c0.app2.tenant1.aws-us-west-11-w.vespa.oath.cloud", + "c0.app2.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-east-33a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", "r0.app2.tenant1.global.vespa.oath.cloud"), tester.recordNames(), "Endpoints in DNS matches current config"); @@ -265,10 +298,10 @@ public class RoutingPoliciesTest { // Deployment creates records and policies for all clusters in all zones List<String> expectedRecords = List.of( - "c0.app1.tenant1.us-central-1.vespa.oath.cloud", - "c0.app1.tenant1.us-west-1.vespa.oath.cloud", - "c1.app1.tenant1.us-central-1.vespa.oath.cloud", - "c1.app1.tenant1.us-west-1.vespa.oath.cloud" + "c0.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-west-11a.vespa.oath.cloud" ); assertEquals(expectedRecords, tester.recordNames()); assertEquals(4, tester.policiesOf(context1.instanceId()).size()); @@ -282,12 +315,12 @@ public class RoutingPoliciesTest { tester.provisionLoadBalancers(clustersPerZone + 1, context1.instanceId(), sharedRoutingLayer, zone1, zone2); context1.submit(applicationPackage).deferLoadBalancerProvisioningIn(Environment.prod).deploy(); expectedRecords = List.of( - "c0.app1.tenant1.us-central-1.vespa.oath.cloud", - "c0.app1.tenant1.us-west-1.vespa.oath.cloud", - "c1.app1.tenant1.us-central-1.vespa.oath.cloud", - "c1.app1.tenant1.us-west-1.vespa.oath.cloud", - "c2.app1.tenant1.us-central-1.vespa.oath.cloud", - "c2.app1.tenant1.us-west-1.vespa.oath.cloud" + "c0.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c2.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c2.app1.tenant1.aws-us-west-11a.vespa.oath.cloud" ); assertEquals(expectedRecords, tester.recordNames()); assertEquals(6, tester.policiesOf(context1.instanceId()).size()); @@ -296,16 +329,16 @@ public class RoutingPoliciesTest { tester.provisionLoadBalancers(clustersPerZone, context2.instanceId(), sharedRoutingLayer, zone1, zone2); context2.submit(applicationPackage).deferLoadBalancerProvisioningIn(Environment.prod).deploy(); expectedRecords = List.of( - "c0.app1.tenant1.us-central-1.vespa.oath.cloud", - "c0.app1.tenant1.us-west-1.vespa.oath.cloud", - "c0.app2.tenant1.us-central-1.vespa.oath.cloud", - "c0.app2.tenant1.us-west-1.vespa.oath.cloud", - "c1.app1.tenant1.us-central-1.vespa.oath.cloud", - "c1.app1.tenant1.us-west-1.vespa.oath.cloud", - "c1.app2.tenant1.us-central-1.vespa.oath.cloud", - "c1.app2.tenant1.us-west-1.vespa.oath.cloud", - "c2.app1.tenant1.us-central-1.vespa.oath.cloud", - "c2.app1.tenant1.us-west-1.vespa.oath.cloud" + "c0.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c0.app2.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app2.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app2.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app2.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c2.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c2.app1.tenant1.aws-us-west-11a.vespa.oath.cloud" ); assertEquals(expectedRecords.stream().sorted().toList(), tester.recordNames().stream().sorted().toList()); assertEquals(4, tester.policiesOf(context2.instanceId()).size()); @@ -314,14 +347,14 @@ public class RoutingPoliciesTest { tester.provisionLoadBalancers(clustersPerZone, context1.instanceId(), sharedRoutingLayer, zone1, zone2); context1.submit(applicationPackage).deferLoadBalancerProvisioningIn(Environment.prod).deploy(); expectedRecords = List.of( - "c0.app1.tenant1.us-central-1.vespa.oath.cloud", - "c0.app1.tenant1.us-west-1.vespa.oath.cloud", - "c0.app2.tenant1.us-central-1.vespa.oath.cloud", - "c0.app2.tenant1.us-west-1.vespa.oath.cloud", - "c1.app1.tenant1.us-central-1.vespa.oath.cloud", - "c1.app1.tenant1.us-west-1.vespa.oath.cloud", - "c1.app2.tenant1.us-central-1.vespa.oath.cloud", - "c1.app2.tenant1.us-west-1.vespa.oath.cloud" + "c0.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c0.app2.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app2.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app2.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app2.tenant1.aws-us-west-11a.vespa.oath.cloud" ); assertEquals(expectedRecords, tester.recordNames()); @@ -330,10 +363,10 @@ public class RoutingPoliciesTest { .forEach(zone -> tester.controllerTester().controller().applications().deactivate(context2.instanceId(), zone)); context2.flushDnsUpdates(); expectedRecords = List.of( - "c0.app1.tenant1.us-central-1.vespa.oath.cloud", - "c0.app1.tenant1.us-west-1.vespa.oath.cloud", - "c1.app1.tenant1.us-central-1.vespa.oath.cloud", - "c1.app1.tenant1.us-west-1.vespa.oath.cloud" + "c0.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-central-22a.vespa.oath.cloud", + "c1.app1.tenant1.aws-us-west-11a.vespa.oath.cloud" ); assertEquals(expectedRecords, tester.recordNames()); assertTrue(tester.routingPolicies().read(context2.instanceId()).isEmpty(), "Removes stale routing policies " + context2.application()); @@ -490,7 +523,7 @@ public class RoutingPoliciesTest { context.submit(applicationPackage).deploy(); var zone = ZoneId.from("dev", "us-east-1"); tester.controllerTester().setRoutingMethod(List.of(zone), RoutingMethod.exclusive); - var prodRecords = List.of("app1.tenant1.us-central-1.vespa.oath.cloud", "app1.tenant1.us-west-1.vespa.oath.cloud"); + var prodRecords = List.of("app1.tenant1.aws-us-central-22a.vespa.oath.cloud", "app1.tenant1.aws-us-west-11a.vespa.oath.cloud"); assertEquals(prodRecords, tester.recordNames()); // Deploy to dev under different instance @@ -520,7 +553,7 @@ public class RoutingPoliciesTest { // Application is deployed context.submit(applicationPackage).deferLoadBalancerProvisioningIn(Environment.prod).deploy(); var expectedRecords = List.of( - "c0.app1.tenant1.us-west-1.vespa.oath.cloud" + "c0.app1.tenant1.aws-us-west-11a.vespa.oath.cloud" ); assertEquals(expectedRecords, tester.recordNames()); assertEquals(1, tester.policiesOf(context.instanceId()).size()); @@ -566,21 +599,21 @@ public class RoutingPoliciesTest { app.deploy(); // TXT records are cleaned up as we go—the last challenge is the last to go here, and we must flush it ourselves. - assertEquals(List.of("a.t.aws-us-east-1a.vespa.oath.cloud", - "challenge--a.t.aws-us-east-1a.vespa.oath.cloud"), + assertEquals(List.of("a.t.aws-us-east-33a.vespa.oath.cloud", + "challenge--a.t.aws-us-east-33a.vespa.oath.cloud"), tester.recordNames()); app.flushDnsUpdates(); assertEquals(Set.of(new Record(Type.CNAME, - RecordName.from("a.t.aws-us-east-1a.vespa.oath.cloud"), - RecordData.from("lb-0--t.a.default--prod.aws-us-east-1a."))), + RecordName.from("a.t.aws-us-east-33a.vespa.oath.cloud"), + RecordData.from("lb-0--t.a.default--prod.aws-us-east-33a."))), tester.controllerTester().nameService().records()); tester.tester.controllerTester().serviceRegistry().vpcEndpointService().outcomes - .put(RecordName.from("challenge--a.t.aws-us-east-1a.vespa.oath.cloud"), ChallengeState.running); + .put(RecordName.from("challenge--a.t.aws-us-east-33a.vespa.oath.cloud"), ChallengeState.running); // Deployment fails because challenge is not answered (immediately). - assertEquals("Status of run 2 of production-aws-us-east-1a for t.a ==> expected: <succeeded> but was: <unfinished>", + assertEquals("Status of run 2 of production-aws-us-east-33a for t.a ==> expected: <succeeded> but was: <unfinished>", assertThrows(AssertionError.class, () -> app.submit(appPackage).deploy()) .getMessage()); @@ -716,8 +749,15 @@ public class RoutingPoliciesTest { .build(); // Application starts deployment + List<JobType> testJobs = tester.controllerTester().zoneRegistry().zones().all() + .in(Environment.test, Environment.staging) + .in(CloudName.AWS) + .ids() + .stream() + .map(JobType::deploymentTo) + .toList(); context = context.submit(applicationPackage); - for (var testJob : List.of(DeploymentContext.systemTest, DeploymentContext.stagingTest)) { + for (var testJob : testJobs) { context = context.runJob(testJob); // Since runJob implicitly tears down the deployment and immediately deletes DNS records associated with the // deployment, we consume only one DNS update at a time here @@ -760,7 +800,7 @@ public class RoutingPoliciesTest { tester.routingPolicies().setRoutingStatus(context.deploymentIdIn(zone2), RoutingStatus.Value.out, RoutingStatus.Agent.tenant); } catch (IllegalArgumentException e) { - assertEquals("Cannot deactivate routing for tenant1.app1 in prod.us-central-1 as it's the last remaining active deployment in endpoint https://r0.app1.tenant1.global.vespa.oath.cloud/ [scope=global, legacy=false, routingMethod=exclusive, authMethod=mtls, name=r0]", e.getMessage()); + assertEquals("Cannot deactivate routing for tenant1.app1 in prod.aws-us-central-22a as it's the last remaining active deployment in endpoint https://r0.app1.tenant1.global.vespa.oath.cloud/ [scope=global, legacy=false, routingMethod=exclusive, authMethod=mtls, name=r0]", e.getMessage()); } context.flushDnsUpdates(); tester.assertTargets(context.instanceId(), EndpointId.of("r0"), 0, zone2); @@ -888,18 +928,18 @@ public class RoutingPoliciesTest { .named(EndpointId.of("a1"), Endpoint.Scope.application).isEmpty(), "Endpoint removed"); assertEquals(List.of("a0.app1.tenant1.a.vespa.oath.cloud", - "beta.app1.tenant1.north.vespa.oath.cloud", - "beta.app1.tenant1.south.vespa.oath.cloud", - "c0.beta.app1.tenant1.north.vespa.oath.cloud", - "c0.beta.app1.tenant1.south.vespa.oath.cloud", - "c0.main.app1.tenant1.north.vespa.oath.cloud", - "c0.main.app1.tenant1.south.vespa.oath.cloud", - "c1.beta.app1.tenant1.north.vespa.oath.cloud", - "c1.beta.app1.tenant1.south.vespa.oath.cloud", - "c1.main.app1.tenant1.north.vespa.oath.cloud", - "c1.main.app1.tenant1.south.vespa.oath.cloud", - "main.app1.tenant1.north.vespa.oath.cloud", - "main.app1.tenant1.south.vespa.oath.cloud"), + "beta.app1.tenant1.aws-us-north-44a.vespa.oath.cloud", + "beta.app1.tenant1.aws-us-south-55a.vespa.oath.cloud", + "c0.beta.app1.tenant1.aws-us-north-44a.vespa.oath.cloud", + "c0.beta.app1.tenant1.aws-us-south-55a.vespa.oath.cloud", + "c0.main.app1.tenant1.aws-us-north-44a.vespa.oath.cloud", + "c0.main.app1.tenant1.aws-us-south-55a.vespa.oath.cloud", + "c1.beta.app1.tenant1.aws-us-north-44a.vespa.oath.cloud", + "c1.beta.app1.tenant1.aws-us-south-55a.vespa.oath.cloud", + "c1.main.app1.tenant1.aws-us-north-44a.vespa.oath.cloud", + "c1.main.app1.tenant1.aws-us-south-55a.vespa.oath.cloud", + "main.app1.tenant1.aws-us-north-44a.vespa.oath.cloud", + "main.app1.tenant1.aws-us-south-55a.vespa.oath.cloud"), tester.recordNames(), "Endpoints in DNS matches current config"); @@ -960,7 +1000,7 @@ public class RoutingPoliciesTest { tester.routingPolicies().setRoutingStatus(mainZone2, RoutingStatus.Value.out, RoutingStatus.Agent.tenant); fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("Cannot deactivate routing for tenant1.app1.main in prod.south as it's the last remaining active deployment in endpoint https://a0.app1.tenant1.a.vespa.oath.cloud/ [scope=application, legacy=false, routingMethod=exclusive, authMethod=mtls, name=a0]", + assertEquals("Cannot deactivate routing for tenant1.app1.main in prod.aws-us-south-55a as it's the last remaining active deployment in endpoint https://a0.app1.tenant1.a.vespa.oath.cloud/ [scope=application, legacy=false, routingMethod=exclusive, authMethod=mtls, name=a0]", e.getMessage()); } @@ -1218,13 +1258,42 @@ public class RoutingPoliciesTest { return loadBalancers; } - private static List<ZoneId> publicZones() { - var sharedRegion = RegionName.from("aws-us-east-1c"); - return List.of(ZoneId.from(Environment.prod, sharedRegion), - ZoneId.from(Environment.prod, RegionName.from("aws-eu-west-1a")), - ZoneId.from(Environment.prod, RegionName.from("gcp-us-south1-b")), - ZoneId.from(Environment.staging, RegionName.from("us-east-3")), - ZoneId.from(Environment.test, RegionName.from("us-east-1"))); + private static List<ZoneApi> publicZones() { + return List.of(ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.prod, RegionName.from("aws-us-east-1c"))) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-east-1") + .build(), + ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.prod, RegionName.from("aws-eu-west-1a"))) + .with(CloudName.AWS) + .withCloudNativeRegionName("eu-west-1") + .build(), + ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.prod, RegionName.from("gcp-us-south1-b"))) + .with(CloudName.GCP) + .withCloudNativeRegionName("us-south1") + .build(), + ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.staging, RegionName.from("aws-us-east-3c"))) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-east-3") + .build(), + ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.test, RegionName.from("aws-us-east-2c"))) + .with(CloudName.AWS) + .withCloudNativeRegionName("us-east-2") + .build(), + ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.staging, RegionName.from("gcp-us-east-99"))) + .with(CloudName.GCP) + .withCloudNativeRegionName("us-east-99") + .build(), + ZoneApiMock.newBuilder() + .with(ZoneId.from(Environment.test, RegionName.from("gcp-us-east-99"))) + .with(CloudName.GCP) + .withCloudNativeRegionName("us-east-99") + .build()); } private static class RoutingPoliciesTester { @@ -1243,23 +1312,16 @@ public class RoutingPoliciesTest { public RoutingPoliciesTester(DeploymentTester tester, boolean exclusiveRouting) { this.tester = tester; - List<ZoneId> zones; + List<ZoneApi> zones; if (tester.controller().system().isPublic()) { zones = publicZones(); - tester.controllerTester().setZones(zones); } else { - zones = new ArrayList<>(tester.controllerTester().zoneRegistry().zones().all().ids()); // Default zones - zones.add(zone4); // Missing from default ZoneRegistryMock zones - tester.controllerTester().setZones(zones); - tester.controllerTester().zoneRegistry().addZones(ZoneApiMock.newBuilder().withId(zone5.value()).withCloud("aws").build()); - tester.controllerTester().zoneRegistry().addZones(ZoneApiMock.newBuilder().withId(zone6.value()).withCloud("gcp").build()); - zones.add(zone5); - zones.add(zone6); - tester.configServer().bootstrap(zones, SystemApplication.notController()); - } - if (exclusiveRouting) { - tester.controllerTester().setRoutingMethod(zones, RoutingMethod.exclusive); + zones = new ArrayList<>(tester.controllerTester().zoneRegistry().zones().all().zones()); + zones.addAll(List.of(zoneApi1, zoneApi2, zoneApi3, zoneApi4, zoneApi5, zoneApi6)); } + tester.controllerTester().zoneRegistry().setZones(zones); + tester.configServer().bootstrap(toZoneIds(zones), SystemApplication.notController()); + tester.controllerTester().setRoutingMethod(toZoneIds(zones), exclusiveRouting ? RoutingMethod.exclusive : RoutingMethod.sharedLayer4); } public Map<DeploymentId, Set<ContainerEndpoint>> containerEndpoints(Environment environment) { @@ -1284,6 +1346,10 @@ public class RoutingPoliciesTest { return tester.controllerTester(); } + private List<ZoneId> toZoneIds(List<ZoneApi> zoneApis) { + return zoneApis.stream().map(ZoneApi::getId).toList(); + } + private void provisionLoadBalancers(int clustersPerZone, ApplicationId application, boolean shared, ZoneId... zones) { for (ZoneId zone : zones) { tester.configServer().removeLoadBalancers(application, zone); @@ -1348,7 +1414,7 @@ public class RoutingPoliciesTest { deploymentsByDnsName.computeIfAbsent(dnsName, (k) -> new ArrayList<>()) .add(deployment); } - assertTrue(deploymentsByDnsName.size() >= 1, "Found " + endpointId + " for " + application); + assertFalse(deploymentsByDnsName.isEmpty(), "Found " + endpointId + " for " + application); deploymentsByDnsName.forEach((dnsName, deployments) -> { Set<String> weightedTargets = deployments.stream() .map(d -> "weighted/lb-" + loadBalancerId + "--" + diff --git a/default_build_settings.cmake b/default_build_settings.cmake index 4c855f9c923..3d13abb8652 100644 --- a/default_build_settings.cmake +++ b/default_build_settings.cmake @@ -4,10 +4,10 @@ include(VespaExtendedDefaultBuildSettings OPTIONAL) function(setup_vespa_default_build_settings_darwin) message("-- Setting up default build settings for darwin") - set(DEFAULT_EXTRA_LINK_DIRECTORY "${VESPA_DEPS_PREFIX}/lib" "${VESPA_HOMEBREW_PREFIX}/opt/bison/lib" "${VESPA_HOMEBREW_PREFIX}/opt/flex/lib" "${VESPA_HOMEBREW_PREFIX}/opt/icu4c/lib" "${VESPA_HOMEBREW_PREFIX}/opt/openssl@1.1/lib" "${VESPA_HOMEBREW_PREFIX}/opt/openblas/lib") + set(DEFAULT_EXTRA_LINK_DIRECTORY "${VESPA_DEPS_PREFIX}/lib" "${VESPA_HOMEBREW_PREFIX}/opt/bison/lib" "${VESPA_HOMEBREW_PREFIX}/opt/flex/lib" "${VESPA_HOMEBREW_PREFIX}/opt/icu4c/lib" "${VESPA_HOMEBREW_PREFIX}/opt/openssl@3/lib" "${VESPA_HOMEBREW_PREFIX}/opt/openblas/lib") list(APPEND DEFAULT_EXTRA_LINK_DIRECTORY "${VESPA_HOMEBREW_PREFIX}/lib") set(DEFAULT_EXTRA_LINK_DIRECTORY "${DEFAULT_EXTRA_LINK_DIRECTORY}" PARENT_SCOPE) - set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS_PREFIX}/include" "${VESPA_HOMEBREW_PREFIX}/opt/flex/include" "${VESPA_HOMEBREW_PREFIX}/opt/icu4c/include" "${VESPA_HOMEBREW_PREFIX}/opt/openssl@1.1/include" "${VESPA_HOMEBREW_PREFIX}/opt/openblas/include") + set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_DEPS_PREFIX}/include" "${VESPA_HOMEBREW_PREFIX}/opt/flex/include" "${VESPA_HOMEBREW_PREFIX}/opt/icu4c/include" "${VESPA_HOMEBREW_PREFIX}/opt/openssl@3/include" "${VESPA_HOMEBREW_PREFIX}/opt/openblas/include") list(APPEND DEFAULT_EXTRA_INCLUDE_DIRECTORY "${VESPA_HOMEBREW_PREFIX}/include") set(DEFAULT_EXTRA_INCLUDE_DIRECTORY "${DEFAULT_EXTRA_INCLUDE_DIRECTORY}" PARENT_SCOPE) endfunction() @@ -84,7 +84,7 @@ endfunction() function(vespa_use_default_cmake_prefix_path) set(DEFAULT_CMAKE_PREFIX_PATH ${VESPA_DEPS_PREFIX}) if (APPLE) - list(APPEND DEFAULT_CMAKE_PREFIX_PATH "${VESPA_HOMEBREW_PREFIX}/opt/bison" "${VESPA_HOMEBREW_PREFIX}/opt/flex" "${VESPA_HOMEBREW_PREFIX}/opt/openssl@1.1" "${VESPA_HOMEBREW_PREFIX}/opt/openblas" "${VESPA_HOMEBREW_PREFIX}/opt/icu4c") + list(APPEND DEFAULT_CMAKE_PREFIX_PATH "${VESPA_HOMEBREW_PREFIX}/opt/bison" "${VESPA_HOMEBREW_PREFIX}/opt/flex" "${VESPA_HOMEBREW_PREFIX}/opt/openssl@3" "${VESPA_HOMEBREW_PREFIX}/opt/openblas" "${VESPA_HOMEBREW_PREFIX}/opt/icu4c") endif() message("-- DEFAULT_CMAKE_PREFIX_PATH is ${DEFAULT_CMAKE_PREFIX_PATH}") if(NOT DEFINED CMAKE_PREFIX_PATH) diff --git a/dependency-versions/pom.xml b/dependency-versions/pom.xml index f24baa036c0..962d666bf6b 100644 --- a/dependency-versions/pom.xml +++ b/dependency-versions/pom.xml @@ -112,7 +112,7 @@ <mimepull.vespa.version>1.10.0</mimepull.vespa.version> <mockito.vespa.version>5.5.0</mockito.vespa.version> <mojo-executor.vespa.version>2.4.0</mojo-executor.vespa.version> - <netty.vespa.version>4.1.97.Final</netty.vespa.version> + <netty.vespa.version>4.1.98.Final</netty.vespa.version> <netty-tcnative.vespa.version>2.0.61.Final</netty-tcnative.vespa.version> <onnxruntime.vespa.version>1.15.1</onnxruntime.vespa.version> <opennlp.vespa.version>2.3.0</opennlp.vespa.version> diff --git a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp index 03db333d582..31cb1d6b385 100644 --- a/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp +++ b/eval/src/apps/analyze_onnx_model/analyze_onnx_model.cpp @@ -78,10 +78,12 @@ MemoryUsage extract_memory_usage() { vespalib::string vm_size = UNKNOWN; vespalib::string vm_rss = UNKNOWN; FilePointer file(fopen("/proc/self/status", "r")); - vespalib::string line; - while (read_line(file, line)) { - extract(line, "VmSize:", vm_size); - extract(line, "VmRSS:", vm_rss); + if (file.valid()) { + vespalib::string line; + while (read_line(file, line)) { + extract(line, "VmSize:", vm_size); + extract(line, "VmRSS:", vm_rss); + } } return {convert(vm_size), convert(vm_rss)}; } diff --git a/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java b/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java index 7af1661cf0c..b16d26a04a4 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/FetchVector.java @@ -22,9 +22,6 @@ public class FetchVector { * Note: If this enum is changed, you must also change {@link DimensionHelper}. */ public enum Dimension { - /** Value from ApplicationId::serializedForm of the form tenant:applicationName:instance. */ - APPLICATION_ID, - /** * Cloud from com.yahoo.config.provision.CloudName::value, e.g. yahoo, aws, gcp. * @@ -59,6 +56,9 @@ public class FetchVector { */ HOSTNAME, + /** Value from ApplicationId::serializedForm of the form tenant:applicationName:instance. */ + INSTANCE_ID, + /** Node type from com.yahoo.config.provision.NodeType::name, e.g. tenant, host, confighost, controller, etc. */ NODE_TYPE, diff --git a/flags/src/main/java/com/yahoo/vespa/flags/FlagDefinition.java b/flags/src/main/java/com/yahoo/vespa/flags/FlagDefinition.java index 1773a03feb1..451f45ec792 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/FlagDefinition.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/FlagDefinition.java @@ -77,7 +77,7 @@ public class FlagDefinition { if (dimensions.contains(FetchVector.Dimension.CONSOLE_USER_EMAIL)) { Set<FetchVector.Dimension> disallowedCombinations = EnumSet.allOf(FetchVector.Dimension.class); disallowedCombinations.remove(FetchVector.Dimension.CONSOLE_USER_EMAIL); - disallowedCombinations.remove(FetchVector.Dimension.APPLICATION_ID); + disallowedCombinations.remove(FetchVector.Dimension.INSTANCE_ID); disallowedCombinations.remove(FetchVector.Dimension.TENANT_ID); disallowedCombinations.retainAll(dimensions); if (!disallowedCombinations.isEmpty()) diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index de14d4e7b00..2e158f0f3ef 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -13,7 +13,7 @@ import java.util.Optional; import java.util.TreeMap; import java.util.function.Predicate; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLOUD_ACCOUNT; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLUSTER_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLUSTER_TYPE; @@ -35,7 +35,7 @@ import static com.yahoo.vespa.flags.FetchVector.Dimension.VESPA_VERSION; * an unbound flag to a flag source produces a (bound) flag, e.g. {@link BooleanFlag} and {@link StringFlag}.</li> * <li>If you would like your flag value to be dependent on e.g. the application ID, then 1. you should * declare this in the unbound flag definition in this file (referring to - * {@link FetchVector.Dimension#APPLICATION_ID}), and 2. specify the application ID when retrieving the value, e.g. + * {@link FetchVector.Dimension#INSTANCE_ID}), and 2. specify the application ID when retrieving the value, e.g. * {@link BooleanFlag#with(FetchVector.Dimension, String)}. See {@link FetchVector} for more info.</li> * </ol> * @@ -53,7 +53,7 @@ public class Flags { List.of("baldersheim"), "2020-12-02", "2023-12-31", "Default limit for when to apply termwise query evaluation", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag QUERY_DISPATCH_POLICY = defineStringFlag( "query-dispatch-policy", "adaptive", @@ -61,62 +61,62 @@ public class Flags { "Select query dispatch policy, valid values are adaptive, round-robin, best-of-random-2," + " latency-amortized-over-requests, latency-amortized-over-time", "Takes effect at redeployment (requires restart)", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag SUMMARY_DECODE_POLICY = defineStringFlag( "summary-decode-policy", "eager", List.of("baldersheim"), "2023-03-30", "2023-12-31", "Select summary decoding policy, valid values are eager and on-demand/ondemand.", "Takes effect at redeployment (requires restart)", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag FEED_SEQUENCER_TYPE = defineStringFlag( "feed-sequencer-type", "THROUGHPUT", List.of("baldersheim"), "2020-12-02", "2023-12-31", "Selects type of sequenced executor used for feeding in proton, valid values are LATENCY, ADAPTIVE, THROUGHPUT", "Takes effect at redeployment (requires restart)", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MAX_UNCOMMITTED_MEMORY = defineIntFlag( "max-uncommitted-memory", 130000, List.of("geirst, baldersheim"), "2021-10-21", "2023-12-31", "Max amount of memory holding updates to an attribute before we do a commit.", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag RESPONSE_SEQUENCER_TYPE = defineStringFlag( "response-sequencer-type", "ADAPTIVE", List.of("baldersheim"), "2020-12-02", "2023-12-31", "Selects type of sequenced executor used for mbus responses, valid values are LATENCY, ADAPTIVE, THROUGHPUT", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag RESPONSE_NUM_THREADS = defineIntFlag( "response-num-threads", 2, List.of("baldersheim"), "2020-12-02", "2023-12-31", "Number of threads used for mbus responses, default is 2, negative number = numcores/4", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag USE_ASYNC_MESSAGE_HANDLING_ON_SCHEDULE = defineFeatureFlag( "async-message-handling-on-schedule", false, List.of("baldersheim"), "2020-12-02", "2023-12-31", "Optionally deliver async messages in own thread", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundDoubleFlag FEED_CONCURRENCY = defineDoubleFlag( "feed-concurrency", 0.5, List.of("baldersheim"), "2020-12-02", "2023-12-31", "How much concurrency should be allowed for feed", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundDoubleFlag FEED_NICENESS = defineDoubleFlag( "feed-niceness", 0.0, List.of("baldersheim"), "2022-06-24", "2023-12-31", "How nice feeding shall be", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MBUS_JAVA_NUM_TARGETS = defineIntFlag( @@ -124,71 +124,71 @@ public class Flags { List.of("baldersheim"), "2022-07-05", "2023-12-31", "Number of rpc targets per service", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MBUS_CPP_NUM_TARGETS = defineIntFlag( "mbus-cpp-num-targets", 2, List.of("baldersheim"), "2022-07-05", "2023-12-31", "Number of rpc targets per service", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag RPC_NUM_TARGETS = defineIntFlag( "rpc-num-targets", 2, List.of("baldersheim"), "2022-07-05", "2023-12-31", "Number of rpc targets per content node", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MBUS_JAVA_EVENTS_BEFORE_WAKEUP = defineIntFlag( "mbus-java-events-before-wakeup", 1, List.of("baldersheim"), "2022-07-05", "2023-12-31", "Number write events before waking up transport thread", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MBUS_CPP_EVENTS_BEFORE_WAKEUP = defineIntFlag( "mbus-cpp-events-before-wakeup", 1, List.of("baldersheim"), "2022-07-05", "2023-12-31", "Number write events before waking up transport thread", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag RPC_EVENTS_BEFORE_WAKEUP = defineIntFlag( "rpc-events-before-wakeup", 1, List.of("baldersheim"), "2022-07-05", "2023-12-31", "Number write events before waking up transport thread", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MBUS_NUM_NETWORK_THREADS = defineIntFlag( "mbus-num-network-threads", 1, List.of("baldersheim"), "2022-07-01", "2023-12-31", "Number of threads used for mbus network", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag SHARED_STRING_REPO_NO_RECLAIM = defineFeatureFlag( "shared-string-repo-no-reclaim", false, List.of("baldersheim"), "2022-06-14", "2023-12-31", "Controls whether we do track usage and reclaim unused enum values in shared string repo", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag CONTAINER_DUMP_HEAP_ON_SHUTDOWN_TIMEOUT = defineFeatureFlag( "container-dump-heap-on-shutdown-timeout", false, List.of("baldersheim"), "2021-09-25", "2023-12-31", "Will trigger a heap dump during if container shutdown times out", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag LOAD_CODE_AS_HUGEPAGES = defineFeatureFlag( "load-code-as-hugepages", false, List.of("baldersheim"), "2022-05-13", "2023-12-31", "Will try to map the code segment with huge (2M) pages", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundDoubleFlag CONTAINER_SHUTDOWN_TIMEOUT = defineDoubleFlag( "container-shutdown-timeout", 50.0, List.of("baldersheim"), "2021-09-25", "2023-12-31", "Timeout for shutdown of a jdisc container", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); // TODO: Move to a permanent flag public static final UnboundListFlag<String> ALLOWED_ATHENZ_PROXY_IDENTITIES = defineListFlag( @@ -203,14 +203,14 @@ public class Flags { "Allows replicas in up to N content groups to not be activated " + "for query visibility if they are out of sync with a majority of other replicas", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundDoubleFlag MIN_NODE_RATIO_PER_GROUP = defineDoubleFlag( "min-node-ratio-per-group", 0.0, List.of("geirst", "vekterli"), "2021-07-16", "2023-12-01", "Minimum ratio of nodes that have to be available (i.e. not Down) in any hierarchic content cluster group for the group to be Up", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag SYSTEM_MEMORY_HIGH = defineStringFlag( "system-memory-high", "", @@ -243,28 +243,28 @@ public class Flags { List.of("arnej"), "2021-11-15", "2023-12-31", "Use Vespa 8 types and formats for geographical positions", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MAX_COMPACT_BUFFERS = defineIntFlag( "max-compact-buffers", 1, List.of("baldersheim", "geirst", "toregge"), "2021-12-15", "2023-12-31", "Upper limit of buffers to compact in a data store at the same time for each reason (memory usage, address space usage)", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag ENABLE_PROXY_PROTOCOL_MIXED_MODE = defineFeatureFlag( "enable-proxy-protocol-mixed-mode", true, List.of("tokle"), "2022-05-09", "2023-10-01", "Enable or disable proxy protocol mixed mode", "Takes effect on redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag LOG_FILE_COMPRESSION_ALGORITHM = defineStringFlag( "log-file-compression-algorithm", "", List.of("arnej"), "2022-06-14", "2024-12-31", "Which algorithm to use for compressing log files. Valid values: empty string (default), gzip, zstd", "Takes effect immediately", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag SEPARATE_METRIC_CHECK_CONFIG = defineFeatureFlag( "separate-metric-check-config", false, @@ -278,7 +278,7 @@ public class Flags { List.of("bjorncs", "vekterli"), "2022-07-21", "2024-01-01", "Configure Vespa TLS capability enforcement mode", "Takes effect on restart of Docker container", - APPLICATION_ID,HOSTNAME,NODE_TYPE,TENANT_ID,VESPA_VERSION + INSTANCE_ID,HOSTNAME,NODE_TYPE,TENANT_ID,VESPA_VERSION ); public static final UnboundBooleanFlag ENABLE_OTELCOL = defineFeatureFlag( @@ -286,7 +286,7 @@ public class Flags { List.of("olaa"), "2022-09-23", "2023-12-01", "Whether an OpenTelemetry collector should be enabled", "Takes effect at next tick", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag CORE_ENCRYPTION_PUBLIC_KEY_ID = defineStringFlag( "core-encryption-public-key-id", "", @@ -300,7 +300,7 @@ public class Flags { List.of("arnej", "bjorncs"), "2023-02-28", "2024-01-10", "Enable global phase ranking", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag ENABLE_CROWDSTRIKE = defineFeatureFlag( "enable-crowdstrike", true, List.of("andreer"), "2023-04-13", "2023-10-14", @@ -311,7 +311,7 @@ public class Flags { "randomized-endpoint-names", false, List.of("andreer"), "2023-04-26", "2023-10-14", "Whether to use randomized endpoint names", "Takes effect on application deployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag ENABLE_THE_ONE_THAT_SHOULD_NOT_BE_NAMED = defineFeatureFlag( "enable-the-one-that-should-not-be-named", false, List.of("hmusum"), "2023-05-08", "2023-10-01", @@ -329,14 +329,14 @@ public class Flags { List.of("baldersheim"), "2023-06-29", "2023-12-31", "Should we enable proper nested multivalue grouping", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag USE_RECONFIGURABLE_DISPATCHER = defineFeatureFlag( "use-reconfigurable-dispatcher", false, List.of("jonmv"), "2023-07-14", "2023-10-01", "Whether to set up a ReconfigurableDispatcher with config self-sub for backend nodes", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag WRITE_CONFIG_SERVER_SESSION_DATA_AS_ONE_BLOB = defineFeatureFlag( "write-config-server-session-data-as-blob", false, @@ -382,21 +382,21 @@ public class Flags { List.of("freva"), "2023-09-08", "2023-11-01", "Minimum amount of advertised memory for exclusive nodes", "Takes effect immediately", - APPLICATION_ID, CLUSTER_ID, CLUSTER_TYPE); + INSTANCE_ID, CLUSTER_ID, CLUSTER_TYPE); public static final UnboundBooleanFlag ASSIGN_RANDOMIZED_ID = defineFeatureFlag( "assign-randomized-id", false, List.of("mortent"), "2023-08-31", "2024-02-01", "Whether to assign randomized id to the application", "Takes effect immediately", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag ASSIGNED_RANDOMIZED_ID_RATE = defineIntFlag( "assign-randomized-id-rate", 5, List.of("mortent"), "2023-09-11", "2024-02-01", "Rate for requesting assigned ids for existing certificates. Rate is per maintainer cycle.", "Takes effect immediately", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag CONTENT_LAYER_METADATA_FEATURE_LEVEL = defineIntFlag( "content-layer-metadata-feature-level", 0, @@ -404,7 +404,7 @@ public class Flags { "Value semantics: 0) legacy behavior, 1) operation cancellation, 2) operation " + "cancellation and ephemeral content node sequence numbers for bucket replicas", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); /** WARNING: public for testing: All flags should be defined in {@link Flags}. */ public static UnboundBooleanFlag defineFeatureFlag(String flagId, boolean defaultValue, List<String> owners, diff --git a/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java b/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java index f856ebeb456..abdc55f068e 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java @@ -13,7 +13,7 @@ import java.util.Set; import java.util.function.Predicate; import java.util.regex.Pattern; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLUSTER_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLUSTER_TYPE; import static com.yahoo.vespa.flags.FetchVector.Dimension.CONSOLE_USER_EMAIL; @@ -43,19 +43,19 @@ public class PermanentFlags { "jvm-gc-options", "", "Sets default jvm gc options", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag HEAP_SIZE_PERCENTAGE = defineIntFlag( "heap-size-percentage", 70, "Sets default jvm heap size percentage", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundDoubleFlag QUERY_DISPATCH_WARMUP = defineDoubleFlag( "query-dispatch-warmup", 5, "Warmup duration for query dispatcher", "Takes effect at redeployment (requires restart)", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag FLEET_CANARY = defineFeatureFlag( "fleet-canary", false, @@ -86,13 +86,13 @@ public class PermanentFlags { "host-flavor", "", "Specifies the Vespa flavor name that the hosts of the matching nodes should have.", "Takes effect on next deployment (including internal redeployment).", - APPLICATION_ID, CLUSTER_TYPE, CLUSTER_ID); + INSTANCE_ID, CLUSTER_TYPE, CLUSTER_ID); public static final UnboundBooleanFlag SKIP_MAINTENANCE_DEPLOYMENT = defineFeatureFlag( "node-repository-skip-maintenance-deployment", false, "Whether PeriodicApplicationMaintainer should skip deployment for an application", "Takes effect at next run of maintainer", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundListFlag<String> INACTIVE_MAINTENANCE_JOBS = defineListFlag( "inactive-maintenance-jobs", List.of(), String.class, @@ -120,19 +120,19 @@ public class PermanentFlags { "Hard limit on how many CPUs a container may use. This value is multiplied by CPU allocated to node, so " + "to cap CPU at 200%, set this to 2, etc. 0 disables the cap to allow unlimited CPU.", "Takes effect on next node agent tick. Change is orchestrated, but does NOT require container restart", - HOSTNAME, APPLICATION_ID, CLUSTER_ID, CLUSTER_TYPE); + HOSTNAME, INSTANCE_ID, CLUSTER_ID, CLUSTER_TYPE); public static final UnboundIntFlag MIN_DISK_THROUGHPUT_MB_S = defineIntFlag( "min-disk-throughput-mb-s", 0, "Minimum required disk throughput performance, 0 = default, Only when using remote disk", "Takes effect when node is provisioned", - APPLICATION_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); + INSTANCE_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); public static final UnboundIntFlag MIN_DISK_IOPS_K = defineIntFlag( "min-disk-iops-k", 0, "Minimum required disk I/O operations per second, unit is kilo, 0 = default, Only when using remote disk", "Takes effect when node is provisioned", - APPLICATION_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); + INSTANCE_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); public static final UnboundListFlag<String> DISABLED_HOST_ADMIN_TASKS = defineListFlag( "disabled-host-admin-tasks", List.of(), String.class, @@ -145,13 +145,13 @@ public class PermanentFlags { "docker-image-repo", "", "Override default docker image repo. Docker image version will be Vespa version.", "Takes effect on next deployment from controller", - APPLICATION_ID); + INSTANCE_ID); - public static final UnboundBooleanFlag SEND_LIMITED_METRIC_SET = defineFeatureFlag( - "send-limited-metric-set", true, - "Whether a limited metric set should be fetched from metrics-proxy (CD systems only)", + public static final UnboundStringFlag METRIC_SET = defineStringFlag( + "metric-set", "Vespa", + "Determines which metric set we should use for the given application", "Takes effect on next host admin tick", - APPLICATION_ID); + INSTANCE_ID); private static final String VERSION_QUALIFIER_REGEX = "[a-zA-Z0-9_-]+"; private static final Pattern QUALIFIER_PATTERN = Pattern.compile("^" + VERSION_QUALIFIER_REGEX + "$"); @@ -164,13 +164,13 @@ public class PermanentFlags { "Otherwise a '.' + the flag value will be appended.", "Takes effect on the next host admin tick. The upgrade to the new wanted docker image is orchestrated.", value -> value.isEmpty() || QUALIFIER_PATTERN.matcher(value).find() || VERSION_PATTERN.matcher(value).find(), - HOSTNAME, NODE_TYPE, TENANT_ID, APPLICATION_ID, CLUSTER_TYPE, CLUSTER_ID, VESPA_VERSION); + HOSTNAME, NODE_TYPE, TENANT_ID, INSTANCE_ID, CLUSTER_TYPE, CLUSTER_ID, VESPA_VERSION); public static final UnboundStringFlag ZOOKEEPER_SERVER_VERSION = defineStringFlag( "zookeeper-server-version", "3.8.0", "ZooKeeper server version, a jar file zookeeper-server-<ZOOKEEPER_SERVER_VERSION>-jar-with-dependencies.jar must exist", "Takes effect on restart of Docker container", - NODE_TYPE, APPLICATION_ID, HOSTNAME); + NODE_TYPE, INSTANCE_ID, HOSTNAME); public static final UnboundBooleanFlag ENABLE_PUBLIC_SIGNUP_FLOW = defineFeatureFlag( "enable-public-signup-flow", false, @@ -188,7 +188,7 @@ public class PermanentFlags { "jvm-omit-stack-trace-in-fast-throw", true, "Controls JVM option OmitStackTraceInFastThrow (default feature flag value is true, which is the default JVM option value as well)", "takes effect on JVM restart", - CLUSTER_TYPE, APPLICATION_ID); + CLUSTER_TYPE, INSTANCE_ID); public static final UnboundIntFlag MAX_TRIAL_TENANTS = defineIntFlag( "max-trial-tenants", -1, @@ -200,7 +200,7 @@ public class PermanentFlags { "allow-disable-mtls", true, "Allow application to disable client authentication", "Takes effect on redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MAX_OS_UPGRADES = defineIntFlag( "max-os-upgrades", 30, @@ -219,7 +219,7 @@ public class PermanentFlags { "tls-ciphers-override", List.of(), String.class, "Override TLS ciphers enabled for port 4443 on hosted application containers", "Takes effect on redeployment", - APPLICATION_ID + INSTANCE_ID ); public static final UnboundStringFlag ENDPOINT_CERTIFICATE_ALGORITHM = defineStringFlag( @@ -227,20 +227,20 @@ public class PermanentFlags { // Acceptable values are: "rsa_4096", "ecdsa_p256" "Selects algorithm used for an applications endpoint certificate", "Takes effect when a new endpoint certificate is requested (on first deployment or deployment adding new endpoints)", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundDoubleFlag RESOURCE_LIMIT_DISK = defineDoubleFlag( "resource-limit-disk", 0.75, "Resource limit (between 0.0 and 1.0) for disk usage on content nodes, used by cluster controller for when to block feed", "Takes effect on next deployment", - APPLICATION_ID + INSTANCE_ID ); public static final UnboundDoubleFlag RESOURCE_LIMIT_MEMORY = defineDoubleFlag( "resource-limit-memory", 0.8, "Resource limit (between 0.0 and 1.0) for memory usage on content nodes, used by cluster controller for when to block feed", "Takes effect on next deployment", - APPLICATION_ID + INSTANCE_ID ); public static final UnboundListFlag<String> LOGCTL_OVERRIDE = defineListFlag( @@ -248,7 +248,7 @@ public class PermanentFlags { "A list of vespa-logctl statements that are run on container startup. " + "Each item should be on the form <service>:<component> <level>=on", "Takes effect on container restart", - APPLICATION_ID, HOSTNAME + INSTANCE_ID, HOSTNAME ); public static final UnboundListFlag<String> ENVIRONMENT_VARIABLES = defineListFlag( @@ -256,14 +256,14 @@ public class PermanentFlags { "A list of environment variables set for all services. " + "Each item should be on the form <ENV_VAR>=<VALUE>", "Takes effect on service restart", - APPLICATION_ID + INSTANCE_ID ); public static final UnboundStringFlag CONFIG_PROXY_JVM_ARGS = defineStringFlag( "config-proxy-jvm-args", "", "Sets jvm args for config proxy (added at the end of startup command, will override existing ones)", "Takes effect on restart of Docker container", - APPLICATION_ID + INSTANCE_ID ); // This must be set in a feature flag to avoid flickering between the new and old value during config server upgrade @@ -277,20 +277,20 @@ public class PermanentFlags { "forward-issues-as-errors", true, "When the backend detects a problematic issue with a query, it will by default send it as an error message to the QRS, which adds it in an ErrorHit in the result. May be disabled using this flag.", "Takes effect immediately", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag DEACTIVATE_ROUTING = defineFeatureFlag( "deactivate-routing", false, "Deactivates routing for an application by removing all reals from its load balancers. Used in " + "cases where we immediately need to stop serving an application, i.e. in case of service violations", "Takes effect on next redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundListFlag<String> IGNORED_HTTP_USER_AGENTS = defineListFlag( "ignored-http-user-agents", List.of(), String.class, "List of user agents to ignore (crawlers etc)", "Takes effect immediately.", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundListFlag<String> INCOMPATIBLE_VERSIONS = defineListFlag( "incompatible-versions", List.of("8"), String.class, @@ -305,7 +305,7 @@ public class PermanentFlags { "The config server will refuse to serve config to nodes running a version which is incompatible with their " + "current wanted node version, i.e., nodes about to upgrade to a version which is incompatible with the current.", "Takes effect immediately", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundStringFlag ADMIN_CLUSTER_NODE_ARCHITECTURE = defineStringFlag( "admin-cluster-node-architecture", "x86_64", @@ -313,7 +313,7 @@ public class PermanentFlags { "(logserver and clustercontroller clusters).", "Takes effect on next redeployment", value -> Set.of("any", "arm64", "x86_64").contains(value), - APPLICATION_ID); + INSTANCE_ID); public static final UnboundListFlag<String> CLOUD_ACCOUNTS = defineListFlag( "cloud-accounts", List.of(), String.class, @@ -325,20 +325,20 @@ public class PermanentFlags { "fail-deployment-for-files-with-unknown-extension", "FAIL", "Whether to log or fail for deployments when app has a file with unknown extension (valid values: LOG, FAIL)", "Takes effect at redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundListFlag<String> DISABLED_DEPLOYMENT_ZONES = defineListFlag( "disabled-deployment-zones", List.of(), String.class, "The zones, e.g., prod.norway-71, where deployments jobs are currently disabled", "Takes effect immediately", - APPLICATION_ID + INSTANCE_ID ); public static final UnboundBooleanFlag ALLOW_USER_FILTERS = defineFeatureFlag( "allow-user-filters", true, "Allow user filter (chains) in application", "Takes effect on next redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundLongFlag CONFIG_SERVER_SESSION_EXPIRY_TIME = defineLongFlag( "config-server-session-expiry-time", 3600, @@ -357,27 +357,27 @@ public class PermanentFlags { "keep-file-references-days", 30, "How many days to keep file references on tenant nodes (based on last modification time)", "Takes effect on restart of Docker container", - APPLICATION_ID + INSTANCE_ID ); public static final UnboundIntFlag KEEP_FILE_REFERENCES_COUNT = defineIntFlag( "keep-file-references-count", 20, "How many file references to keep on tenant nodes (no matter what last modification time is)", "Takes effect on restart of Docker container", - ZONE_ID, APPLICATION_ID + ZONE_ID, INSTANCE_ID ); public static final UnboundIntFlag ENDPOINT_CONNECTION_TTL = defineIntFlag( "endpoint-connection-ttl", 45, "Time to live for connections to endpoints in seconds", "Takes effect on next redeployment", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundBooleanFlag AUTOSCALING = defineFeatureFlag( "autoscaling", true, "Whether to enable autoscaling", "Takes effect immediately", - APPLICATION_ID); + INSTANCE_ID); public static final UnboundIntFlag MAX_HOSTS_PER_HOUR = defineIntFlag( "max-hosts-per-hour", 40, @@ -392,7 +392,7 @@ public class PermanentFlags { "Takes effect on next tick", // The application ID is the exclusive application ID associated with the host, // if any, or otherwise hosted-vespa:tenant-host:default. - APPLICATION_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); + INSTANCE_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); public static final UnboundIntFlag DROP_DENTRIES = defineIntFlag( "drop-dentries", -1, @@ -401,7 +401,7 @@ public class PermanentFlags { "Takes effect on next tick", // The application ID is the exclusive application ID associated with the host, // if any, or otherwise hosted-vespa:tenant-host:default. - APPLICATION_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); + INSTANCE_ID, TENANT_ID, CLUSTER_ID, CLUSTER_TYPE); public static final UnboundIntFlag CERT_POOL_SIZE = defineIntFlag( "cert-pool-size", 0, diff --git a/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java b/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java index 2193d70ec47..8fb48c8a82f 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/json/DimensionHelper.java @@ -4,30 +4,31 @@ package com.yahoo.vespa.flags.json; import com.yahoo.vespa.flags.FetchVector; import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; import java.util.Map; -import java.util.stream.Collectors; /** * @author hakonhall */ public class DimensionHelper { - private static final Map<FetchVector.Dimension, String> serializedDimensions = new HashMap<>(); + private static final Map<FetchVector.Dimension, List<String>> serializedDimensions = new HashMap<>(); static { - serializedDimensions.put(FetchVector.Dimension.APPLICATION_ID, "application"); - serializedDimensions.put(FetchVector.Dimension.CLOUD, "cloud"); - serializedDimensions.put(FetchVector.Dimension.CLOUD_ACCOUNT, "cloud-account"); - serializedDimensions.put(FetchVector.Dimension.CLUSTER_ID, "cluster-id"); - serializedDimensions.put(FetchVector.Dimension.CLUSTER_TYPE, "cluster-type"); - serializedDimensions.put(FetchVector.Dimension.CONSOLE_USER_EMAIL, "console-user-email"); - serializedDimensions.put(FetchVector.Dimension.ENVIRONMENT, "environment"); - serializedDimensions.put(FetchVector.Dimension.HOSTNAME, "hostname"); - serializedDimensions.put(FetchVector.Dimension.NODE_TYPE, "node-type"); - serializedDimensions.put(FetchVector.Dimension.SYSTEM, "system"); - serializedDimensions.put(FetchVector.Dimension.TENANT_ID, "tenant"); - serializedDimensions.put(FetchVector.Dimension.VESPA_VERSION, "vespa-version"); - serializedDimensions.put(FetchVector.Dimension.ZONE_ID, "zone"); + serializedDimensions.put(FetchVector.Dimension.CLOUD, List.of("cloud")); + serializedDimensions.put(FetchVector.Dimension.CLOUD_ACCOUNT, List.of("cloud-account")); + serializedDimensions.put(FetchVector.Dimension.CLUSTER_ID, List.of("cluster-id")); + serializedDimensions.put(FetchVector.Dimension.CLUSTER_TYPE, List.of("cluster-type")); + serializedDimensions.put(FetchVector.Dimension.CONSOLE_USER_EMAIL, List.of("console-user-email")); + serializedDimensions.put(FetchVector.Dimension.ENVIRONMENT, List.of("environment")); + serializedDimensions.put(FetchVector.Dimension.HOSTNAME, List.of("hostname")); + serializedDimensions.put(FetchVector.Dimension.INSTANCE_ID, List.of("instance", "application")); + serializedDimensions.put(FetchVector.Dimension.NODE_TYPE, List.of("node-type")); + serializedDimensions.put(FetchVector.Dimension.SYSTEM, List.of("system")); + serializedDimensions.put(FetchVector.Dimension.TENANT_ID, List.of("tenant")); + serializedDimensions.put(FetchVector.Dimension.VESPA_VERSION, List.of("vespa-version")); + serializedDimensions.put(FetchVector.Dimension.ZONE_ID, List.of("zone")); if (serializedDimensions.size() != FetchVector.Dimension.values().length) { throw new IllegalStateException(FetchVectorHelper.class.getName() + " is not in sync with " + @@ -35,16 +36,27 @@ public class DimensionHelper { } } - private static final Map<String, FetchVector.Dimension> deserializedDimensions = serializedDimensions. - entrySet().stream().collect(Collectors.toMap(Map.Entry::getValue, Map.Entry::getKey)); + private static final Map<String, FetchVector.Dimension> deserializedDimensions = reverseMapping(serializedDimensions); + + private static Map<String, FetchVector.Dimension> reverseMapping(Map<FetchVector.Dimension, List<String>> mapping) { + Map<String, FetchVector.Dimension> reverseMapping = new LinkedHashMap<>(); + mapping.forEach((dimension, serializedDimensions) -> { + serializedDimensions.forEach(serializedDimension -> { + if (reverseMapping.put(serializedDimension, dimension) != null) { + throw new IllegalStateException("Duplicate serialized dimension: '" + serializedDimension + "'"); + } + }); + }); + return Map.copyOf(reverseMapping); + } public static String toWire(FetchVector.Dimension dimension) { - String serializedDimension = serializedDimensions.get(dimension); - if (serializedDimension == null) { + List<String> serializedDimension = serializedDimensions.get(dimension); + if (serializedDimension == null || serializedDimension.isEmpty()) { throw new IllegalArgumentException("Unsupported dimension (please add it): '" + dimension + "'"); } - return serializedDimension; + return serializedDimension.get(0); } public static FetchVector.Dimension fromWire(String serializedDimension) { diff --git a/flags/src/test/java/com/yahoo/vespa/flags/FlagsTest.java b/flags/src/test/java/com/yahoo/vespa/flags/FlagsTest.java index 3edde140de8..dd332be6627 100644 --- a/flags/src/test/java/com/yahoo/vespa/flags/FlagsTest.java +++ b/flags/src/test/java/com/yahoo/vespa/flags/FlagsTest.java @@ -50,7 +50,7 @@ public class FlagsTest { // zone is set because it was set on the unbound flag above assertThat(vector.getValue().getValue(FetchVector.Dimension.ZONE_ID), is(Optional.of("a-zone"))); // application and node type are not set - assertThat(vector.getValue().getValue(FetchVector.Dimension.APPLICATION_ID), is(Optional.empty())); + assertThat(vector.getValue().getValue(FetchVector.Dimension.INSTANCE_ID), is(Optional.empty())); assertThat(vector.getValue().getValue(FetchVector.Dimension.NODE_TYPE), is(Optional.empty())); RawFlag rawFlag = mock(RawFlag.class); @@ -58,11 +58,11 @@ public class FlagsTest { when(rawFlag.asJsonNode()).thenReturn(BooleanNode.getTrue()); // raw flag deserializes to true - assertThat(booleanFlag.with(FetchVector.Dimension.APPLICATION_ID, "an-app").value(), equalTo(true)); + assertThat(booleanFlag.with(FetchVector.Dimension.INSTANCE_ID, "an-app").value(), equalTo(true)); verify(source, times(2)).fetch(any(), vector.capture()); // application was set on the (bound) flag. - assertThat(vector.getValue().getValue(FetchVector.Dimension.APPLICATION_ID), is(Optional.of("an-app"))); + assertThat(vector.getValue().getValue(FetchVector.Dimension.INSTANCE_ID), is(Optional.of("an-app"))); } @Test diff --git a/flags/src/test/java/com/yahoo/vespa/flags/json/ConditionTest.java b/flags/src/test/java/com/yahoo/vespa/flags/json/ConditionTest.java index 4da66bd5cc1..084ad3b9395 100644 --- a/flags/src/test/java/com/yahoo/vespa/flags/json/ConditionTest.java +++ b/flags/src/test/java/com/yahoo/vespa/flags/json/ConditionTest.java @@ -18,7 +18,7 @@ public class ConditionTest { var params = new Condition.CreateParams(FetchVector.Dimension.HOSTNAME).withValues(hostname1); Condition condition = WhitelistCondition.create(params); assertFalse(condition.test(new FetchVector())); - assertFalse(condition.test(new FetchVector().with(FetchVector.Dimension.APPLICATION_ID, "foo"))); + assertFalse(condition.test(new FetchVector().with(FetchVector.Dimension.INSTANCE_ID, "foo"))); assertFalse(condition.test(new FetchVector().with(FetchVector.Dimension.HOSTNAME, "bar"))); assertTrue(condition.test(new FetchVector().with(FetchVector.Dimension.HOSTNAME, hostname1))); } @@ -29,7 +29,7 @@ public class ConditionTest { var params = new Condition.CreateParams(FetchVector.Dimension.HOSTNAME).withValues(hostname1); Condition condition = BlacklistCondition.create(params); assertTrue(condition.test(new FetchVector())); - assertTrue(condition.test(new FetchVector().with(FetchVector.Dimension.APPLICATION_ID, "foo"))); + assertTrue(condition.test(new FetchVector().with(FetchVector.Dimension.INSTANCE_ID, "foo"))); assertTrue(condition.test(new FetchVector().with(FetchVector.Dimension.HOSTNAME, "bar"))); assertFalse(condition.test(new FetchVector().with(FetchVector.Dimension.HOSTNAME, hostname1))); } diff --git a/flags/src/test/java/com/yahoo/vespa/flags/json/FlagDataTest.java b/flags/src/test/java/com/yahoo/vespa/flags/json/FlagDataTest.java index 3ca7f59c759..ed81afc8054 100644 --- a/flags/src/test/java/com/yahoo/vespa/flags/json/FlagDataTest.java +++ b/flags/src/test/java/com/yahoo/vespa/flags/json/FlagDataTest.java @@ -52,6 +52,8 @@ public class FlagDataTest { } }"""; + private final String json_with_instance = json.replace("application", "instance"); + private final FetchVector vector = new FetchVector(); @Test @@ -62,16 +64,16 @@ public class FlagDataTest { // First rule matches only if both conditions match verify(Optional.of("false"), vector .with(FetchVector.Dimension.HOSTNAME, "host1") - .with(FetchVector.Dimension.APPLICATION_ID, "app2")); + .with(FetchVector.Dimension.INSTANCE_ID, "app2")); verify(Optional.of("true"), vector .with(FetchVector.Dimension.HOSTNAME, "host1") - .with(FetchVector.Dimension.APPLICATION_ID, "app3")); + .with(FetchVector.Dimension.INSTANCE_ID, "app3")); // Verify unsetting a dimension with null works. verify(Optional.of("true"), vector .with(FetchVector.Dimension.HOSTNAME, "host1") - .with(FetchVector.Dimension.APPLICATION_ID, "app3") - .with(FetchVector.Dimension.APPLICATION_ID, null)); + .with(FetchVector.Dimension.INSTANCE_ID, "app3") + .with(FetchVector.Dimension.INSTANCE_ID, null)); // No rules apply if zone is overridden to an unknown zone verify(Optional.empty(), vector.with(FetchVector.Dimension.ZONE_ID, "unknown zone")); @@ -81,7 +83,7 @@ public class FlagDataTest { void testPartialResolve() { FlagData data = FlagData.deserialize(json); assertEquals(data.partialResolve(vector), data); - assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app1")), + assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app1")), FlagData.deserialize(""" { "id": "id1", @@ -102,7 +104,7 @@ public class FlagDataTest { } }""")); - assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app1")), + assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app1")), FlagData.deserialize(""" { "id": "id1", @@ -123,7 +125,7 @@ public class FlagDataTest { } }""")); - assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app3")), + assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app3")), FlagData.deserialize(""" { "id": "id1", @@ -154,7 +156,7 @@ public class FlagDataTest { } }""")); - assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app3") + assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app3") .with(FetchVector.Dimension.HOSTNAME, "host1")), FlagData.deserialize(""" { @@ -169,7 +171,7 @@ public class FlagDataTest { } }""")); - assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app3") + assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app3") .with(FetchVector.Dimension.HOSTNAME, "host3")), FlagData.deserialize(""" { @@ -191,7 +193,7 @@ public class FlagDataTest { } }""")); - assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app3") + assertEquals(data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app3") .with(FetchVector.Dimension.HOSTNAME, "host3") .with(FetchVector.Dimension.ZONE_ID, "zone2")), FlagData.deserialize(""" @@ -204,7 +206,7 @@ public class FlagDataTest { ] }""")); - FlagData fullyResolved = data.partialResolve(vector.with(FetchVector.Dimension.APPLICATION_ID, "app3") + FlagData fullyResolved = data.partialResolve(vector.with(FetchVector.Dimension.INSTANCE_ID, "app3") .with(FetchVector.Dimension.HOSTNAME, "host3") .with(FetchVector.Dimension.ZONE_ID, "zone3")); assertEquals(fullyResolved, FlagData.deserialize(""" @@ -271,6 +273,11 @@ public class FlagDataTest { } private void verify(Optional<String> expectedValue, FetchVector vector) { + verify(json, expectedValue, vector); + verify(json_with_instance, expectedValue, vector); + } + + private void verify(String json, Optional<String> expectedValue, FetchVector vector) { FlagData data = FlagData.deserialize(json); assertEquals("id1", data.id().toString()); Optional<RawFlag> rawFlag = data.resolve(vector); diff --git a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java index 6597f10198d..e81f0b1d897 100644 --- a/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java +++ b/jdisc-security-filters/src/main/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilter.java @@ -55,7 +55,6 @@ public class CloudTokenDataPlaneFilter extends JsonSecurityRequestFilterBase { private static List<Client> parseClients(CloudTokenDataPlaneFilterConfig cfg) { Set<String> ids = new HashSet<>(); List<Client> clients = new ArrayList<>(cfg.clients().size()); - if (cfg.clients().isEmpty()) throw new IllegalArgumentException("Empty clients configuration"); for (var c : cfg.clients()) { if (ids.contains(c.id())) throw new IllegalArgumentException("Clients definition has duplicate id '%s'".formatted(c.id())); diff --git a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java index a34d2eb67c3..2ff6209fd06 100644 --- a/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java +++ b/jdisc-security-filters/src/test/java/com/yahoo/jdisc/http/filter/security/cloud/CloudTokenDataPlaneFilterTest.java @@ -166,6 +166,29 @@ class CloudTokenDataPlaneFilterTest { assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); } + @Test + void rejects_tokens_on_empty_clients() { + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .withHeader("Authorization", "Bearer " + UNKNOWN_TOKEN.secretTokenString()) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithEmptyClientsConfig().filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(FORBIDDEN, responseHandler.getResponse().getStatus()); + } + + @Test + void rejects_missing_tokens_on_empty_clients() { + var req = FilterTestUtils.newRequestBuilder() + .withMethod(Method.GET) + .build(); + var responseHandler = new MockResponseHandler(); + newFilterWithEmptyClientsConfig().filter(req, responseHandler); + assertNotNull(responseHandler.getResponse()); + assertEquals(UNAUTHORIZED, responseHandler.getResponse().getStatus()); + } + private CloudTokenDataPlaneFilter newFilterWithClientsConfig() { return new CloudTokenDataPlaneFilter( new CloudTokenDataPlaneFilterConfig.Builder() @@ -191,4 +214,12 @@ class CloudTokenDataPlaneFilterTest { clock); } + private CloudTokenDataPlaneFilter newFilterWithEmptyClientsConfig() { + return new CloudTokenDataPlaneFilter( + new CloudTokenDataPlaneFilterConfig.Builder() + .tokenContext(TOKEN_CONTEXT) + .build(), + clock); + } + } diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java new file mode 100644 index 00000000000..aafb9877c27 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -0,0 +1,306 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxEvaluator; +import ai.vespa.modelintegration.evaluator.OnnxEvaluatorOptions; +import com.yahoo.api.annotations.Beta; +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.component.AbstractComponent; +import com.yahoo.component.annotation.Inject; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.huggingface.HuggingFaceTokenizer; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import java.nio.file.Paths; +import java.util.Map; +import java.util.List; +import java.util.ArrayList; +import java.util.Set; +import java.util.HashSet; +import java.util.BitSet; +import java.util.Arrays; + +import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGEST_FIRST; + +/** + * A ColBERT embedder implementation that maps text to multiple vectors, one vector per subword id. + * This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model. + * + * See col-bert-embedder.def for configurable parameters. + * @author bergum + */ +@Beta +public class ColBertEmbedder extends AbstractComponent implements Embedder { + private final Embedder.Runtime runtime; + private final String inputIdsName; + private final String attentionMaskName; + + private final String outputName; + + private final HuggingFaceTokenizer tokenizer; + private final OnnxEvaluator evaluator; + + private final int maxTransformerTokens; + private final int maxQueryTokens; + private final int maxDocumentTokens; + + private final long startSequenceToken; + private final long endSequenceToken; + private final long maskSequenceToken; + + + @Inject + public ColBertEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, ColBertEmbedderConfig config) { + this.runtime = runtime; + inputIdsName = config.transformerInputIds(); + attentionMaskName = config.transformerAttentionMask(); + outputName = config.transformerOutput(); + maxTransformerTokens = config.transformerMaxTokens(); + if(config.maxDocumentTokens() > maxTransformerTokens) + throw new IllegalArgumentException("maxDocumentTokens must be less than or equal to transformerMaxTokens"); + maxDocumentTokens = config.maxDocumentTokens(); + maxQueryTokens = config.maxQueryTokens(); + startSequenceToken = config.transformerStartSequenceToken(); + endSequenceToken = config.transformerEndSequenceToken(); + maskSequenceToken = config.transformerMaskToken(); + + var tokenizerPath = Paths.get(config.tokenizerPath().toString()); + var builder = new HuggingFaceTokenizer.Builder() + .addSpecialTokens(false) + .addDefaultModel(tokenizerPath) + .setPadding(false); + var info = HuggingFaceTokenizer.getModelInfo(tokenizerPath); + if (info.maxLength() == -1 || info.truncation() != LONGEST_FIRST) { + // Force truncation to max token vector length accepted by model if tokenizer.json contains no valid truncation configuration + int maxLength = info.maxLength() > 0 && info.maxLength() <= config.transformerMaxTokens() + ? info.maxLength() + : config.transformerMaxTokens(); + builder.setTruncation(true).setMaxLength(maxLength); + } + this.tokenizer = builder.build(); + var onnxOpts = new OnnxEvaluatorOptions(); + + if (config.transformerGpuDevice() >= 0) + onnxOpts.setGpuDevice(config.transformerGpuDevice()); + onnxOpts.setExecutionMode(config.transformerExecutionMode().toString()); + onnxOpts.setThreads(config.transformerInterOpThreads(), config.transformerIntraOpThreads()); + evaluator = onnx.evaluatorOf(config.transformerModel().toString(), onnxOpts); + validateModel(); + } + + public void validateModel() { + Map<String, TensorType> inputs = evaluator.getInputInfo(); + validateName(inputs, inputIdsName, "input"); + validateName(inputs, attentionMaskName, "input"); + Map<String, TensorType> outputs = evaluator.getOutputInfo(); + validateName(outputs, outputName, "output"); + } + + private void validateName(Map<String, TensorType> types, String name, String type) { + if (!types.containsKey(name)) { + throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " + + "Model contains: " + String.join(",", types.keySet())); + } + } + + @Override + public List<Integer> embed(String text, Context context) { + throw new UnsupportedOperationException("This embedder only supports embed with tensor type"); + } + + @Override + public Tensor embed(String text, Context context, TensorType tensorType) { + if(!verifyTensorType(tensorType)) { + throw new IllegalArgumentException("Invalid ColBERT embedder tensor destination." + + "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType.toString()); + } + if (context.getDestination().startsWith("query")) { + return embedQuery(text, context, tensorType); + } else { + return embedDocument(text, context, tensorType); + } + } + + @Override + public void deconstruct() { + evaluator.close(); + tokenizer.close(); + } + + protected Tensor embedQuery(String text, Context context, TensorType tensorType) { + if(tensorType.valueType() == TensorType.Value.INT8) + throw new IllegalArgumentException("ColBert query embed does not accept int8 tensor value type"); + + long Q_TOKEN_ID = 1; // [unused0] token id used during training to differentiate query versus document. + + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + List<Long> ids = encoding.ids(); + if (ids.size() > maxQueryTokens - 3) + ids = ids.subList(0, maxQueryTokens - 3); + + List<Long> inputIds = new ArrayList<>(maxQueryTokens); + List<Long> attentionMask = new ArrayList<>(maxQueryTokens); + + inputIds.add(startSequenceToken); + inputIds.add(Q_TOKEN_ID); + inputIds.addAll(ids); + inputIds.add(endSequenceToken); + int length = inputIds.size(); + + int padding = maxQueryTokens - length; + for (int i = 0; i < padding; i++) + inputIds.add(maskSequenceToken); + + for (int i = 0; i < length; i++) + attentionMask.add((long) 1); + for (int i = 0; i < padding; i++) + attentionMask.add((long) 0);//Do not attend to mask paddings + + Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + + var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0")); + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + Tensor tokenEmbeddings = outputs.get(outputName); + IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + + int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); + if(dims != result.shape()[1]) { + throw new IllegalArgumentException("Token dimensionality does not" + + " match indexed dimensionality of " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(tensorType); + for (int token = 0; token < result.shape()[0]; token++) + for (int d = 0; d < result.shape()[1]; d++) + builder.cell(TensorAddress.of(token, d), result.get(TensorAddress.of(token, d))); + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return builder.build(); + } + + protected Tensor embedDocument(String text, Context context, TensorType tensorType) { + long D_TOKEN_ID = 2; // [unused1] token id used during training to differentiate query versus document. + var start = System.nanoTime(); + var encoding = tokenizer.encode(text, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); + + List<Long> ids = encoding.ids().stream().filter(token + -> !PUNCTUATION_TOKEN_IDS.contains(token)).toList(); + ; + + if (ids.size() > maxDocumentTokens - 3) + ids = ids.subList(0, maxDocumentTokens - 3); + List<Long> inputIds = new ArrayList<>(maxDocumentTokens); + List<Long> attentionMask = new ArrayList<>(maxDocumentTokens); + inputIds.add(startSequenceToken); + inputIds.add(D_TOKEN_ID); + inputIds.addAll(ids); + inputIds.add(endSequenceToken); + for (int i = 0; i < inputIds.size(); i++) + attentionMask.add((long) 1); + + Tensor inputIdsTensor = createTensorRepresentation(inputIds, "d1"); + Tensor attentionMaskTensor = createTensorRepresentation(attentionMask, "d1"); + + var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"), + attentionMaskName, attentionMaskTensor.expand("d0")); + + Map<String, Tensor> outputs = evaluator.evaluate(inputs); + Tensor tokenEmbeddings = outputs.get(outputName); + IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + Tensor contextualEmbeddings; + if(tensorType.valueType() == TensorType.Value.INT8) { + contextualEmbeddings = toBitTensor(result, tensorType); + } else { + contextualEmbeddings = toFloatTensor(result, tensorType); + } + + runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context); + return contextualEmbeddings; + } + + public static Tensor toFloatTensor(IndexedTensor result, TensorType type) { + int size = type.indexedSubtype().dimensions().size(); + if (size != 1) + throw new IllegalArgumentException("Indexed tensor must have one dimension"); + int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int)result.shape()[1]; + if(resultDim != dims) { + throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDim + + " + dimensions into tensor with " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(type); + for (int token = 0; token < result.shape()[0]; token++) { + for (int d = 0; d < result.shape()[1]; d++) { + var value = result.get(TensorAddress.of(token, d)); + builder.cell(TensorAddress.of(token,d),value); + } + } + return builder.build(); + } + + public static Tensor toBitTensor(IndexedTensor result, TensorType type) { + if (type.valueType() != TensorType.Value.INT8) + throw new IllegalArgumentException("Only a int8 tensor type can be" + + " the destination of bit packing"); + int size = type.indexedSubtype().dimensions().size(); + if (size != 1) + throw new IllegalArgumentException("Indexed tensor must have one dimension"); + int dims = type.indexedSubtype().dimensions().get(0).size().get().intValue(); + int resultDim = (int)result.shape()[1]; + if(resultDim/8 != dims) { + throw new IllegalArgumentException("Not possible to pack " + resultDim + + " + dimensions into " + dims); + } + Tensor.Builder builder = Tensor.Builder.of(type); + for (int token = 0; token < result.shape()[0]; token++) { + BitSet bitSet = new BitSet(8); + int key = 0; + for (int d = 0; d < result.shape()[1]; d++) { + var value = result.get(TensorAddress.of(token, d)); + int bitIndex = 7 - (d % 8); + if (value > 0.0) { + bitSet.set(bitIndex); + } else { + bitSet.clear(bitIndex); + } + if ((d + 1) % 8 == 0) { + byte[] bytes = bitSet.toByteArray(); + byte packed = (bytes.length == 0) ? 0 : bytes[0]; + builder.cell(TensorAddress.of(token, key), packed); + key++; + bitSet = new BitSet(8); + } + } + } + return builder.build(); + } + + protected boolean verifyTensorType(TensorType target) { + return target.dimensions().size() == 2 && + target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1; + } + + private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) { + int size = input.size(); + TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type); + for (int i = 0; i < size; ++i) { + builder.cell(input.get(i), i); + } + return builder.build(); + } + + private static final Set<Long> PUNCTUATION_TOKEN_IDS = new HashSet<>( + Arrays.asList(999L, 1000L, 1001L, 1002L, 1003L, 1004L, 1005L, 1006L, + 1007L, 1008L, 1009L, 1010L, 1011L, 1012L, 1013L, 1024L, + 1025L, 1026L, 1027L, 1028L, 1029L, 1030L, 1031L, 1032L, + 1033L, 1034L, 1035L, 1036L, 1063L, 1064L, 1065L, 1066L)); +} diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java new file mode 100644 index 00000000000..8516f6e6689 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -0,0 +1,126 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.embedding; + +import ai.vespa.modelintegration.evaluator.OnnxRuntime; +import com.yahoo.config.ModelReference; +import com.yahoo.embedding.ColBertEmbedderConfig; +import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.MixedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertThrows; +import static org.junit.Assume.assumeTrue; + +public class ColBertEmbedderTest { + + @Test + public void testPacking() { + assertPackedRight( + "" + + "tensor<float>(d1[6],d2[8]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1]," + + "[0, 0, 0, 0, 0, 1, 0, 1]," + + "[0, 0, 0, 0, 0, 0, 1, 1]," + + "[0, 1, 1, 1, 1, 1, 1, 1]," + + "[1, 0, 0, 0, 0, 0, 0, 0]," + + "[1, 1, 1, 1, 1, 1, 1, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[1])"), + "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}" + ); + assertPackedRight( + "" + + "tensor<float>(d1[2],d2[16]):" + + "[" + + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + + "]", + TensorType.fromSpec("tensor<int8>(dt{},x[2])"), + "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}" + ); + } + + @Test + public void testEmbedder() { + assertEmbed("tensor<float>(dt{},x[128])", "this is a document", indexingContext); + assertEmbed("tensor<int8>(dt{},x[16])", "this is a document", indexingContext); + assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because int8 is not supported for query context + assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext); + }); + assertThrows(IllegalArgumentException.class, () -> { + //throws because 16 is less than model output (128) and we want float + assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext); + }); + + assertThrows(IllegalArgumentException.class, () -> { + //throws because 128/8 does not fit into 15 + assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext); + }); + } + + @Test + public void testLenghtLimits() { + StringBuilder sb = new StringBuilder(); + for(int i = 0; i < 1024; i++) { + sb.append("annoyance"); + sb.append(" "); + } + String text = sb.toString(); + Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + assertEquals(512*128,fullFloat.size()); + + Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); + assertEquals(32*128,query.size()); + + Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); + assertEquals(512*16,binaryRep.size()); + + Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); + // 4 tokens, 16 bytes each = 64 bytes + //because of CLS, special, sequence, SEP + assertEquals(4*16,shortDoc.size());; + } + + static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { + TensorType destType = TensorType.fromSpec(tensorSpec); + Tensor result = embedder.embed(text, context, destType); + assertEquals(destType,result.type()); + MixedTensor mixedTensor = (MixedTensor) result; + if(context == queryContext) { + assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size()); + } + return result; + } + + static void assertPackedRight(String numbers, TensorType destination,String expected) { + Tensor packed = ColBertEmbedder.toBitTensor((IndexedTensor) Tensor.from(numbers), destination); + assertEquals(expected,packed.toString()); + } + + static final Embedder embedder; + static final Embedder.Context indexingContext; + static final Embedder.Context queryContext; + static { + indexingContext = new Embedder.Context("schema.indexing"); + queryContext = new Embedder.Context("query(qt)"); + embedder = getEmbedder(); + } + private static Embedder getEmbedder() { + String vocabPath = "src/test/models/onnx/transformer/tokenizer.json"; + String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx"; + assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath)); + ColBertEmbedderConfig.Builder builder = new ColBertEmbedderConfig.Builder(); + builder.tokenizerPath(ModelReference.valueOf(vocabPath)); + builder.transformerModel(ModelReference.valueOf(modelPath)); + builder.transformerGpuDevice(-1); + return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); + } +}
\ No newline at end of file diff --git a/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx Binary files differnew file mode 100644 index 00000000000..5ab1060e59e --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/colbert-dummy-v2.onnx diff --git a/model-integration/src/test/models/onnx/transformer/tokenizer.json b/model-integration/src/test/models/onnx/transformer/tokenizer.json new file mode 100644 index 00000000000..28340f289bb --- /dev/null +++ b/model-integration/src/test/models/onnx/transformer/tokenizer.json @@ -0,0 +1,175 @@ +{ + "version": "1.0", + "truncation": null, + "padding": null, + "added_tokens": [ + { + "id": 0, + "content": "[PAD]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 100, + "content": "[UNK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 101, + "content": "[CLS]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 102, + "content": "[SEP]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + }, + { + "id": 103, + "content": "[MASK]", + "single_word": false, + "lstrip": false, + "rstrip": false, + "normalized": false, + "special": true + } + ], + "normalizer": { + "type": "BertNormalizer", + "clean_text": true, + "handle_chinese_chars": true, + "strip_accents": null, + "lowercase": true + }, + "pre_tokenizer": { + "type": "BertPreTokenizer" + }, + "post_processor": { + "type": "TemplateProcessing", + "single": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + } + ], + "pair": [ + { + "SpecialToken": { + "id": "[CLS]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "A", + "type_id": 0 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 0 + } + }, + { + "Sequence": { + "id": "B", + "type_id": 1 + } + }, + { + "SpecialToken": { + "id": "[SEP]", + "type_id": 1 + } + } + ], + "special_tokens": { + "[CLS]": { + "id": "[CLS]", + "ids": [101], + "tokens": ["[CLS]"] + }, + "[SEP]": { + "id": "[SEP]", + "ids": [102], + "tokens": ["[SEP]"] + } + } + }, + "decoder": { + "type": "WordPiece", + "prefix": "##", + "cleanup": true + }, + "model": { + "type": "WordPiece", + "unk_token": "[UNK]", + "continuing_subword_prefix": "##", + "max_input_chars_per_word": 100, + "vocab": { + "[PAD]": 0, + "[unused0]": 1, + "[unused1]": 2, + "[UNK]": 100, + "[CLS]": 101, + "[SEP]": 102, + "[MASK]": 103, + "a": 1037, + "b": 1038, + "c": 1039, + "d": 1040, + "e": 1041, + "f": 1042, + "g": 1043, + "h": 1044, + "i": 1045, + "j": 1046, + "k": 1047, + "l": 1048, + "m": 1049, + "n": 1050, + "o": 1051, + "p": 1052, + "q": 1053, + "r": 1054, + "s": 1055, + "t": 1056, + "u": 1057, + "v": 1058, + "w": 1059, + "x": 1060, + "y": 1061, + "z": 1062 + } + } +} diff --git a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java index 466ee65fcc1..98252a696f2 100644 --- a/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java +++ b/node-admin/src/main/java/com/yahoo/vespa/hosted/node/admin/nodeagent/NodeAgentImpl.java @@ -40,7 +40,7 @@ import java.util.function.Function; import java.util.logging.Level; import java.util.logging.Logger; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLUSTER_ID; import static com.yahoo.vespa.flags.FetchVector.Dimension.CLUSTER_TYPE; import static com.yahoo.vespa.flags.FetchVector.Dimension.HOSTNAME; @@ -416,7 +416,7 @@ public class NodeAgentImpl implements NodeAgent { private ContainerResources getContainerResources(NodeAgentContext context) { double cpuCap = context.vcpuOnThisHost() * containerCpuCap - .with(APPLICATION_ID, context.node().owner().map(ApplicationId::serializedForm)) + .with(INSTANCE_ID, context.node().owner().map(ApplicationId::serializedForm)) .with(CLUSTER_ID, context.node().membership().map(NodeMembership::clusterId)) .with(CLUSTER_TYPE, context.node().membership().map(membership -> membership.type().value())) .with(HOSTNAME, context.node().hostname()) diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java index e228d31384c..f42d1ce9bd3 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerInstance.java @@ -21,7 +21,8 @@ import java.util.Set; public class LoadBalancerInstance { private final Optional<DomainName> hostname; - private final Optional<String> ipAddress; + private final Optional<String> ip4Address; + private final Optional<String> ip6Address; private final Optional<DnsZone> dnsZone; private final Set<Integer> ports; private final Set<String> networks; @@ -30,11 +31,12 @@ public class LoadBalancerInstance { private final List<PrivateServiceId> serviceIds; private final CloudAccount cloudAccount; - public LoadBalancerInstance(Optional<DomainName> hostname, Optional<String> ipAddress, + public LoadBalancerInstance(Optional<DomainName> hostname, Optional<String> ip4Address, Optional<String> ip6Address, Optional<DnsZone> dnsZone, Set<Integer> ports, Set<String> networks, Set<Real> reals, ZoneEndpoint settings, List<PrivateServiceId> serviceIds, CloudAccount cloudAccount) { this.hostname = Objects.requireNonNull(hostname, "hostname must be non-null"); - this.ipAddress = Objects.requireNonNull(ipAddress, "ip must be non-null"); + this.ip4Address = Objects.requireNonNull(ip4Address, "ip4Address must be non-null"); + this.ip6Address = Objects.requireNonNull(ip6Address, "ip6Address must be non-null"); this.dnsZone = Objects.requireNonNull(dnsZone, "dnsZone must be non-null"); this.ports = ImmutableSortedSet.copyOf(requirePorts(ports)); this.networks = ImmutableSortedSet.copyOf(Objects.requireNonNull(networks, "networks must be non-null")); @@ -43,9 +45,9 @@ public class LoadBalancerInstance { this.serviceIds = List.copyOf(Objects.requireNonNull(serviceIds, "private service id must be non-null")); this.cloudAccount = Objects.requireNonNull(cloudAccount, "cloudAccount must be non-null"); - if (hostname.isEmpty() == ipAddress.isEmpty()) { - throw new IllegalArgumentException("Exactly 1 of hostname=%s and ipAddress=%s must be set".formatted( - hostname.map(DomainName::value).orElse("<empty>"), ipAddress.orElse("<empty>"))); + if (hostname.isEmpty() == ip4Address.isEmpty()) { + throw new IllegalArgumentException("Exactly 1 of hostname=%s and ip4Address=%s must be set".formatted( + hostname.map(DomainName::value).orElse("<empty>"), ip4Address.orElse("<empty>"))); } } @@ -54,9 +56,14 @@ public class LoadBalancerInstance { return hostname; } - /** IP address of this (public) load balancer */ - public Optional<String> ipAddress() { - return ipAddress; + /** IPv4 address of this (public) load balancer */ + public Optional<String> ip4Address() { + return ip4Address; + } + + /** IPv6 address of this (public) load balancer */ + public Optional<String> ip6Address() { + return ip6Address; } /** ID of the DNS zone associated with this */ @@ -114,7 +121,7 @@ public class LoadBalancerInstance { public LoadBalancerInstance with(Set<Real> reals, ZoneEndpoint settings, Optional<PrivateServiceId> serviceId) { List<PrivateServiceId> ids = new ArrayList<>(serviceIds); serviceId.filter(id -> ! ids.contains(id)).ifPresent(ids::add); - return new LoadBalancerInstance(hostname, ipAddress, dnsZone, ports, networks, + return new LoadBalancerInstance(hostname, ip4Address, ip6Address, dnsZone, ports, networks, reals, settings, ids, cloudAccount); } @@ -123,7 +130,7 @@ public class LoadBalancerInstance { public LoadBalancerInstance withServiceIds(List<PrivateServiceId> serviceIds) { List<PrivateServiceId> ids = new ArrayList<>(serviceIds); for (PrivateServiceId id : this.serviceIds) if ( ! ids.contains(id)) ids.add(id); - return new LoadBalancerInstance(hostname, ipAddress, dnsZone, ports, networks, reals, settings, ids, cloudAccount); + return new LoadBalancerInstance(hostname, ip4Address, ip6Address, dnsZone, ports, networks, reals, settings, ids, cloudAccount); } } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java index a79766a577d..c79ccc2aece 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/LoadBalancerServiceMock.java @@ -57,6 +57,7 @@ public class LoadBalancerServiceMock implements LoadBalancerService { var instance = new LoadBalancerInstance( Optional.of(DomainName.of("lb-" + spec.application().toShortString() + "-" + spec.cluster().value())), Optional.empty(), + Optional.empty(), Optional.of(new DnsZone("zone-id-1")), Collections.singleton(4443), ImmutableSet.of("10.2.3.0/24", "10.4.5.0/24"), diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java index e49d1b302cf..073662b39fe 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/lb/SharedLoadBalancerService.java @@ -45,6 +45,7 @@ public class SharedLoadBalancerService implements LoadBalancerService { return new LoadBalancerInstance(Optional.of(DomainName.of(vipHostname)), Optional.empty(), Optional.empty(), + Optional.empty(), Set.of(4443), Set.of(), spec.reals(), diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java index 8638087c5cd..eb290d9ec2a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/AutoscalingMaintainer.java @@ -59,7 +59,7 @@ public class AutoscalingMaintainer extends NodeRepositoryMaintainer { int failures = 0; outer: for (var applicationNodes : activeNodesByApplication().entrySet()) { - boolean enabled = enabledFlag.with(FetchVector.Dimension.APPLICATION_ID, + boolean enabled = enabledFlag.with(FetchVector.Dimension.INSTANCE_ID, applicationNodes.getKey().serializedForm()).value(); if (!enabled) continue; for (var clusterNodes : nodesByCluster(applicationNodes.getValue()).entrySet()) { diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/PeriodicApplicationMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/PeriodicApplicationMaintainer.java index 14693c75436..9cbf1778b3b 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/PeriodicApplicationMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/PeriodicApplicationMaintainer.java @@ -58,7 +58,7 @@ public class PeriodicApplicationMaintainer extends ApplicationMaintainer { private boolean shouldMaintain(ApplicationId id) { BooleanFlag skipMaintenanceDeployment = PermanentFlags.SKIP_MAINTENANCE_DEPLOYMENT.bindTo(flagSource) - .with(FetchVector.Dimension.APPLICATION_ID, id.serializedForm()); + .with(FetchVector.Dimension.INSTANCE_ID, id.serializedForm()); return ! skipMaintenanceDeployment.value(); } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java index 3c3868bfeb8..8ad975f5334 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/CuratorDb.java @@ -47,6 +47,7 @@ import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; import java.util.stream.IntStream; +import java.util.stream.Stream; import static com.yahoo.stream.CustomCollectors.toLinkedMap; import static java.util.stream.Collectors.collectingAndThen; @@ -456,7 +457,12 @@ public class CuratorDb { transaction.onCommitted(() -> { for (var lb : loadBalancers) { if (lb.state() == fromState) continue; - Optional<String> target = lb.instance().flatMap(instance -> instance.hostname().map(DomainName::value).or(instance::ipAddress)); + Optional<String> target = lb.instance() + .flatMap(instance -> instance.hostname() + .map(DomainName::value) + .or(() -> Optional.of(Stream.concat(instance.ip4Address().stream(), + instance.ip6Address().stream()) + .collect(Collectors.joining(","))))); if (fromState == null) { log.log(Level.INFO, () -> "Creating " + lb.id() + target.map(t -> " (" + t + ")").orElse("") + " in " + lb.state()); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java index b85d96c6b54..d329676f842 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializer.java @@ -45,6 +45,7 @@ public class LoadBalancerSerializer { private static final String idField = "id"; private static final String hostnameField = "hostname"; private static final String lbIpAddressField = "ipAddress"; + private static final String lbIp6AddressField = "ip6Address"; private static final String stateField = "state"; private static final String changedAtField = "changedAt"; private static final String dnsZoneField = "dnsZone"; @@ -69,7 +70,8 @@ public class LoadBalancerSerializer { root.setString(idField, loadBalancer.id().serializedForm()); loadBalancer.instance().flatMap(LoadBalancerInstance::hostname).ifPresent(hostname -> root.setString(hostnameField, hostname.value())); - loadBalancer.instance().flatMap(LoadBalancerInstance::ipAddress).ifPresent(ip -> root.setString(lbIpAddressField, ip)); + loadBalancer.instance().flatMap(LoadBalancerInstance::ip4Address).ifPresent(ip -> root.setString(lbIpAddressField, ip)); + loadBalancer.instance().flatMap(LoadBalancerInstance::ip6Address).ifPresent(ip -> root.setString(lbIp6AddressField, ip)); root.setString(stateField, asString(loadBalancer.state())); root.setLong(changedAtField, loadBalancer.changedAt().toEpochMilli()); loadBalancer.instance().flatMap(LoadBalancerInstance::dnsZone).ifPresent(dnsZone -> root.setString(dnsZoneField, dnsZone.id())); @@ -123,7 +125,8 @@ public class LoadBalancerSerializer { object.field(networksField).traverse((ArrayTraverser) (i, network) -> networks.add(network.asString())); Optional<DomainName> hostname = optionalString(object.field(hostnameField), Function.identity()).filter(s -> !s.isEmpty()).map(DomainName::of); - Optional<String> ipAddress = optionalString(object.field(lbIpAddressField), Function.identity()).filter(s -> !s.isEmpty()); + Optional<String> ip4Address = optionalString(object.field(lbIpAddressField), Function.identity()).filter(s -> !s.isEmpty()); + Optional<String> ip6Address = optionalString(object.field(lbIp6AddressField), Function.identity()).filter(s -> !s.isEmpty()); Optional<DnsZone> dnsZone = optionalString(object.field(dnsZoneField), DnsZone::new); ZoneEndpoint settings = zoneEndpoint(object.field(settingsField)); Optional<PrivateServiceId> serviceId = optionalString(object.field(serviceIdField), PrivateServiceId::of); @@ -131,9 +134,9 @@ public class LoadBalancerSerializer { object.field(serviceIdsField).traverse((ArrayTraverser) (__, serviceIdObject) -> serviceIds.add(PrivateServiceId.of(serviceIdObject.asString()))); if (serviceIds.isEmpty()) serviceId.ifPresent(serviceIds::add); // TODO: remove after winter vacation '23 CloudAccount cloudAccount = optionalString(object.field(cloudAccountField), CloudAccount::from).orElse(CloudAccount.empty); - Optional<LoadBalancerInstance> instance = hostname.isEmpty() && ipAddress.isEmpty() + Optional<LoadBalancerInstance> instance = hostname.isEmpty() && ip4Address.isEmpty() && ip6Address.isEmpty() ? Optional.empty() - : Optional.of(new LoadBalancerInstance(hostname, ipAddress, dnsZone, ports, networks, reals, settings, serviceIds, cloudAccount)); + : Optional.of(new LoadBalancerInstance(hostname, ip4Address, ip6Address, dnsZone, ports, networks, reals, settings, serviceIds, cloudAccount)); return new LoadBalancer(LoadBalancerId.fromSerializedForm(object.field(idField).asString()), instance, diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java index 4236f78336b..e5599ac3d18 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/CapacityPolicies.java @@ -18,7 +18,7 @@ import java.util.Map; import java.util.TreeMap; import static com.yahoo.config.provision.NodeResources.Architecture; -import static com.yahoo.vespa.flags.FetchVector.Dimension.APPLICATION_ID; +import static com.yahoo.vespa.flags.FetchVector.Dimension.INSTANCE_ID; import static java.util.Objects.requireNonNull; /** @@ -146,7 +146,7 @@ public class CapacityPolicies { } private Architecture adminClusterArchitecture(ApplicationId instance) { - return Architecture.valueOf(adminClusterNodeArchitecture.with(APPLICATION_ID, instance.serializedForm()).value()); + return Architecture.valueOf(adminClusterNodeArchitecture.with(INSTANCE_ID, instance.serializedForm()).value()); } /** Returns the resources for the newest version not newer than that requested in the cluster spec. */ diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java index c414c70f315..22909122079 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/LoadBalancerProvisioner.java @@ -273,7 +273,7 @@ public class LoadBalancerProvisioner { LoadBalancer currentLoadBalancer, ZoneEndpoint zoneEndpoint, CloudAccount cloudAccount) { - boolean shouldDeactivateRouting = deactivateRouting.with(FetchVector.Dimension.APPLICATION_ID, + boolean shouldDeactivateRouting = deactivateRouting.with(FetchVector.Dimension.INSTANCE_ID, id.application().serializedForm()) .value(); Set<Real> reals = shouldDeactivateRouting ? Set.of() : realsOf(nodes, cloudAccount); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java index b289a965567..ca170d2af6b 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeAllocation.java @@ -94,7 +94,7 @@ class NodeAllocation { this.nextIndex = nextIndex; this.nodeRepository = nodeRepository; this.requiredHostFlavor = Optional.of(PermanentFlags.HOST_FLAVOR.bindTo(nodeRepository.flagSource()) - .with(FetchVector.Dimension.APPLICATION_ID, application.serializedForm()) + .with(FetchVector.Dimension.INSTANCE_ID, application.serializedForm()) .with(FetchVector.Dimension.CLUSTER_TYPE, cluster.type().name()) .with(FetchVector.Dimension.CLUSTER_ID, cluster.id().value()) .value()) diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeResourceLimits.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeResourceLimits.java index 06ab9eb1a10..ab222c45252 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeResourceLimits.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/provisioning/NodeResourceLimits.java @@ -88,7 +88,7 @@ public class NodeResourceLimits { if (cluster.type() == ClusterSpec.Type.admin) return 1; if (!exclusive) return 4; return minExclusiveAdvertisedMemoryGbFlag - .with(FetchVector.Dimension.APPLICATION_ID, applicationId.serializedForm()) + .with(FetchVector.Dimension.INSTANCE_ID, applicationId.serializedForm()) .with(FetchVector.Dimension.CLUSTER_ID, cluster.id().value()) .with(FetchVector.Dimension.CLUSTER_TYPE, cluster.type().name()) .value(); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java index 09f947503f6..20aa7d8181e 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/LoadBalancersResponse.java @@ -57,7 +57,8 @@ public class LoadBalancersResponse extends SlimeJsonResponse { lbObject.setString("instance", lb.id().application().instance().value()); lbObject.setString("cluster", lb.id().cluster().value()); lb.instance().flatMap(LoadBalancerInstance::hostname).ifPresent(hostname -> lbObject.setString("hostname", hostname.value())); - lb.instance().flatMap(LoadBalancerInstance::ipAddress).ifPresent(ipAddress -> lbObject.setString("ipAddress", ipAddress)); + lb.instance().flatMap(LoadBalancerInstance::ip4Address).ifPresent(ip -> lbObject.setString("ipAddress", ip)); + lb.instance().flatMap(LoadBalancerInstance::ip6Address).ifPresent(ip -> lbObject.setString("ip6Address", ip)); lb.instance().flatMap(LoadBalancerInstance::dnsZone).ifPresent(dnsZone -> lbObject.setString("dnsZone", dnsZone.id())); Cursor networkArray = lbObject.setArray("networks"); diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java index 2b908efde94..a8f526544d7 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/restapi/NodesResponse.java @@ -201,7 +201,7 @@ class NodesResponse extends SlimeJsonResponse { .with(FetchVector.Dimension.HOSTNAME, node.hostname()) .with(FetchVector.Dimension.NODE_TYPE, node.type().name()) .with(FetchVector.Dimension.TENANT_ID, allocation.owner().tenant().value()) - .with(FetchVector.Dimension.APPLICATION_ID, allocation.owner().serializedForm()) + .with(FetchVector.Dimension.INSTANCE_ID, allocation.owner().serializedForm()) .with(FetchVector.Dimension.CLUSTER_TYPE, allocation.membership().cluster().type().name()) .with(FetchVector.Dimension.CLUSTER_ID, allocation.membership().cluster().id().value()) .with(FetchVector.Dimension.VESPA_VERSION, allocation.membership().cluster().vespaVersion().toFullString()) diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java index 6dc681ae5c8..b5257e23d9e 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/persistence/LoadBalancerSerializerTest.java @@ -40,6 +40,7 @@ public class LoadBalancerSerializerTest { Optional.of(new LoadBalancerInstance( Optional.of(DomainName.of("lb-host")), Optional.empty(), + Optional.empty(), Optional.of(new DnsZone("zone-id-1")), Set.of(4080, 4443), Set.of("10.2.3.4/24"), @@ -73,6 +74,7 @@ public class LoadBalancerSerializerTest { Optional.of(new LoadBalancerInstance( Optional.empty(), Optional.of("1.2.3.4"), + Optional.of("fd00::1"), Optional.of(new DnsZone("zone-id-1")), Set.of(4443), Set.of("10.2.3.4/24", "12.3.2.1/30"), @@ -86,6 +88,8 @@ public class LoadBalancerSerializerTest { var serialized = LoadBalancerSerializer.fromJson(LoadBalancerSerializer.toJson(loadBalancer)); assertEquals(loadBalancer.id(), serialized.id()); assertEquals(loadBalancer.instance().get().hostname(), serialized.instance().get().hostname()); + assertEquals(loadBalancer.instance().get().ip4Address(), serialized.instance().get().ip4Address()); + assertEquals(loadBalancer.instance().get().ip6Address(), serialized.instance().get().ip6Address()); assertEquals(loadBalancer.instance().get().dnsZone(), serialized.instance().get().dnsZone()); assertEquals(loadBalancer.instance().get().ports(), serialized.instance().get().ports()); assertEquals(loadBalancer.instance().get().networks(), serialized.instance().get().networks()); diff --git a/searchcore/src/tests/proton/matching/matching_test.cpp b/searchcore/src/tests/proton/matching/matching_test.cpp index 6ef462f80c4..ec549ee6f71 100644 --- a/searchcore/src/tests/proton/matching/matching_test.cpp +++ b/searchcore/src/tests/proton/matching/matching_test.cpp @@ -1135,12 +1135,15 @@ TEST("require that docsum matcher can extract matching elements from single attr EXPECT_EQUAL(list[1], 3u); } +using FMA = vespalib::FuzzyMatchingAlgorithm; + struct AttributeBlueprintParamsFixture { BlueprintFactory factory; search::fef::test::IndexEnvironment index_env; RankSetup rank_setup; Properties rank_properties; - AttributeBlueprintParamsFixture(double lower_limit, double upper_limit, double target_hits_max_adjustment_factor) + AttributeBlueprintParamsFixture(double lower_limit, double upper_limit, double target_hits_max_adjustment_factor, + FMA fuzzy_matching_algorithm) : factory(), index_env(), rank_setup(factory, index_env), @@ -1149,36 +1152,41 @@ struct AttributeBlueprintParamsFixture { rank_setup.set_global_filter_lower_limit(lower_limit); rank_setup.set_global_filter_upper_limit(upper_limit); rank_setup.set_target_hits_max_adjustment_factor(target_hits_max_adjustment_factor); + rank_setup.set_fuzzy_matching_algorithm(fuzzy_matching_algorithm); } void set_query_properties(vespalib::stringref lower_limit, vespalib::stringref upper_limit, - vespalib::stringref target_hits_max_adjustment_factor) { + vespalib::stringref target_hits_max_adjustment_factor, + const vespalib::string fuzzy_matching_algorithm) { rank_properties.add(GlobalFilterLowerLimit::NAME, lower_limit); rank_properties.add(GlobalFilterUpperLimit::NAME, upper_limit); rank_properties.add(TargetHitsMaxAdjustmentFactor::NAME, target_hits_max_adjustment_factor); + rank_properties.add(FuzzyAlgorithm::NAME, fuzzy_matching_algorithm); } AttributeBlueprintParams extract(uint32_t active_docids = 9, uint32_t docid_limit = 10) const { return MatchToolsFactory::extract_attribute_blueprint_params(rank_setup, rank_properties, active_docids, docid_limit); } }; -TEST_F("attribute blueprint params are extracted from rank profile", AttributeBlueprintParamsFixture(0.2, 0.8, 5.0)) +TEST_F("attribute blueprint params are extracted from rank profile", AttributeBlueprintParamsFixture(0.2, 0.8, 5.0, FMA::BruteForce)) { auto params = f.extract(); EXPECT_EQUAL(0.2, params.global_filter_lower_limit); EXPECT_EQUAL(0.8, params.global_filter_upper_limit); EXPECT_EQUAL(5.0, params.target_hits_max_adjustment_factor); + EXPECT_EQUAL(FMA::BruteForce, params.fuzzy_matching_algorithm); } -TEST_F("attribute blueprint params are extracted from query", AttributeBlueprintParamsFixture(0.2, 0.8, 5.0)) +TEST_F("attribute blueprint params are extracted from query", AttributeBlueprintParamsFixture(0.2, 0.8, 5.0, FMA::BruteForce)) { - f.set_query_properties("0.15", "0.75", "3.0"); + f.set_query_properties("0.15", "0.75", "3.0", "dfa_explicit"); auto params = f.extract(); EXPECT_EQUAL(0.15, params.global_filter_lower_limit); EXPECT_EQUAL(0.75, params.global_filter_upper_limit); EXPECT_EQUAL(3.0, params.target_hits_max_adjustment_factor); + EXPECT_EQUAL(FMA::DfaExplicit, params.fuzzy_matching_algorithm); } -TEST_F("global filter params are scaled with active hit ratio", AttributeBlueprintParamsFixture(0.2, 0.8, 5.0)) +TEST_F("global filter params are scaled with active hit ratio", AttributeBlueprintParamsFixture(0.2, 0.8, 5.0, FMA::BruteForce)) { auto params = f.extract(5, 10); EXPECT_EQUAL(0.12, params.global_filter_lower_limit); diff --git a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp index f62f4c60a6c..5ae671b88cb 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/match_tools.cpp @@ -331,6 +331,7 @@ MatchToolsFactory::extract_attribute_blueprint_params(const RankSetup& rank_setu double lower_limit = GlobalFilterLowerLimit::lookup(rank_properties, rank_setup.get_global_filter_lower_limit()); double upper_limit = GlobalFilterUpperLimit::lookup(rank_properties, rank_setup.get_global_filter_upper_limit()); double target_hits_max_adjustment_factor = TargetHitsMaxAdjustmentFactor::lookup(rank_properties, rank_setup.get_target_hits_max_adjustment_factor()); + auto fuzzy_matching_algorithm = FuzzyAlgorithm::lookup(rank_properties, rank_setup.get_fuzzy_matching_algorithm()); // Note that we count the reserved docid 0 as active. // This ensures that when searchable-copies=1, the ratio is 1.0. @@ -338,7 +339,8 @@ MatchToolsFactory::extract_attribute_blueprint_params(const RankSetup& rank_setu return {lower_limit * active_hit_ratio, upper_limit * active_hit_ratio, - target_hits_max_adjustment_factor}; + target_hits_max_adjustment_factor, + fuzzy_matching_algorithm}; } AttributeOperationTask::AttributeOperationTask(const RequestContext & requestContext, diff --git a/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp b/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp index 1c7b8b2b695..975a2918026 100644 --- a/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp +++ b/searchlib/src/tests/attribute/enum_comparator/enum_comparator_test.cpp @@ -4,12 +4,16 @@ #include <vespa/searchlib/attribute/dfa_string_comparator.h> #include <vespa/vespalib/btree/btreeroot.h> #include <vespa/vespalib/gtest/gtest.h> +#include <vespa/vespalib/text/lowercase.h> +#include <vespa/vespalib/text/utf8.h> #include <vespa/searchlib/attribute/enumstore.hpp> using namespace vespalib::btree; using vespalib::datastore::AtomicEntryRef; +using vespalib::LowerCase; +using vespalib::Utf8ReaderForZTS; namespace vespalib::datastore { @@ -18,6 +22,22 @@ std::ostream & operator << (std::ostream& os, const EntryRef& ref) { } } + +namespace { + +std::vector<uint32_t> as_utf32(const char* key) +{ + std::vector<uint32_t> result; + Utf8ReaderForZTS reader(key); + while (reader.hasMore()) { + uint32_t code_point = reader.getChar(); + result.push_back(code_point); + } + return result; +} + +} + namespace search { using NumericEnumStore = EnumStoreT<int32_t>; @@ -253,14 +273,16 @@ TEST(DfaStringComparatorTest, require_that_less_is_working) EnumIndex e1 = es.insert("Aa"); EnumIndex e2 = es.insert("aa"); EnumIndex e3 = es.insert("aB"); - DfaStringComparator cmp1(es.get_data_store(), "aa"); + auto aa_utf32 = as_utf32("aa"); + DfaStringComparator cmp1(es.get_data_store(), aa_utf32); EXPECT_FALSE(cmp1.less(EnumIndex(), e1)); EXPECT_FALSE(cmp1.less(EnumIndex(), e2)); EXPECT_TRUE(cmp1.less(EnumIndex(), e3)); EXPECT_FALSE(cmp1.less(e1, EnumIndex())); EXPECT_FALSE(cmp1.less(e2, EnumIndex())); EXPECT_FALSE(cmp1.less(e3, EnumIndex())); - DfaStringComparator cmp2(es.get_data_store(), "Aa"); + auto Aa_utf32 = as_utf32("Aa"); + DfaStringComparator cmp2(es.get_data_store(), Aa_utf32); EXPECT_TRUE(cmp2.less(EnumIndex(), e1)); EXPECT_TRUE(cmp2.less(EnumIndex(), e2)); EXPECT_TRUE(cmp2.less(EnumIndex(), e3)); diff --git a/searchlib/src/tests/ranksetup/ranksetup_test.cpp b/searchlib/src/tests/ranksetup/ranksetup_test.cpp index f708df0a862..8d51eb56cc3 100644 --- a/searchlib/src/tests/ranksetup/ranksetup_test.cpp +++ b/searchlib/src/tests/ranksetup/ranksetup_test.cpp @@ -536,6 +536,7 @@ void RankSetupTest::testRankSetup() env.getProperties().add(matching::GlobalFilterLowerLimit::NAME, "0.3"); env.getProperties().add(matching::GlobalFilterUpperLimit::NAME, "0.7"); env.getProperties().add(matching::TargetHitsMaxAdjustmentFactor::NAME, "5.0"); + env.getProperties().add(matching::FuzzyAlgorithm::NAME, "dfa_implicit"); RankSetup rs(_factory, env); EXPECT_FALSE(rs.has_match_features()); @@ -577,6 +578,7 @@ void RankSetupTest::testRankSetup() EXPECT_EQUAL(rs.get_global_filter_lower_limit(), 0.3); EXPECT_EQUAL(rs.get_global_filter_upper_limit(), 0.7); EXPECT_EQUAL(rs.get_target_hits_max_adjustment_factor(), 5.0); + EXPECT_EQUAL(rs.get_fuzzy_matching_algorithm(), vespalib::FuzzyMatchingAlgorithm::DfaImplicit); } bool diff --git a/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp b/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp index e00cf109f8e..c0353e53bd1 100644 --- a/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp +++ b/searchlib/src/tests/util/folded_string_compare/folded_string_compare_test.cpp @@ -3,12 +3,32 @@ #include <vespa/searchlib/util/foldedstringcompare.h> #include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/text/lowercase.h> +#include <vespa/vespalib/text/utf8.h> using search::FoldedStringCompare; +using vespalib::LowerCase; +using vespalib::Utf8ReaderForZTS; using IntVec = std::vector<int>; using StringVec = std::vector<vespalib::string>; +namespace { + +template <bool fold> +std::vector<uint32_t> as_utf32(const char* key) +{ + std::vector<uint32_t> result; + Utf8ReaderForZTS reader(key); + while (reader.hasMore()) { + uint32_t code_point = fold ? LowerCase::convert(reader.getChar()) : reader.getChar(); + result.push_back(code_point); + } + return result; +} + +} + class FoldedStringCompareTest : public ::testing::Test { protected: @@ -21,10 +41,22 @@ protected: template <bool fold_lhs, bool fold_rhs> int - compare_folded_helper(const vespalib::string& lhs, const vespalib::string& rhs) + compare_folded_helper2(const vespalib::string& lhs, const vespalib::string& rhs) { int ret = FoldedStringCompare::compareFolded<fold_lhs, fold_rhs>(lhs.c_str(), rhs.c_str()); - EXPECT_EQ(-ret, (FoldedStringCompare::compareFolded<fold_rhs, fold_lhs>(rhs.c_str(), lhs.c_str()))); + auto folded_lhs_utf32 = as_utf32<fold_lhs>(lhs.c_str()); + EXPECT_EQ(ret, (FoldedStringCompare::compareFolded<false, fold_rhs>(std::cref(folded_lhs_utf32), rhs.c_str()))); + auto folded_rhs_utf32 = as_utf32<fold_rhs>(rhs.c_str()); + EXPECT_EQ(ret, (FoldedStringCompare::compareFolded<fold_lhs, false>(lhs.c_str(), std::cref(folded_rhs_utf32)))); + return ret; + } + + template <bool fold_lhs, bool fold_rhs> + int + compare_folded_helper(const vespalib::string& lhs, const vespalib::string& rhs) + { + int ret = compare_folded_helper2<fold_lhs, fold_rhs>(lhs, rhs); + EXPECT_EQ(-ret, (compare_folded_helper2<fold_rhs, fold_lhs>(rhs, lhs))); return ret; } diff --git a/searchlib/src/vespa/searchcommon/attribute/search_context_params.h b/searchlib/src/vespa/searchcommon/attribute/search_context_params.h index 8ed7eadf919..1c3b32bd777 100644 --- a/searchlib/src/vespa/searchcommon/attribute/search_context_params.h +++ b/searchlib/src/vespa/searchcommon/attribute/search_context_params.h @@ -3,6 +3,8 @@ #pragma once #include "i_document_meta_store_context.h" +#include <vespa/searchlib/fef/indexproperties.h> +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> #include <cstddef> #include <limits> #include <cstdint> @@ -21,6 +23,8 @@ private: uint32_t _diversityCutoffGroups; bool _useBitVector; bool _diversityCutoffStrict; + vespalib::FuzzyMatchingAlgorithm _fuzzy_matching_algorithm; + public: SearchContextParams() @@ -28,13 +32,15 @@ public: _metaStoreReadGuard(nullptr), _diversityCutoffGroups(std::numeric_limits<uint32_t>::max()), _useBitVector(false), - _diversityCutoffStrict(false) + _diversityCutoffStrict(false), + _fuzzy_matching_algorithm(search::fef::indexproperties::matching::FuzzyAlgorithm::DEFAULT_VALUE) { } bool useBitVector() const { return _useBitVector; } const IAttributeVector * diversityAttribute() const { return _diversityAttribute; } uint32_t diversityCutoffGroups() const { return _diversityCutoffGroups; } bool diversityCutoffStrict() const { return _diversityCutoffStrict; } const IDocumentMetaStoreContext::IReadGuard::SP * metaStoreReadGuard() const { return _metaStoreReadGuard; } + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm() const { return _fuzzy_matching_algorithm; } SearchContextParams &useBitVector(bool value) { _useBitVector = value; @@ -56,6 +62,10 @@ public: _metaStoreReadGuard = readGuard; return *this; } + SearchContextParams& fuzzy_matching_algorithm(vespalib::FuzzyMatchingAlgorithm value) { + _fuzzy_matching_algorithm = value; + return *this; + } }; } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 453b7b321b9..1519bb14554 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -713,6 +713,7 @@ public: template <class TermNode> void visitTerm(TermNode &n) { SearchContextParams scParams = createContextParams(_field.isFilter()); + scParams.fuzzy_matching_algorithm(getRequestContext().get_attribute_blueprint_params().fuzzy_matching_algorithm); const string stack = StackDumpCreator::create(n); setResult(std::make_unique<AttributeFieldBlueprint>(_field, _attr, stack, scParams)); } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_params.h b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_params.h index 64213235c23..1f9a3ebfa7e 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_params.h +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_params.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/searchlib/fef/indexproperties.h> +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search::attribute { @@ -14,20 +15,24 @@ struct AttributeBlueprintParams double global_filter_lower_limit; double global_filter_upper_limit; double target_hits_max_adjustment_factor; + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm; AttributeBlueprintParams(double global_filter_lower_limit_in, double global_filter_upper_limit_in, - double target_hits_max_adjustment_factor_in) + double target_hits_max_adjustment_factor_in, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm_in) : global_filter_lower_limit(global_filter_lower_limit_in), global_filter_upper_limit(global_filter_upper_limit_in), - target_hits_max_adjustment_factor(target_hits_max_adjustment_factor_in) + target_hits_max_adjustment_factor(target_hits_max_adjustment_factor_in), + fuzzy_matching_algorithm(fuzzy_matching_algorithm_in) { } AttributeBlueprintParams() : AttributeBlueprintParams(fef::indexproperties::matching::GlobalFilterLowerLimit::DEFAULT_VALUE, fef::indexproperties::matching::GlobalFilterUpperLimit::DEFAULT_VALUE, - fef::indexproperties::matching::TargetHitsMaxAdjustmentFactor::DEFAULT_VALUE) + fef::indexproperties::matching::TargetHitsMaxAdjustmentFactor::DEFAULT_VALUE, + fef::indexproperties::matching::FuzzyAlgorithm::DEFAULT_VALUE) { } }; diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h index 6b873020994..fcba13f85a4 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h +++ b/searchlib/src/vespa/searchlib/attribute/dfa_fuzzy_matcher.h @@ -17,7 +17,7 @@ namespace search::attribute { class DfaFuzzyMatcher { private: vespalib::fuzzy::LevenshteinDfa _dfa; - std::string _successor; + std::vector<uint32_t> _successor; public: DfaFuzzyMatcher(std::string_view target, uint8_t max_edits, bool cased, vespalib::fuzzy::LevenshteinDfa::DfaType dfa_type); @@ -29,7 +29,7 @@ public: if (match.matches()) { return true; } else { - DfaStringComparator cmp(data_store, _successor.c_str()); + DfaStringComparator cmp(data_store, _successor); itr.seek(vespalib::datastore::AtomicEntryRef(), cmp); return false; } diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.cpp b/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.cpp index ddbe4fd110f..e9710553ef1 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.cpp +++ b/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.cpp @@ -5,8 +5,9 @@ namespace search::attribute { -DfaStringComparator::DfaStringComparator(const DataStoreType& data_store, const char* candidate) - : ParentType(data_store, candidate) +DfaStringComparator::DfaStringComparator(const DataStoreType& data_store, const std::vector<uint32_t>& candidate) + : ParentType(data_store), + _candidate(std::cref(candidate)) { } @@ -17,13 +18,13 @@ DfaStringComparator::less(const vespalib::datastore::EntryRef lhs, const vespali if (rhs.valid()) { return FoldedStringCompare::compareFolded<true, true>(get(lhs), get(rhs)) < 0; } else { - return FoldedStringCompare::compareFolded<true, false>(get(lhs), get(rhs)) < 0; + return FoldedStringCompare::compareFolded<true, false>(get(lhs), _candidate) < 0; } } else { if (rhs.valid()) { - return FoldedStringCompare::compareFolded<false, true>(get(lhs), get(rhs)) < 0; + return FoldedStringCompare::compareFolded<false, true>(_candidate, get(rhs)) < 0; } else { - return FoldedStringCompare::compareFolded<false, false>(get(lhs), get(rhs)) < 0; + return false; } } } diff --git a/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.h b/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.h index 7ef14aa1719..8c80035c8fb 100644 --- a/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.h +++ b/searchlib/src/vespa/searchlib/attribute/dfa_string_comparator.h @@ -4,6 +4,7 @@ #include "i_enum_store.h" #include <vespa/vespalib/datastore/unique_store_string_comparator.h> +#include <functional> namespace search::attribute { @@ -24,9 +25,10 @@ public: using DataStoreType = ParentType::DataStoreType; private: using ParentType::get; + std::reference_wrapper<const std::vector<uint32_t>> _candidate; public: - DfaStringComparator(const DataStoreType& data_store, const char* candidate); + DfaStringComparator(const DataStoreType& data_store, const std::vector<uint32_t>& candidate); bool less(const vespalib::datastore::EntryRef lhs, const vespalib::datastore::EntryRef rhs) const override; }; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.h b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.h index 3ae342be61b..f418e698585 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.h @@ -4,6 +4,8 @@ #include "multi_string_enum_search_context.h" #include "enumhintsearchcontext.h" +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> + namespace search::attribute { @@ -17,7 +19,12 @@ class MultiStringEnumHintSearchContext : public MultiStringEnumSearchContext<M>, public EnumHintSearchContext { public: - MultiStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, MultiValueMappingReadView<M> mv_mapping_read_view, const EnumStoreT<const char*>& enum_store, uint32_t doc_id_limit, uint64_t num_values); + MultiStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + MultiValueMappingReadView<M> mv_mapping_read_view, + const EnumStoreT<const char*>& enum_store, + uint32_t doc_id_limit, uint64_t num_values); ~MultiStringEnumHintSearchContext() override; }; diff --git a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.hpp index fc1f72c940f..f4b96a46e3d 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_hint_search_context.hpp @@ -6,8 +6,13 @@ namespace search::attribute { template <typename M> -MultiStringEnumHintSearchContext<M>::MultiStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, MultiValueMappingReadView<M> mv_mapping_read_view, const EnumStoreT<const char*>& enum_store, uint32_t doc_id_limit, uint64_t num_values) - : MultiStringEnumSearchContext<M>(std::move(qTerm), cased, toBeSearched, mv_mapping_read_view, enum_store), +MultiStringEnumHintSearchContext<M>::MultiStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + MultiValueMappingReadView<M> mv_mapping_read_view, + const EnumStoreT<const char*>& enum_store, + uint32_t doc_id_limit, uint64_t num_values) + : MultiStringEnumSearchContext<M>(std::move(qTerm), cased, fuzzy_matching_algorithm, toBeSearched, mv_mapping_read_view, enum_store), EnumHintSearchContext(enum_store.get_dictionary(), doc_id_limit, num_values) { diff --git a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.h index 1787ea0086d..c9b8e8271b1 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.h @@ -4,6 +4,7 @@ #include "multi_enum_search_context.h" #include "string_search_context.h" +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search::attribute { @@ -15,7 +16,11 @@ template <typename M> class MultiStringEnumSearchContext : public MultiEnumSearchContext<const char*, StringSearchContext, M> { public: - MultiStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, MultiValueMappingReadView<M> mv_mapping_read_view, const EnumStoreT<const char*>& enum_store); + MultiStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + MultiValueMappingReadView<M> mv_mapping_read_view, + const EnumStoreT<const char*>& enum_store); }; } diff --git a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.hpp b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.hpp index 1d74db04373..48d1e8b6406 100644 --- a/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multi_string_enum_search_context.hpp @@ -9,8 +9,12 @@ namespace search::attribute { template <typename M> -MultiStringEnumSearchContext<M>::MultiStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, MultiValueMappingReadView<M> mv_mapping_read_view, const EnumStoreT<const char*>& enum_store) - : MultiEnumSearchContext<const char*, StringSearchContext, M>(StringMatcher(std::move(qTerm), cased), toBeSearched, mv_mapping_read_view, enum_store) +MultiStringEnumSearchContext<M>::MultiStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + MultiValueMappingReadView<M> mv_mapping_read_view, + const EnumStoreT<const char*>& enum_store) + : MultiEnumSearchContext<const char*, StringSearchContext, M>(StringMatcher(std::move(qTerm), cased, fuzzy_matching_algorithm), toBeSearched, mv_mapping_read_view, enum_store) { } diff --git a/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp index 43bb1c5ebb0..53e5f0d2e12 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multistringattribute.hpp @@ -46,11 +46,12 @@ MultiValueStringAttributeT<B, M>::freezeEnumDictionary() template <typename B, typename M> std::unique_ptr<attribute::SearchContext> MultiValueStringAttributeT<B, M>::getSearch(QueryTermSimpleUP qTerm, - const attribute::SearchContextParams &) const + const attribute::SearchContextParams ¶ms) const { bool cased = this->get_match_is_cased(); auto doc_id_limit = this->getCommittedDocIdLimit(); - return std::make_unique<attribute::MultiStringEnumHintSearchContext<M>>(std::move(qTerm), cased, *this, this->_mvMapping.make_read_view(doc_id_limit), this->_enumStore, doc_id_limit, this->getStatus().getNumValues()); + return std::make_unique<attribute::MultiStringEnumHintSearchContext<M>>(std::move(qTerm), cased, params.fuzzy_matching_algorithm(), + *this, this->_mvMapping.make_read_view(doc_id_limit), this->_enumStore, doc_id_limit, this->getStatus().getNumValues()); } template <typename B, typename M> diff --git a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp index fe52b785fa7..3da6357bb53 100644 --- a/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/multistringpostattribute.hpp @@ -99,7 +99,7 @@ MultiValueStringPostingAttributeT<B, T>::getSearch(QueryTermSimpleUP qTerm, using SC = attribute::StringPostingSearchContext<BaseSC, SelfType, int32_t>; bool cased = this->get_match_is_cased(); auto doc_id_limit = this->getCommittedDocIdLimit(); - BaseSC base_sc(std::move(qTerm), cased, *this, this->_mvMapping.make_read_view(doc_id_limit), this->_enumStore); + BaseSC base_sc(std::move(qTerm), cased, params.fuzzy_matching_algorithm(), *this, this->_mvMapping.make_read_view(doc_id_limit), this->_enumStore); return std::make_unique<SC>(std::move(base_sc), params.useBitVector(), *this); } diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp index 2d1748cefa5..95ba37d85be 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.cpp @@ -5,8 +5,13 @@ namespace search::attribute { -SingleStringEnumHintSearchContext::SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store, uint64_t num_values) - : SingleStringEnumSearchContext(std::move(qTerm), cased, toBeSearched, enum_indices, enum_store), +SingleStringEnumHintSearchContext::SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + EnumIndices enum_indices, + const EnumStoreT<const char*>& enum_store, + uint64_t num_values) + : SingleStringEnumSearchContext(std::move(qTerm), cased, fuzzy_matching_algorithm, toBeSearched, enum_indices, enum_store), EnumHintSearchContext(enum_store.get_dictionary(), enum_indices.size(), num_values) { diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h index f157bf17a71..595d1ac8c57 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_hint_search_context.h @@ -4,6 +4,7 @@ #include "single_string_enum_search_context.h" #include "enumhintsearchcontext.h" +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search::attribute { @@ -16,7 +17,12 @@ class SingleStringEnumHintSearchContext : public SingleStringEnumSearchContext, public EnumHintSearchContext { public: - SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store, uint64_t num_values); + SingleStringEnumHintSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + EnumIndices enum_indices, + const EnumStoreT<const char*>& enum_store, + uint64_t num_values); ~SingleStringEnumHintSearchContext() override; }; diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp index 8d23eaf7af0..42aebe9f814 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.cpp @@ -6,8 +6,13 @@ namespace search::attribute { -SingleStringEnumSearchContext::SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store) - : SingleEnumSearchContext<const char*, StringSearchContext>(StringMatcher(std::move(qTerm), cased), toBeSearched, enum_indices, enum_store) +SingleStringEnumSearchContext::SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + EnumIndices enum_indices, + const EnumStoreT<const char*>& enum_store) + : SingleEnumSearchContext<const char*, StringSearchContext>(StringMatcher(std::move(qTerm), cased, fuzzy_matching_algorithm), + toBeSearched, enum_indices, enum_store) { } diff --git a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h index b8014b1b0e3..71c62af33aa 100644 --- a/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/single_string_enum_search_context.h @@ -4,6 +4,7 @@ #include "single_enum_search_context.h" #include "string_search_context.h" +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search::attribute { @@ -14,7 +15,11 @@ namespace search::attribute { class SingleStringEnumSearchContext : public SingleEnumSearchContext<const char*, StringSearchContext> { public: - SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, const AttributeVector& toBeSearched, EnumIndices enum_indices, const EnumStoreT<const char*>& enum_store); + SingleStringEnumSearchContext(std::unique_ptr<QueryTermSimple> qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm, + const AttributeVector& toBeSearched, + EnumIndices enum_indices, + const EnumStoreT<const char*>& enum_store); SingleStringEnumSearchContext(SingleStringEnumSearchContext&&) noexcept; ~SingleStringEnumSearchContext() override; }; diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp index c3f5c295260..c4c6fc97053 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlestringattribute.hpp @@ -43,11 +43,12 @@ SingleValueStringAttributeT<B>::freezeEnumDictionary() template <typename B> std::unique_ptr<attribute::SearchContext> SingleValueStringAttributeT<B>::getSearch(QueryTermSimpleUP qTerm, - const attribute::SearchContextParams &) const + const attribute::SearchContextParams& params) const { bool cased = this->get_match_is_cased(); auto docid_limit = this->getCommittedDocIdLimit(); - return std::make_unique<attribute::SingleStringEnumHintSearchContext>(std::move(qTerm), cased, *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore, this->getStatus().getNumValues()); + return std::make_unique<attribute::SingleStringEnumHintSearchContext>(std::move(qTerm), cased, params.fuzzy_matching_algorithm(), + *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore, this->getStatus().getNumValues()); } } diff --git a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp index 60847636baa..20d672411f8 100644 --- a/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp +++ b/searchlib/src/vespa/searchlib/attribute/singlestringpostattribute.hpp @@ -146,7 +146,7 @@ SingleValueStringPostingAttributeT<B>::getSearch(QueryTermSimpleUP qTerm, using SC = attribute::StringPostingSearchContext<BaseSC, SelfType, vespalib::btree::BTreeNoLeafData>; bool cased = this->get_match_is_cased(); auto docid_limit = this->getCommittedDocIdLimit(); - BaseSC base_sc(std::move(qTerm), cased, *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore); + BaseSC base_sc(std::move(qTerm), cased, params.fuzzy_matching_algorithm(), *this, this->_enumIndices.make_read_view(docid_limit), this->_enumStore); return std::make_unique<SC>(std::move(base_sc), params.useBitVector(), *this); diff --git a/searchlib/src/vespa/searchlib/attribute/string_matcher.cpp b/searchlib/src/vespa/searchlib/attribute/string_matcher.cpp index bc3637e7215..8b755d5f3b1 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_matcher.cpp +++ b/searchlib/src/vespa/searchlib/attribute/string_matcher.cpp @@ -5,9 +5,9 @@ namespace search::attribute { -StringMatcher::StringMatcher(std::unique_ptr<QueryTermSimple> query_term, bool cased) +StringMatcher::StringMatcher(std::unique_ptr<QueryTermSimple> query_term, bool cased, vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm) : _query_term(static_cast<QueryTermUCS4 *>(query_term.release())), - _helper(*_query_term, cased) + _helper(*_query_term, cased, fuzzy_matching_algorithm) { } diff --git a/searchlib/src/vespa/searchlib/attribute/string_matcher.h b/searchlib/src/vespa/searchlib/attribute/string_matcher.h index ea4debecc0d..05089e1251a 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_matcher.h +++ b/searchlib/src/vespa/searchlib/attribute/string_matcher.h @@ -3,6 +3,7 @@ #pragma once #include "string_search_helper.h" +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search { class QueryTermSimple; } @@ -18,7 +19,7 @@ private: std::unique_ptr<QueryTermUCS4> _query_term; attribute::StringSearchHelper _helper; public: - StringMatcher(std::unique_ptr<QueryTermSimple> qTerm, bool cased); + StringMatcher(std::unique_ptr<QueryTermSimple> qTerm, bool cased, vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm); StringMatcher(StringMatcher&&) noexcept; ~StringMatcher(); protected: diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_context.cpp b/searchlib/src/vespa/searchlib/attribute/string_search_context.cpp index fadf7a3151d..119b4a60d0c 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_context.cpp +++ b/searchlib/src/vespa/searchlib/attribute/string_search_context.cpp @@ -9,9 +9,10 @@ namespace search::attribute { -StringSearchContext::StringSearchContext(const AttributeVector& to_be_searched, std::unique_ptr<QueryTermSimple> query_term, bool cased) +StringSearchContext::StringSearchContext(const AttributeVector& to_be_searched, std::unique_ptr<QueryTermSimple> query_term, + bool cased, vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm) : SearchContext(to_be_searched), - StringMatcher(std::move(query_term), cased) + StringMatcher(std::move(query_term), cased, fuzzy_matching_algorithm) { } diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_context.h b/searchlib/src/vespa/searchlib/attribute/string_search_context.h index a0014379436..e459153d2b8 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_context.h +++ b/searchlib/src/vespa/searchlib/attribute/string_search_context.h @@ -4,6 +4,7 @@ #include "search_context.h" #include "string_matcher.h" +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search { @@ -24,7 +25,8 @@ class StringSearchContext : public SearchContext, public StringMatcher protected: using MatcherType = StringMatcher; public: - StringSearchContext(const AttributeVector& to_be_searched, std::unique_ptr<QueryTermSimple> query_term, bool cased); + StringSearchContext(const AttributeVector& to_be_searched, std::unique_ptr<QueryTermSimple> query_term, + bool cased, vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm); StringSearchContext(const AttributeVector& to_be_searched, StringMatcher&& matcher); StringSearchContext(StringSearchContext &&) noexcept; ~StringSearchContext() override; diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp b/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp index 60c00a043d0..1efe39667b8 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp +++ b/searchlib/src/vespa/searchlib/attribute/string_search_helper.cpp @@ -9,7 +9,7 @@ namespace search::attribute { -StringSearchHelper::StringSearchHelper(QueryTermUCS4 & term, bool cased) +StringSearchHelper::StringSearchHelper(QueryTermUCS4 & term, bool cased, vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm) : _regex(), _fuzzyMatcher(), _term(), @@ -24,6 +24,8 @@ StringSearchHelper::StringSearchHelper(QueryTermUCS4 & term, bool cased) ? vespalib::Regex::from_pattern(term.getTerm(), vespalib::Regex::Options::None) : vespalib::Regex::from_pattern(term.getTerm(), vespalib::Regex::Options::IgnoreCase); } else if (isFuzzy()) { + (void) fuzzy_matching_algorithm; + // TODO: Select implementation based on algorithm. _fuzzyMatcher = std::make_unique<vespalib::FuzzyMatcher>(term.getTerm(), term.getFuzzyMaxEditDistance(), term.getFuzzyPrefixLength(), diff --git a/searchlib/src/vespa/searchlib/attribute/string_search_helper.h b/searchlib/src/vespa/searchlib/attribute/string_search_helper.h index 3db0d4dbb5f..0e7a116a874 100644 --- a/searchlib/src/vespa/searchlib/attribute/string_search_helper.h +++ b/searchlib/src/vespa/searchlib/attribute/string_search_helper.h @@ -2,6 +2,7 @@ #pragma once +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> #include <vespa/vespalib/regex/regex.h> namespace vespalib { class FuzzyMatcher; } @@ -16,7 +17,8 @@ namespace search::attribute { class StringSearchHelper { public: using FuzzyMatcher = vespalib::FuzzyMatcher; - StringSearchHelper(QueryTermUCS4 & qTerm, bool cased); + StringSearchHelper(QueryTermUCS4 & qTerm, bool cased, + vespalib::FuzzyMatchingAlgorithm fuzzy_matching_algorithm = vespalib::FuzzyMatchingAlgorithm::BruteForce); StringSearchHelper(StringSearchHelper&&) noexcept; StringSearchHelper(const StringSearchHelper &) = delete; StringSearchHelper & operator =(const StringSearchHelper &) = delete; diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp index 7871e66970e..b006aebbcdb 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.cpp +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.cpp @@ -438,6 +438,22 @@ TargetHitsMaxAdjustmentFactor::lookup(const Properties& props, double defaultVal return lookupDouble(props, NAME, defaultValue); } +const vespalib::string FuzzyAlgorithm::NAME("vespa.matching.fuzzy.algorithm"); +const vespalib::FuzzyMatchingAlgorithm FuzzyAlgorithm::DEFAULT_VALUE(vespalib::FuzzyMatchingAlgorithm::BruteForce); + +vespalib::FuzzyMatchingAlgorithm +FuzzyAlgorithm::lookup(const Properties& props) +{ + return lookup(props, DEFAULT_VALUE); +} + +vespalib::FuzzyMatchingAlgorithm +FuzzyAlgorithm::lookup(const Properties& props, vespalib::FuzzyMatchingAlgorithm default_value) +{ + auto value = lookupString(props, NAME, vespalib::to_string(default_value)); + return vespalib::fuzzy_matching_algorithm_from_string(value, default_value); +} + } // namespace matching namespace softtimeout { diff --git a/searchlib/src/vespa/searchlib/fef/indexproperties.h b/searchlib/src/vespa/searchlib/fef/indexproperties.h index 4f38a27d3fe..1f16d6b5f57 100644 --- a/searchlib/src/vespa/searchlib/fef/indexproperties.h +++ b/searchlib/src/vespa/searchlib/fef/indexproperties.h @@ -2,9 +2,10 @@ #pragma once +#include <vespa/searchlib/common/feature.h> +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> #include <vespa/vespalib/stllike/string.h> #include <vector> -#include <vespa/searchlib/common/feature.h> namespace search::fef { class Properties; } @@ -328,6 +329,16 @@ namespace matching { static double lookup(const Properties &props); static double lookup(const Properties &props, double defaultValue); }; + + /** + * Property to control the algorithm using for fuzzy matching. + **/ + struct FuzzyAlgorithm { + static const vespalib::string NAME; + static const vespalib::FuzzyMatchingAlgorithm DEFAULT_VALUE; + static vespalib::FuzzyMatchingAlgorithm lookup(const Properties& props); + static vespalib::FuzzyMatchingAlgorithm lookup(const Properties& props, vespalib::FuzzyMatchingAlgorithm default_value); + }; } namespace softtimeout { diff --git a/searchlib/src/vespa/searchlib/fef/ranksetup.cpp b/searchlib/src/vespa/searchlib/fef/ranksetup.cpp index 9d4e547feef..02b56701cdb 100644 --- a/searchlib/src/vespa/searchlib/fef/ranksetup.cpp +++ b/searchlib/src/vespa/searchlib/fef/ranksetup.cpp @@ -69,6 +69,7 @@ RankSetup::RankSetup(const BlueprintFactory &factory, const IIndexEnvironment &i _global_filter_lower_limit(0.0), _global_filter_upper_limit(1.0), _target_hits_max_adjustment_factor(20.0), + _fuzzy_matching_algorithm(vespalib::FuzzyMatchingAlgorithm::BruteForce), _mutateOnMatch(), _mutateOnFirstPhase(), _mutateOnSecondPhase(), @@ -123,6 +124,7 @@ RankSetup::configure() set_global_filter_lower_limit(matching::GlobalFilterLowerLimit::lookup(_indexEnv.getProperties())); set_global_filter_upper_limit(matching::GlobalFilterUpperLimit::lookup(_indexEnv.getProperties())); set_target_hits_max_adjustment_factor(matching::TargetHitsMaxAdjustmentFactor::lookup(_indexEnv.getProperties())); + set_fuzzy_matching_algorithm(matching::FuzzyAlgorithm::lookup(_indexEnv.getProperties())); _mutateOnMatch._attribute = mutate::on_match::Attribute::lookup(_indexEnv.getProperties()); _mutateOnMatch._operation = mutate::on_match::Operation::lookup(_indexEnv.getProperties()); _mutateOnFirstPhase._attribute = mutate::on_first_phase::Attribute::lookup(_indexEnv.getProperties()); diff --git a/searchlib/src/vespa/searchlib/fef/ranksetup.h b/searchlib/src/vespa/searchlib/fef/ranksetup.h index 72432c2ed8a..3170f965e58 100644 --- a/searchlib/src/vespa/searchlib/fef/ranksetup.h +++ b/searchlib/src/vespa/searchlib/fef/ranksetup.h @@ -8,6 +8,7 @@ #include "blueprintresolver.h" #include "rank_program.h" #include <vespa/searchlib/common/stringmap.h> +#include <vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h> namespace search::fef { @@ -77,6 +78,7 @@ private: double _global_filter_lower_limit; double _global_filter_upper_limit; double _target_hits_max_adjustment_factor; + vespalib::FuzzyMatchingAlgorithm _fuzzy_matching_algorithm; MutateOperation _mutateOnMatch; MutateOperation _mutateOnFirstPhase; MutateOperation _mutateOnSecondPhase; @@ -396,6 +398,8 @@ public: double get_global_filter_upper_limit() const { return _global_filter_upper_limit; } void set_target_hits_max_adjustment_factor(double v) { _target_hits_max_adjustment_factor = v; } double get_target_hits_max_adjustment_factor() const { return _target_hits_max_adjustment_factor; } + void set_fuzzy_matching_algorithm(vespalib::FuzzyMatchingAlgorithm v) { _fuzzy_matching_algorithm = v; } + vespalib::FuzzyMatchingAlgorithm get_fuzzy_matching_algorithm() const { return _fuzzy_matching_algorithm; } /** * This method may be used to indicate that certain features diff --git a/searchlib/src/vespa/searchlib/util/foldedstringcompare.cpp b/searchlib/src/vespa/searchlib/util/foldedstringcompare.cpp index a61a12cebf6..53b9a2db31d 100644 --- a/searchlib/src/vespa/searchlib/util/foldedstringcompare.cpp +++ b/searchlib/src/vespa/searchlib/util/foldedstringcompare.cpp @@ -8,6 +8,45 @@ using vespalib::LowerCase; using vespalib::Utf8ReaderForZTS; namespace search { +using Utf32VectorRef = std::reference_wrapper<const std::vector<uint32_t>>; + +namespace foldedstringcompare { + +class Utf32Reader { + using Iterator = typename std::vector<uint32_t>::const_iterator; + + Iterator _cur; + Iterator _end; +public: + Utf32Reader(const std::vector<uint32_t>& key) + : _cur(key.begin()), + _end(key.end()) + { + } + + bool hasMore() const noexcept { return _cur != _end; } + uint32_t getChar() noexcept { return *_cur++; } +}; + +template <typename T> class FoldableStringHelper; + +template <> class FoldableStringHelper<const char*> +{ +public: + using Reader = Utf8ReaderForZTS; +}; + +template <> class FoldableStringHelper<Utf32VectorRef> +{ +public: + using Reader = Utf32Reader; +}; + +} + +template <typename KeyType> +using Reader = typename foldedstringcompare::FoldableStringHelper<KeyType>::Reader; + size_t FoldedStringCompare:: size(const char *key) @@ -15,15 +54,20 @@ size(const char *key) return Utf8ReaderForZTS::countChars(key); } -template <bool fold_lhs, bool fold_rhs> +template <bool fold_lhs, bool fold_rhs, detail::FoldableString KeyType, detail::FoldableString OKeyType> int FoldedStringCompare:: -compareFolded(const char *key, const char *okey) +compareFolded(KeyType key, OKeyType okey) { - Utf8ReaderForZTS kreader(key); - Utf8ReaderForZTS oreader(okey); + Reader<KeyType> kreader(key); + Reader<OKeyType> oreader(okey); for (;;) { + if (!kreader.hasMore()) { + return oreader.hasMore() ? -1 : 0; + } else if (!oreader.hasMore()) { + return 1; + } uint32_t kval = fold_lhs ? LowerCase::convert(kreader.getChar()) : kreader.getChar(); uint32_t oval = fold_rhs ? LowerCase::convert(oreader.getChar()) : oreader.getChar(); @@ -34,13 +78,9 @@ compareFolded(const char *key, const char *okey) return 1; } } - if (kval == 0) { - return 0; - } } } - template <bool fold_lhs, bool fold_rhs> int FoldedStringCompare:: @@ -91,6 +131,11 @@ compare(const char *key, const char *okey) return strcmp(key, okey); } +template int FoldedStringCompare::compareFolded<false, false>(const char* key, Utf32VectorRef okey); +template int FoldedStringCompare::compareFolded<true, false>(const char* key, Utf32VectorRef okey); +template int FoldedStringCompare::compareFolded<false, false>(Utf32VectorRef key, const char* okey); +template int FoldedStringCompare::compareFolded<false, true>(Utf32VectorRef key, const char* okey); + template int FoldedStringCompare::compareFolded<false, false>(const char* key, const char* okey); template int FoldedStringCompare::compareFolded<false, true>(const char* key, const char* okey); template int FoldedStringCompare::compareFolded<true, false>(const char* key, const char* okey); diff --git a/searchlib/src/vespa/searchlib/util/foldedstringcompare.h b/searchlib/src/vespa/searchlib/util/foldedstringcompare.h index fb842e190e2..cd7cd325667 100644 --- a/searchlib/src/vespa/searchlib/util/foldedstringcompare.h +++ b/searchlib/src/vespa/searchlib/util/foldedstringcompare.h @@ -3,9 +3,19 @@ #pragma once #include <cstddef> +#include <cstdint> +#include <functional> +#include <vector> namespace search { +namespace detail { + +template <typename T> +concept FoldableString = std::same_as<const char*,T> || std::same_as<std::reference_wrapper<const std::vector<uint32_t>>, T>; + +} + class FoldedStringCompare { public: @@ -20,12 +30,12 @@ public: /** * Compare UTF-8 key with UTF-8 other key after folding both * - * @param key NUL terminated UTF-8 string - * @param okey NUL terminated UTF-8 string + * @param key NUL terminated UTF-8 string or vector<uint32_t> + * @param okey NUL terminated UTF-8 string or vector<uint32_t> * @return integer -1 if key < okey, 0 if key == okey, 1 if key > okey **/ - template <bool fold_lhs, bool fold_rhs> - static int compareFolded(const char *key, const char *okey); + template <bool fold_lhs, bool fold_rhs, detail::FoldableString KeyType, detail::FoldableString OKeyType> + static int compareFolded(KeyType key, OKeyType okey); /** * Compare UTF-8 key with UTF-8 other key after folding both. diff --git a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt index 6228b2ecfda..2972ea7745e 100644 --- a/vespa-dependencies-enforcer/allowed-maven-dependencies.txt +++ b/vespa-dependencies-enforcer/allowed-maven-dependencies.txt @@ -54,17 +54,17 @@ io.dropwizard.metrics:metrics-core:4.2.19 io.jsonwebtoken:jjwt-api:0.11.5 io.jsonwebtoken:jjwt-impl:0.11.5 io.jsonwebtoken:jjwt-jackson:0.11.5 -io.netty:netty-buffer:4.1.97.Final -io.netty:netty-codec:4.1.97.Final -io.netty:netty-common:4.1.97.Final -io.netty:netty-handler:4.1.97.Final -io.netty:netty-resolver:4.1.97.Final +io.netty:netty-buffer:4.1.98.Final +io.netty:netty-codec:4.1.98.Final +io.netty:netty-common:4.1.98.Final +io.netty:netty-handler:4.1.98.Final +io.netty:netty-resolver:4.1.98.Final io.netty:netty-tcnative:2.0.61.Final io.netty:netty-tcnative-classes:2.0.61.Final -io.netty:netty-transport:4.1.97.Final -io.netty:netty-transport-classes-epoll:4.1.97.Final -io.netty:netty-transport-native-epoll:4.1.97.Final -io.netty:netty-transport-native-unix-common:4.1.97.Final +io.netty:netty-transport:4.1.98.Final +io.netty:netty-transport-classes-epoll:4.1.98.Final +io.netty:netty-transport-native-epoll:4.1.98.Final +io.netty:netty-transport-native-unix-common:4.1.98.Final io.prometheus:simpleclient:0.16.0 io.prometheus:simpleclient_common:0.16.0 io.prometheus:simpleclient_tracer_common:0.16.0 diff --git a/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt b/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt index bdbb03bcfee..5e8d29980cd 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt +++ b/vespalib/src/vespa/vespalib/fuzzy/CMakeLists.txt @@ -3,6 +3,7 @@ vespa_add_library(vespalib_vespalib_fuzzy OBJECT SOURCES explicit_levenshtein_dfa.cpp fuzzy_matcher.cpp + fuzzy_matching_algorithm.cpp implicit_levenshtein_dfa.cpp levenshtein_dfa.cpp levenshtein_distance.cpp diff --git a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h index 630e07738fb..b7ac35b9d19 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h +++ b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.h @@ -123,18 +123,17 @@ public: _nodes[from_node_idx].set_wildcard_out_edge(to_node_idx); } - [[nodiscard]] MatchResult match(std::string_view u8str, std::string* successor_out) const override; + [[nodiscard]] MatchResult match(std::string_view u8str) const override; - [[nodiscard]] MatchResult match(std::string_view u8str, std::vector<uint32_t>* successor_out) const override; + [[nodiscard]] MatchResult match(std::string_view u8str, std::string& successor_out) const override; + + [[nodiscard]] MatchResult match(std::string_view u8str, std::vector<uint32_t>& successor_out) const override; [[nodiscard]] size_t memory_usage() const noexcept override { return sizeof(DfaNodeType) * _nodes.size(); } void dump_as_graphviz(std::ostream& os) const override; -private: - template <typename SuccessorT> - [[nodiscard]] MatchResult match_impl(std::string_view u8str, SuccessorT* successor_out) const; }; template <typename Traits> diff --git a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp index 7860d841fbf..0a371d3b277 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp +++ b/vespalib/src/vespa/vespalib/fuzzy/explicit_levenshtein_dfa.hpp @@ -94,23 +94,24 @@ struct ExplicitDfaMatcher { }; template <uint8_t MaxEdits> -template <typename SuccessorT> LevenshteinDfa::MatchResult -ExplicitLevenshteinDfaImpl<MaxEdits>::match_impl(std::string_view u8str, SuccessorT* successor_out) const { +ExplicitLevenshteinDfaImpl<MaxEdits>::match(std::string_view u8str) const { ExplicitDfaMatcher<MaxEdits> matcher(_nodes, _is_cased); - return MatchAlgorithm<MaxEdits>::match(matcher, u8str, successor_out); + return MatchAlgorithm<MaxEdits>::match(matcher, u8str); } template <uint8_t MaxEdits> LevenshteinDfa::MatchResult -ExplicitLevenshteinDfaImpl<MaxEdits>::match(std::string_view u8str, std::string* successor_out) const { - return match_impl(u8str, successor_out); +ExplicitLevenshteinDfaImpl<MaxEdits>::match(std::string_view u8str, std::string& successor_out) const { + ExplicitDfaMatcher<MaxEdits> matcher(_nodes, _is_cased); + return MatchAlgorithm<MaxEdits>::match(matcher, u8str, successor_out); } template <uint8_t MaxEdits> LevenshteinDfa::MatchResult -ExplicitLevenshteinDfaImpl<MaxEdits>::match(std::string_view u8str, std::vector<uint32_t>* successor_out) const { - return match_impl(u8str, successor_out); +ExplicitLevenshteinDfaImpl<MaxEdits>::match(std::string_view u8str, std::vector<uint32_t>& successor_out) const { + ExplicitDfaMatcher<MaxEdits> matcher(_nodes, _is_cased); + return MatchAlgorithm<MaxEdits>::match(matcher, u8str, successor_out); } template <uint8_t MaxEdits> diff --git a/vespalib/src/vespa/vespalib/fuzzy/fuzzy_matching_algorithm.cpp b/vespalib/src/vespa/vespalib/fuzzy/fuzzy_matching_algorithm.cpp new file mode 100644 index 00000000000..826b0beffd6 --- /dev/null +++ b/vespalib/src/vespa/vespalib/fuzzy/fuzzy_matching_algorithm.cpp @@ -0,0 +1,51 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "fuzzy_matching_algorithm.h" + +namespace vespalib { + +namespace { + +const vespalib::string brute_force = "brute_force"; +const vespalib::string dfa_implicit = "dfa_implicit"; +const vespalib::string dfa_explicit = "dfa_explicit"; + +} + +vespalib::string +to_string(FuzzyMatchingAlgorithm algo) +{ + switch (algo) { + case FuzzyMatchingAlgorithm::BruteForce: + return brute_force; + case FuzzyMatchingAlgorithm::DfaImplicit: + return dfa_implicit; + case FuzzyMatchingAlgorithm::DfaExplicit: + return dfa_explicit; + default: + return ""; + } +} + +FuzzyMatchingAlgorithm +fuzzy_matching_algorithm_from_string(const vespalib::string& algo, + FuzzyMatchingAlgorithm default_algo) +{ + if (algo == brute_force) { + return FuzzyMatchingAlgorithm::BruteForce; + } else if (algo == dfa_implicit) { + return FuzzyMatchingAlgorithm::DfaImplicit; + } else if (algo == dfa_explicit) { + return FuzzyMatchingAlgorithm::DfaExplicit; + } + return default_algo; +} + +std::ostream& +operator<<(std::ostream& out, FuzzyMatchingAlgorithm algo) +{ + out << to_string(algo); + return out; +} + +} diff --git a/vespalib/src/vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h b/vespalib/src/vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h new file mode 100644 index 00000000000..83cb121fe5f --- /dev/null +++ b/vespalib/src/vespa/vespalib/fuzzy/fuzzy_matching_algorithm.h @@ -0,0 +1,26 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/vespalib/stllike/string.h> +#include <ostream> + +namespace vespalib { + +/** + * Algorithms that are supported for fuzzy matching. + */ +enum class FuzzyMatchingAlgorithm { + BruteForce, + DfaImplicit, + DfaExplicit +}; + +vespalib::string to_string(FuzzyMatchingAlgorithm algo); + +FuzzyMatchingAlgorithm fuzzy_matching_algorithm_from_string(const vespalib::string& algo, + FuzzyMatchingAlgorithm default_algo); + +std::ostream& operator<<(std::ostream& out, FuzzyMatchingAlgorithm algo); + +} diff --git a/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.h b/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.h index bb4a0918593..e12d8c3dedb 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.h +++ b/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.h @@ -27,9 +27,11 @@ public: ~ImplicitLevenshteinDfa() override = default; - [[nodiscard]] MatchResult match(std::string_view u8str, std::string* successor_out) const override; + [[nodiscard]] MatchResult match(std::string_view u8str) const override; - [[nodiscard]] MatchResult match(std::string_view u8str, std::vector<uint32_t>* successor_out) const override; + [[nodiscard]] MatchResult match(std::string_view u8str, std::string& successor_out) const override; + + [[nodiscard]] MatchResult match(std::string_view u8str, std::vector<uint32_t>& successor_out) const override; [[nodiscard]] size_t memory_usage() const noexcept override { return _u32_str_buf.size() * sizeof(uint32_t); @@ -37,9 +39,6 @@ public: void dump_as_graphviz(std::ostream& os) const override; private: - template <typename SuccessorT> - [[nodiscard]] MatchResult match_impl(std::string_view u8str, SuccessorT* successor_out) const; - void precompute_utf8_target_with_offsets(); }; diff --git a/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.hpp b/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.hpp index 25fd3fdcc4e..ff381b7ba7c 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.hpp +++ b/vespalib/src/vespa/vespalib/fuzzy/implicit_levenshtein_dfa.hpp @@ -135,23 +135,24 @@ struct ImplicitDfaMatcher : public DfaSteppingBase<Traits> { }; template <typename Traits> -template <typename SuccessorT> LevenshteinDfa::MatchResult -ImplicitLevenshteinDfa<Traits>::match_impl(std::string_view u8str, SuccessorT* successor_out) const { +ImplicitLevenshteinDfa<Traits>::match(std::string_view u8str) const { ImplicitDfaMatcher<Traits> matcher(_u32_str_buf, _target_as_utf8, _target_utf8_char_offsets, _is_cased); - return MatchAlgorithm<Traits::max_edits()>::match(matcher, u8str, successor_out); + return MatchAlgorithm<Traits::max_edits()>::match(matcher, u8str); } template <typename Traits> LevenshteinDfa::MatchResult -ImplicitLevenshteinDfa<Traits>::match(std::string_view u8str, std::string* successor_out) const { - return match_impl(u8str, successor_out); +ImplicitLevenshteinDfa<Traits>::match(std::string_view u8str, std::string& successor_out) const { + ImplicitDfaMatcher<Traits> matcher(_u32_str_buf, _target_as_utf8, _target_utf8_char_offsets, _is_cased); + return MatchAlgorithm<Traits::max_edits()>::match(matcher, u8str, successor_out); } template <typename Traits> LevenshteinDfa::MatchResult -ImplicitLevenshteinDfa<Traits>::match(std::string_view u8str, std::vector<uint32_t>* successor_out) const { - return match_impl(u8str, successor_out); +ImplicitLevenshteinDfa<Traits>::match(std::string_view u8str, std::vector<uint32_t>& successor_out) const { + ImplicitDfaMatcher<Traits> matcher(_u32_str_buf, _target_as_utf8, _target_utf8_char_offsets, _is_cased); + return MatchAlgorithm<Traits::max_edits()>::match(matcher, u8str, successor_out); } template <typename Traits> diff --git a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp index 6e38821851b..1caae408176 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp +++ b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.cpp @@ -19,17 +19,17 @@ LevenshteinDfa::~LevenshteinDfa() = default; LevenshteinDfa::MatchResult LevenshteinDfa::match(std::string_view u8str) const { - return _impl->match(u8str, static_cast<std::vector<uint32_t>*>(nullptr)); // TODO rewire + return _impl->match(u8str); } LevenshteinDfa::MatchResult LevenshteinDfa::match(std::string_view u8str, std::string& successor_out) const { - return _impl->match(u8str, &successor_out); + return _impl->match(u8str, successor_out); } LevenshteinDfa::MatchResult LevenshteinDfa::match(std::string_view u8str, std::vector<uint32_t>& successor_out) const { - return _impl->match(u8str, &successor_out); + return _impl->match(u8str, successor_out); } size_t LevenshteinDfa::memory_usage() const noexcept { diff --git a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h index 85ad98e2a09..c6ca06d4de3 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h +++ b/vespalib/src/vespa/vespalib/fuzzy/levenshtein_dfa.h @@ -140,8 +140,9 @@ public: struct Impl { virtual ~Impl() = default; - [[nodiscard]] virtual MatchResult match(std::string_view u8str, std::string* successor_out) const = 0; - [[nodiscard]] virtual MatchResult match(std::string_view u8str, std::vector<uint32_t>* successor_out) const = 0; + [[nodiscard]] virtual MatchResult match(std::string_view u8str) const = 0; + [[nodiscard]] virtual MatchResult match(std::string_view u8str, std::string& successor_out) const = 0; + [[nodiscard]] virtual MatchResult match(std::string_view u8str, std::vector<uint32_t>& successor_out) const = 0; [[nodiscard]] virtual size_t memory_usage() const noexcept = 0; virtual void dump_as_graphviz(std::ostream& out) const = 0; }; diff --git a/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp b/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp index 2b3c06aa7cf..fb5ec32abc7 100644 --- a/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp +++ b/vespalib/src/vespa/vespalib/fuzzy/match_algorithm.hpp @@ -149,7 +149,7 @@ struct MatchAlgorithm { template <DfaMatcher Matcher, typename SuccessorT> static MatchResult match(const Matcher& matcher, std::string_view source, - SuccessorT* successor_out) + SuccessorT& successor_out) { using StateType = typename Matcher::StateType; Utf8Reader u8_reader(source.data(), source.size()); @@ -166,7 +166,7 @@ struct MatchAlgorithm { if (raw_mch != mch) { can_use_raw_prefix = false; // FIXME this is pessimistic; considers entire string, not just prefix } - if (successor_out && matcher.has_higher_out_edge(state, mch)) { + if (matcher.has_higher_out_edge(state, mch)) { last_state_with_higher_out = state; n_prefix_u8_bytes = u8_pos_before_char; char_after_prefix = mch; @@ -175,14 +175,12 @@ struct MatchAlgorithm { if (matcher.can_match(maybe_next)) { state = maybe_next; } else { - // Can never match; find the successor if requested - if (successor_out) { - emit_successor_prefix(*successor_out, source, n_prefix_u8_bytes, - matcher.is_cased() || can_use_raw_prefix); - assert(matcher.valid_state(last_state_with_higher_out)); - backtrack_and_emit_greater_suffix(matcher, last_state_with_higher_out, - char_after_prefix, *successor_out); - } + // Can never match; find the successor + emit_successor_prefix(successor_out, source, n_prefix_u8_bytes, + matcher.is_cased() || can_use_raw_prefix); + assert(matcher.valid_state(last_state_with_higher_out)); + backtrack_and_emit_greater_suffix(matcher, last_state_with_higher_out, + char_after_prefix, successor_out); return MatchResult::make_mismatch(max_edits()); } } @@ -190,10 +188,33 @@ struct MatchAlgorithm { if (edits <= max_edits()) { return MatchResult::make_match(max_edits(), edits); } - if (successor_out) { - emit_successor_prefix(*successor_out, source, source.size(), - matcher.is_cased() || can_use_raw_prefix); - emit_smallest_matching_suffix(matcher, state, *successor_out); + emit_successor_prefix(successor_out, source, source.size(), + matcher.is_cased() || can_use_raw_prefix); + emit_smallest_matching_suffix(matcher, state, successor_out); + return MatchResult::make_mismatch(max_edits()); + } + + /** + * Simplified match loop which does _not_ emit a successor on mismatch. Otherwise the + * exact same semantics as the successor-emitting `match()` overload. + */ + template <DfaMatcher Matcher> + static MatchResult match(const Matcher& matcher, std::string_view source) { + using StateType = typename Matcher::StateType; + Utf8Reader u8_reader(source.data(), source.size()); + StateType state = matcher.start(); + while (u8_reader.hasMore()) { + const uint32_t mch = normalized_match_char(u8_reader.getChar(), matcher.is_cased()); + auto maybe_next = matcher.match_input(state, mch); + if (matcher.can_match(maybe_next)) { + state = maybe_next; + } else { + return MatchResult::make_mismatch(max_edits()); + } + } + const auto edits = matcher.match_edit_distance(state); + if (edits <= max_edits()) { + return MatchResult::make_match(max_edits(), edits); } return MatchResult::make_mismatch(max_edits()); } |