diff options
215 files changed, 2460 insertions, 2207 deletions
diff --git a/cloud-tenant-base-dependencies-enforcer/pom.xml b/cloud-tenant-base-dependencies-enforcer/pom.xml index 16dbd6b4014..414039ac492 100644 --- a/cloud-tenant-base-dependencies-enforcer/pom.xml +++ b/cloud-tenant-base-dependencies-enforcer/pom.xml @@ -235,6 +235,7 @@ <include>commons-beanutils:commons-beanutils:1.7.0:jar:test</include> <include>commons-codec:commons-codec:1.11:jar:test</include> <include>commons-digester:commons-digester:1.8:jar:test</include> + <include>io.airlift:aircompressor:0.17:jar:test</include> <include>io.airlift:airline:0.7:jar:test</include> <include>io.jsonwebtoken:jjwt:0.9.1:jar:test</include> <include>io.prometheus:simpleclient:0.6.0:jar:test</include> diff --git a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/SchemaValidator.java b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/SchemaValidator.java index 1f6822a770b..b0b1209aa90 100644 --- a/config-application-package/src/main/java/com/yahoo/config/model/application/provider/SchemaValidator.java +++ b/config-application-package/src/main/java/com/yahoo/config/model/application/provider/SchemaValidator.java @@ -101,7 +101,7 @@ public class SchemaValidator { } private String message(SAXParseException e) { - return "XML error in " + fileName + ": " + + return "Invalid XML according to XML schema, error in " + fileName + ": " + Exceptions.toMessageString(e) + " [" + e.getLineNumber() + ":" + e.getColumnNumber() + "]" + ", input:\n" + getErrorContext(e.getLineNumber()); diff --git a/config-lib/src/main/java/com/yahoo/config/InnerNode.java b/config-lib/src/main/java/com/yahoo/config/InnerNode.java index 94afe64b667..45f61fee315 100644 --- a/config-lib/src/main/java/com/yahoo/config/InnerNode.java +++ b/config-lib/src/main/java/com/yahoo/config/InnerNode.java @@ -64,6 +64,8 @@ public abstract class InnerNode extends Node { if ( !(other instanceof InnerNode) || (other.getClass() != this.getClass())) return false; + /* This implementation requires getChildren() to return elements in order. + Hence we should make it final. Or make equals independent of order. */ Collection<Object> children = getChildren().values(); Collection<Object> otherChildren = ((InnerNode)other).getChildren().values(); @@ -86,8 +88,9 @@ public abstract class InnerNode extends Node { return res; } + // TODO Make final before Vespa 8 as correct order is required protected Map<String, Object> getChildren() { - HashMap<String, Object> ret = new LinkedHashMap<String, Object>(); + HashMap<String, Object> ret = new LinkedHashMap<>(); Field fields[] = getClass().getDeclaredFields(); for (Field field : fields) { field.setAccessible(true); @@ -108,10 +111,11 @@ public abstract class InnerNode extends Node { /** * Returns a flat map of this node's direct children, including all NodeVectors' elements. * Keys are the node name, including index for vector elements, e.g. 'arr[0]'. + * TODO Make final before Vespa 8 as correct order is required */ @SuppressWarnings("unchecked") protected Map<String, Node> getChildrenWithVectorsFlattened() { - HashMap<String, Node> ret = new LinkedHashMap<String, Node>(); + HashMap<String, Node> ret = new LinkedHashMap<>(); Map<String, Object> children = getChildren(); for (Map.Entry<String, Object> childEntry : children.entrySet()) { @@ -151,7 +155,7 @@ public abstract class InnerNode extends Node { * @return map of leaf nodes */ private static Map<String, LeafNode<?>> getAllDescendantLeafNodes(String parentName, InnerNode node) { - Map<String, LeafNode<?>> ret = new LinkedHashMap<String, LeafNode<?>>(); + Map<String, LeafNode<?>> ret = new LinkedHashMap<>(); String prefix = parentName.isEmpty() ? "" : parentName + "."; Map<String, Node> children = node.getChildrenWithVectorsFlattened(); for (Map.Entry<String, Node> childEntry : children.entrySet()) { diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java index 1cf698af9cb..18dcba02dff 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/ModelContext.java @@ -31,7 +31,9 @@ public interface ModelContext { ApplicationPackage applicationPackage(); Optional<Model> previousModel(); Optional<ApplicationPackage> permanentApplicationPackage(); - Optional<HostProvisioner> hostProvisioner(); + // TODO: Remove after 7.338 has been released + default Optional<HostProvisioner> hostProvisioner() { return Optional.of(getHostProvisioner()); } + HostProvisioner getHostProvisioner(); Provisioned provisioned(); DeployLogger deployLogger(); ConfigDefinitionRepo configDefinitionRepo(); @@ -68,8 +70,7 @@ public interface ModelContext { @ModelFeatureFlag(owners = {"bjorncs", "jonmv"}) default double reindexerWindowSizeIncrement() { return 0.2; } @ModelFeatureFlag(owners = {"baldersheim"}, comment = "Revisit in May or June 2020") default double defaultTermwiseLimit() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"vekterli"}) default boolean useThreePhaseUpdates() { throw new UnsupportedOperationException("TODO specify default value"); } - @ModelFeatureFlag(owners = {"geirst"}, comment = "Remove on 7.XXX when this is default on") default boolean useDirectStorageApiRpc() { throw new UnsupportedOperationException("TODO specify default value"); } - @ModelFeatureFlag(owners = {"geirst"}, comment = "Remove when 7.328 is no longer in use") default boolean useFastValueTensorImplementation() { return true; } + @ModelFeatureFlag(owners = {"geirst"}, comment = "Remove when 7.336 is no longer in use") default boolean useDirectStorageApiRpc() { return true; } @ModelFeatureFlag(owners = {"baldersheim"}, comment = "Select sequencer type use while feeding") default String feedSequencerType() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"baldersheim"}) default String responseSequencerType() { throw new UnsupportedOperationException("TODO specify default value"); } @ModelFeatureFlag(owners = {"baldersheim"}) default int defaultNumResponseThreads() { return 2; } @@ -114,7 +115,7 @@ public interface ModelContext { // TODO(somebody): Only needed for LbServicesProducerTest default boolean useDedicatedNodeForLogserver() { return true; } - // NOTE: Use FeatureFlags interface above instead of non-permament flags + // NOTE: Use FeatureFlags interface above instead of non-permanent flags @Deprecated double defaultTermwiseLimit(); @Deprecated default int defaultNumResponseThreads() { return 2; } @Deprecated String feedSequencerType(); @@ -127,8 +128,7 @@ public interface ModelContext { @Deprecated int mergeChunkSize(); @Deprecated double feedConcurrency(); @Deprecated boolean useThreePhaseUpdates(); - @Deprecated boolean useDirectStorageApiRpc(); - @Deprecated default boolean useFastValueTensorImplementation() { return true; } + @Deprecated default boolean useDirectStorageApiRpc() { return true; } @Deprecated default boolean useAccessControlTlsHandshakeClientAuth() { return false; } } diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java index 3e7017b78e1..521b8f6998d 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/TestProperties.java @@ -37,13 +37,11 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea private final Set<ContainerEndpoint> endpoints = Collections.emptySet(); private boolean useDedicatedNodeForLogserver = false; private boolean useThreePhaseUpdates = false; - private boolean useDirectStorageApiRpc = false; - private boolean useFastValueTensorImplementation = true; private double defaultTermwiseLimit = 1.0; private String jvmGCOptions = null; private String sequencerType = "LATENCY"; private String responseSequencerType = "ADAPTIVE"; - private int reponseNumThreads = 2; + private int responseNumThreads = 2; private Optional<EndpointCertificateSecrets> endpointCertificateSecrets = Optional.empty(); private AthenzDomain athenzDomain; private ApplicationRoles applicationRoles; @@ -74,12 +72,11 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea @Override public Optional<EndpointCertificateSecrets> endpointCertificateSecrets() { return endpointCertificateSecrets; } @Override public double defaultTermwiseLimit() { return defaultTermwiseLimit; } @Override public boolean useThreePhaseUpdates() { return useThreePhaseUpdates; } - @Override public boolean useDirectStorageApiRpc() { return useDirectStorageApiRpc; } - @Override public boolean useFastValueTensorImplementation() { return useFastValueTensorImplementation; } + @Override public boolean useDirectStorageApiRpc() { return true; } @Override public Optional<AthenzDomain> athenzDomain() { return Optional.ofNullable(athenzDomain); } @Override public Optional<ApplicationRoles> applicationRoles() { return Optional.ofNullable(applicationRoles); } @Override public String responseSequencerType() { return responseSequencerType; } - @Override public int defaultNumResponseThreads() { return reponseNumThreads; } + @Override public int defaultNumResponseThreads() { return responseNumThreads; } @Override public boolean skipCommunicationManagerThread() { return false; } @Override public boolean skipMbusRequestThread() { return false; } @Override public boolean skipMbusReplyThread() { return false; } @@ -124,7 +121,7 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea return this; } public TestProperties setResponseNumThreads(int numThreads) { - reponseNumThreads = numThreads; + responseNumThreads = numThreads; return this; } public TestProperties setDefaultTermwiseLimit(double limit) { @@ -137,11 +134,6 @@ public class TestProperties implements ModelContext.Properties, ModelContext.Fea return this; } - public TestProperties setUseDirectStorageApiRpc(boolean useDirectStorageApiRpc) { - this.useDirectStorageApiRpc = useDirectStorageApiRpc; - return this; - } - public TestProperties setApplicationId(ApplicationId applicationId) { this.applicationId = applicationId; return this; diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java index 5de03f17958..d6673cd49e9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModelFactory.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.model; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; @@ -11,7 +11,6 @@ import com.yahoo.config.model.MapConfigModelRegistry; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.ConfigChangeAction; import com.yahoo.config.model.api.ConfigModelPlugin; -import com.yahoo.config.model.api.HostProvisioner; import com.yahoo.config.model.api.Model; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.ModelCreateResult; @@ -22,10 +21,8 @@ import com.yahoo.config.model.builder.xml.ConfigModelBuilder; import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.provision.TransientException; import com.yahoo.config.provision.Zone; -import com.yahoo.container.QrConfig; import com.yahoo.vespa.config.VespaVersion; import com.yahoo.vespa.model.application.validation.Validation; -import com.yahoo.vespa.model.container.ApplicationContainerCluster; import org.xml.sax.SAXException; import java.io.IOException; @@ -147,7 +144,7 @@ public class VespaModelFactory implements ModelFactory { .permanentApplicationPackage(modelContext.permanentApplicationPackage()) .properties(modelContext.properties()) .vespaVersion(version()) - .modelHostProvisioner(createHostProvisioner(modelContext)) + .modelHostProvisioner(modelContext.getHostProvisioner()) .provisioned(modelContext.provisioned()) .endpoints(modelContext.properties().endpoints()) .modelImporters(modelImporters) @@ -160,11 +157,6 @@ public class VespaModelFactory implements ModelFactory { return builder.build(validationParameters); } - private static HostProvisioner createHostProvisioner(ModelContext modelContext) { - return modelContext.hostProvisioner().orElse( - DeployState.getDefaultModelHostProvisioner(modelContext.applicationPackage())); - } - private void validateXML(ApplicationPackage applicationPackage, boolean ignoreValidationErrors) { try { applicationPackage.validateXML(); diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index 00797876395..f218c06754a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -149,6 +149,7 @@ public class RankSetupValidator extends Validator { private void writeExtraVerifyRanksetupConfig(String dir, DocumentDatabase db) throws IOException { String configName = "verify-ranksetup.cfg"; + String configContent = ""; // Assist verify-ranksetup in finding the actual ONNX model files Map<String, OnnxModel> models = db.getDerivedConfiguration().getSearch().onnxModels().asMap(); @@ -159,8 +160,9 @@ public class RankSetupValidator extends Validator { config.add(String.format("file[%d].ref \"%s\"", config.size() / 2, model.getFileReference())); config.add(String.format("file[%d].path \"%s\"", config.size() / 2, modelPath)); } - IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(config), false); + configContent = StringUtilities.implodeMultiline(config); } + IOUtils.writeFile(dir + configName, configContent, false); } public static String getFileRepositoryPath(Path path, String fileReference) { diff --git a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java index 2c209dd79bf..9780e9b503a 100644 --- a/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/ApplicationDeployTest.java @@ -203,7 +203,7 @@ public class ApplicationDeployTest { ApplicationPackageTester.create(tmpDir.getAbsolutePath()); fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("XML error in deployment.xml: element \"instance\" not allowed here; expected the element end-tag or element \"delay\", \"region\", \"steps\" or \"test\" [7:30], input:\n", e.getMessage()); + assertEquals("Invalid XML according to XML schema, error in deployment.xml: element \"instance\" not allowed here; expected the element end-tag or element \"delay\", \"region\", \"steps\" or \"test\" [7:30], input:\n", e.getMessage()); } } diff --git a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java index f8ab3cc54c8..98cbd363bca 100644 --- a/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java +++ b/config-model/src/test/java/com/yahoo/config/model/MockModelContext.java @@ -13,6 +13,7 @@ import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.MockFileRegistry; import com.yahoo.config.model.application.provider.StaticConfigDefinitionRepo; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; @@ -54,6 +55,11 @@ public class MockModelContext implements ModelContext { } @Override + public HostProvisioner getHostProvisioner() { + return DeployState.getDefaultModelHostProvisioner(applicationPackage); + } + + @Override public Provisioned provisioned() { return new Provisioned(); } @Override diff --git a/config-model/src/test/java/com/yahoo/config/model/application/provider/SchemaValidatorTest.java b/config-model/src/test/java/com/yahoo/config/model/application/provider/SchemaValidatorTest.java index a89378cb7ba..c2938746443 100644 --- a/config-model/src/test/java/com/yahoo/config/model/application/provider/SchemaValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/config/model/application/provider/SchemaValidatorTest.java @@ -77,7 +77,7 @@ public class SchemaValidatorTest { } private String expectedErrorMessage(String input) { - return "XML error in " + input + ": The element type \"config\" must be terminated by the matching end-tag \"</config>\". [7:5], input:\n" + + return "Invalid XML according to XML schema, error in " + input + ": The element type \"config\" must be terminated by the matching end-tag \"</config>\". [7:5], input:\n" + "4: <basicStruct>\n" + "5: <stringVal>default</stringVal>\n" + "6: </basicStruct>\n" + diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java deleted file mode 100644 index 14228968161..00000000000 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionFeatureArgumentsTestCase.java +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchdefinition; - -import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; -import com.yahoo.collections.Pair; -import com.yahoo.search.query.profile.QueryProfileRegistry; -import com.yahoo.searchdefinition.derived.AttributeFields; -import com.yahoo.searchdefinition.derived.RawRankProfile; -import com.yahoo.searchdefinition.parser.ParseException; -import org.junit.Test; - -import java.util.List; - -import static org.junit.Assert.assertEquals; - -/** - * @author lesters - */ -public class RankingExpressionFeatureArgumentsTestCase extends SchemaTestCase { - - @Test - public void testFeatureWithExpressionArguments() throws ParseException { - RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); - SearchBuilder builder = new SearchBuilder(rankProfileRegistry); - builder.importString( - "search test {\n" + - " document test { \n" + - " field t1 type tensor<float>(x{}) { \n" + - " indexing: attribute | summary \n" + - " }\n" + - " field t2 type tensor<float>(x{}) { \n" + - " indexing: attribute | summary \n" + - " }\n" + - " }\n" + - " rank-profile test {\n" + - " function my_func(t) {\n" + - " expression: sum(t, x) \n" + - " }\n" + - " function eval_func() {\n" + - " expression: my_func( attribute(t1) ) \n" + - " }\n" + - " function eval_func_with_expr() {\n" + - " expression: my_func( attribute(t1) * attribute(t2) ) \n" + - " }\n" + - " function eval_func_with_expr_2() {\n" + - " expression: my_func( attribute(t1){x:0} ) \n" + - " }\n" + - " function eval_func_via_func_with_expr() {\n" + - " expression: call_func_with_expr( attribute(t1), attribute(t2) ) \n" + - " }\n" + - " function call_func_with_expr(a, b) {\n" + - " expression: my_func( a * b ) \n" + - " }\n" + - " first-phase {\n" + - " expression: 42 \n" + - " }\n" + - " }\n" + - "\n" + - "}\n"); - builder.build(); - Search s = builder.getSearch(); - RankProfile test = rankProfileRegistry.get(s, "test").compile(new QueryProfileRegistry(), new ImportedMlModels()); - List<Pair<String, String>> testRankProperties = new RawRankProfile(test, - new QueryProfileRegistry(), - new ImportedMlModels(), - new AttributeFields(s)).configProperties(); - - for(Pair<String,String> prop : testRankProperties) { - System.out.println(prop); - } - - assertEquals("(rankingExpression(my_func).rankingScript, reduce(t, sum, x))", - testRankProperties.get(0).toString()); - - // eval_func - assertEquals("(rankingExpression(eval_func).rankingScript, rankingExpression(my_func@9bbaee2bad5a2fc0))", - testRankProperties.get(2).toString()); - assertEquals("(rankingExpression(my_func@9bbaee2bad5a2fc0).rankingScript, reduce(attribute(t1), sum, x))", - testRankProperties.get(1).toString()); - - // The following functions should generate features to evaluate the expression argument before passing to my_func - - // eval_func_with_expr - assertEquals("(rankingExpression(eval_func_with_expr).rankingScript, rankingExpression(my_func@45673ba956ae9b77))", - testRankProperties.get(5).toString()); - assertEquals("(rankingExpression(my_func@45673ba956ae9b77).rankingScript, reduce(autogenerated_ranking_feature@43bc412603c00a4a, sum, x))", - testRankProperties.get(4).toString()); - assertEquals("(rankingExpression(autogenerated_ranking_feature@43bc412603c00a4a).rankingScript, attribute(t1) * attribute(t2))", - testRankProperties.get(3).toString()); - - // eval_func_with_expr_2 - assertEquals("(rankingExpression(eval_func_with_expr_2).rankingScript, rankingExpression(my_func@2192533eaad2293d))", - testRankProperties.get(8).toString()); - assertEquals("(rankingExpression(my_func@2192533eaad2293d).rankingScript, reduce(autogenerated_ranking_feature@71a4196136b577cf, sum, x))", - testRankProperties.get(7).toString()); - assertEquals("(rankingExpression(autogenerated_ranking_feature@71a4196136b577cf).rankingScript, attribute(t1){x:0})", - testRankProperties.get(6).toString()); - - // eval_func_via_func_with_expr - assertEquals("(rankingExpression(eval_func_via_func_with_expr).rankingScript, rankingExpression(call_func_with_expr@640470df47a83000.c156faa8f98c0b0c))", - testRankProperties.get(10).toString()); - assertEquals("(rankingExpression(call_func_with_expr@640470df47a83000.c156faa8f98c0b0c).rankingScript, rankingExpression(my_func@45673ba956ae9b77))", - testRankProperties.get(9).toString()); - // my_func@45673ba956ae9b77 is the same as under eval_func_with_expr - - } - -}
\ No newline at end of file diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java index 20182c89a8c..84a6d2a154a 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionShadowingTestCase.java @@ -207,20 +207,22 @@ public class RankingExpressionShadowingTestCase extends SchemaTestCase { queryProfiles, new ImportedMlModels(), new AttributeFields(s)).configProperties(); - assertEquals("(rankingExpression(autogenerated_ranking_feature@).rankingScript, reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input))", + assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,reduce(query(q) * constant(W_hidden), sum, input) + constant(b_input)))", censorBindingHash(testRankProperties.get(0).toString())); - assertEquals("(rankingExpression(relu@).rankingScript, max(1.0,autogenerated_ranking_feature@))", - censorBindingHash(testRankProperties.get(1).toString())); assertEquals("(rankingExpression(hidden_layer).rankingScript, rankingExpression(relu@))", + censorBindingHash(testRankProperties.get(1).toString())); + assertEquals("(rankingExpression(hidden_layer).type, tensor(x[1]))", censorBindingHash(testRankProperties.get(2).toString())); assertEquals("(rankingExpression(final_layer).rankingScript, sigmoid(reduce(rankingExpression(hidden_layer) * constant(W_final), sum, hidden) + constant(b_final)))", + testRankProperties.get(3).toString()); + assertEquals("(rankingExpression(final_layer).type, tensor(x[1]))", testRankProperties.get(4).toString()); assertEquals("(rankingExpression(relu).rankingScript, max(1.0,x))", - testRankProperties.get(6).toString()); + testRankProperties.get(5).toString()); assertEquals("(vespa.rank.secondphase, rankingExpression(secondphase))", - testRankProperties.get(7).toString()); + testRankProperties.get(6).toString()); assertEquals("(rankingExpression(secondphase).rankingScript, reduce(rankingExpression(final_layer), sum))", - testRankProperties.get(8).toString()); + testRankProperties.get(7).toString()); } private QueryProfileRegistry queryProfileWith(String field, String type) { diff --git a/config-model/src/test/java/com/yahoo/vespa/model/VespaModelFactoryTest.java b/config-model/src/test/java/com/yahoo/vespa/model/VespaModelFactoryTest.java index a9bf8bdcc49..33f9d715801 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/VespaModelFactoryTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/VespaModelFactoryTest.java @@ -146,6 +146,9 @@ public class VespaModelFactoryTest { } @Override + public HostProvisioner getHostProvisioner() { return provisionerToOverride; } + + @Override public Properties properties() { return new TestProperties(); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerIncludeTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerIncludeTest.java index 7d4be4b5e33..4c90d415bf0 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerIncludeTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ContainerIncludeTest.java @@ -123,7 +123,7 @@ public class ContainerIncludeTest { creator.create(true); fail("Expected exception due to xml schema violation ('zearcer')"); } catch (IllegalArgumentException e) { - assertThat(e.getMessage(), containsString("XML error")); + assertThat(e.getMessage(), containsString("Invalid XML according to XML schema")); assertThat(e.getMessage(), containsString("zearcer")); } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java index 61056856242..9f17a1c4142 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/content/ContentClusterTest.java @@ -17,7 +17,6 @@ import com.yahoo.vespa.config.content.AllClustersBucketSpacesConfig; import com.yahoo.vespa.config.content.FleetcontrollerConfig; import com.yahoo.vespa.config.content.StorDistributionConfig; import com.yahoo.vespa.config.content.StorFilestorConfig; -import com.yahoo.vespa.config.content.core.StorCommunicationmanagerConfig; import com.yahoo.vespa.config.content.core.StorDistributormanagerConfig; import com.yahoo.vespa.config.content.core.StorServerConfig; import com.yahoo.vespa.config.search.DispatchConfig; @@ -967,33 +966,6 @@ public class ContentClusterTest extends ContentBaseTest { assertTrue(resolveThreePhaseUpdateConfigWithFeatureFlag(true)); } - void assertDirectStorageApiRpcConfig(boolean expUseDirectStorageApiRpc, ContentNode node) { - var builder = new StorCommunicationmanagerConfig.Builder(); - node.getConfig(builder); - var config = new StorCommunicationmanagerConfig(builder); - assertEquals(expUseDirectStorageApiRpc, config.use_direct_storageapi_rpc()); - } - - void assertDirectStorageApiRpcFlagIsPropagatedToConfig(boolean useDirectStorageApiRpc) { - VespaModel model = createEnd2EndOneNode(new TestProperties().setUseDirectStorageApiRpc(useDirectStorageApiRpc)); - - ContentCluster cc = model.getContentClusters().get("storage"); - assertFalse(cc.getDistributorNodes().getChildren().isEmpty()); - for (Distributor d : cc.getDistributorNodes().getChildren().values()) { - assertDirectStorageApiRpcConfig(useDirectStorageApiRpc, d); - } - assertFalse(cc.getStorageNodes().getChildren().isEmpty()); - for (StorageNode node : cc.getStorageNodes().getChildren().values()) { - assertDirectStorageApiRpcConfig(useDirectStorageApiRpc, node); - } - } - - @Test - public void use_direct_storage_api_rpc_config_is_controlled_by_properties() { - assertDirectStorageApiRpcFlagIsPropagatedToConfig(false); - assertDirectStorageApiRpcFlagIsPropagatedToConfig(true); - } - void assertZookeeperServerImplementation(boolean reconfigurable, String expectedClassName) { VespaModel model = createEnd2EndOneNode( new TestProperties() diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/GlobalComponentRegistry.java b/configserver/src/main/java/com/yahoo/vespa/config/server/GlobalComponentRegistry.java index 1eb18773898..23f4fdead62 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/GlobalComponentRegistry.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/GlobalComponentRegistry.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server; import com.yahoo.cloud.config.ConfigserverConfig; @@ -9,7 +9,7 @@ import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.vespa.config.server.application.PermanentApplicationPackage; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.modelfactory.ModelFactoryRegistry; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.session.SessionPreparer; @@ -38,7 +38,6 @@ public interface GlobalComponentRegistry { ReloadListener getReloadListener(); ConfigDefinitionRepo getStaticConfigDefinitionRepo(); PermanentApplicationPackage getPermanentApplicationPackage(); - HostRegistries getHostRegistries(); ModelFactoryRegistry getModelFactoryRegistry(); Optional<Provisioner> getHostProvisioner(); Zone getZone(); @@ -48,4 +47,5 @@ public interface GlobalComponentRegistry { FlagSource getFlagSource(); ExecutorService getZkCacheExecutor(); SecretStore getSecretStore(); + HostRegistry hostRegistry(); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistry.java b/configserver/src/main/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistry.java index 9badd19009f..c9022e49f2d 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistry.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistry.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server; import com.google.inject.Inject; @@ -11,7 +11,7 @@ import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.vespa.config.server.application.PermanentApplicationPackage; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.modelfactory.ModelFactoryRegistry; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; @@ -44,7 +44,6 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry private final ConfigserverConfig configserverConfig; private final ConfigDefinitionRepo staticConfigDefinitionRepo; private final PermanentApplicationPackage permanentApplicationPackage; - private final HostRegistries hostRegistries; private final Optional<Provisioner> hostProvisioner; private final Zone zone; private final ConfigServerDB configServerDB; @@ -52,6 +51,7 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry private final SecretStore secretStore; private final StripedExecutor<TenantName> zkWatcherExecutor; private final ExecutorService zkCacheExecutor; + private final HostRegistry hostRegistry; @SuppressWarnings("WeakerAccess") @Inject @@ -62,15 +62,14 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry SessionPreparer sessionPreparer, RpcServer rpcServer, ConfigserverConfig configserverConfig, - SuperModelGenerationCounter superModelGenerationCounter, ConfigDefinitionRepo staticConfigDefinitionRepo, PermanentApplicationPackage permanentApplicationPackage, - HostRegistries hostRegistries, HostProvisionerProvider hostProvisionerProvider, Zone zone, ConfigServerDB configServerDB, FlagSource flagSource, - SecretStore secretStore) { + SecretStore secretStore, + HostRegistry hostRegistry) { this.curator = curator; this.configCurator = configCurator; this.metrics = metrics; @@ -80,7 +79,6 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry this.configserverConfig = configserverConfig; this.staticConfigDefinitionRepo = staticConfigDefinitionRepo; this.permanentApplicationPackage = permanentApplicationPackage; - this.hostRegistries = hostRegistries; this.hostProvisioner = hostProvisionerProvider.getHostProvisioner(); this.zone = zone; this.configServerDB = configServerDB; @@ -88,6 +86,7 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry this.secretStore = secretStore; this.zkWatcherExecutor = new StripedExecutor<>(); this.zkCacheExecutor = Executors.newFixedThreadPool(1, ThreadFactoryFactory.getThreadFactory(TenantRepository.class.getName())); + this.hostRegistry = hostRegistry; } @Override @@ -109,8 +108,6 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry @Override public PermanentApplicationPackage getPermanentApplicationPackage() { return permanentApplicationPackage; } @Override - public HostRegistries getHostRegistries() { return hostRegistries; } - @Override public ModelFactoryRegistry getModelFactoryRegistry() { return modelFactoryRegistry; } @Override @@ -146,4 +143,8 @@ public class InjectedGlobalComponentRegistry implements GlobalComponentRegistry public SecretStore getSecretStore() { return secretStore; } + + @Override + public HostRegistry hostRegistry() { return hostRegistry; } + } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/ReloadListener.java b/configserver/src/main/java/com/yahoo/vespa/config/server/ReloadListener.java index 773a4862033..7cdb596780a 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/ReloadListener.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/ReloadListener.java @@ -1,8 +1,7 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server; import com.yahoo.config.provision.ApplicationId; -import com.yahoo.config.provision.TenantName; import com.yahoo.vespa.config.server.application.ApplicationSet; import java.util.Collection; @@ -17,21 +16,21 @@ import java.util.Collection; public interface ReloadListener { /** - * Signal the listener that hosts used by by a particular tenant. + * Signals the listener that hosts used by a particular tenant. * - * @param tenant Name of tenant. + * @param applicationId application id * @param newHosts a {@link Collection} of hosts used by tenant. */ - void hostsUpdated(TenantName tenant, Collection<String> newHosts); + void hostsUpdated(ApplicationId applicationId, Collection<String> newHosts); /** - * Verify that given hosts are available for use by tenant. + * Verifies that given hosts are available for use by tenant. * - * @param tenant tenant that wants to allocate hosts. + * @param applicationId application id * @param newHosts a {@link java.util.Collection} of hosts that tenant wants to allocate. * @throws java.lang.IllegalArgumentException if one or more of the hosts are in use by another tenant. */ - void verifyHostsAreAvailable(TenantName tenant, Collection<String> newHosts); + void verifyHostsAreAvailable(ApplicationId applicationId, Collection<String> newHosts); /** * Configs has been activated for an application: Either an application diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java index 5a34217dbdd..b126b006212 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/application/TenantApplications.java @@ -27,7 +27,6 @@ import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.curator.Lock; import com.yahoo.vespa.curator.transaction.CuratorTransaction; import org.apache.curator.framework.CuratorFramework; -import org.apache.curator.framework.recipes.cache.ChildData; import org.apache.curator.framework.recipes.cache.PathChildrenCacheEvent; import java.nio.file.Files; @@ -62,7 +61,7 @@ public class TenantApplications implements RequestHandler, HostValidator<Applica private final TenantName tenant; private final ReloadListener reloadListener; private final ConfigResponseFactory responseFactory; - private final HostRegistry<ApplicationId> hostRegistry; + private final HostRegistry hostRegistry; private final ApplicationMapper applicationMapper = new ApplicationMapper(); private final MetricUpdater tenantMetricUpdater; private final Clock clock; @@ -70,7 +69,7 @@ public class TenantApplications implements RequestHandler, HostValidator<Applica public TenantApplications(TenantName tenant, Curator curator, StripedExecutor<TenantName> zkWatcherExecutor, ExecutorService zkCacheExecutor, Metrics metrics, ReloadListener reloadListener, - ConfigserverConfig configserverConfig, HostRegistry<ApplicationId> hostRegistry, + ConfigserverConfig configserverConfig, HostRegistry hostRegistry, TenantFileSystemDirs tenantFileSystemDirs, Clock clock) { this.database = new ApplicationCuratorDatabase(tenant, curator); this.tenant = tenant; @@ -96,7 +95,7 @@ public class TenantApplications implements RequestHandler, HostValidator<Applica componentRegistry.getMetrics(), componentRegistry.getReloadListener(), componentRegistry.getConfigserverConfig(), - componentRegistry.getHostRegistries().createApplicationHostRegistry(tenantName), + new HostRegistry(), new TenantFileSystemDirs(componentRegistry.getConfigServerDB(), tenantName), componentRegistry.getClock()); } @@ -222,7 +221,10 @@ public class TenantApplications implements RequestHandler, HostValidator<Applica } private void notifyReloadListeners(ApplicationSet applicationSet) { - reloadListener.hostsUpdated(tenant, hostRegistry.getAllHosts()); + if (applicationSet.getAllApplications().isEmpty()) throw new IllegalArgumentException("application set cannot be empty"); + + reloadListener.hostsUpdated(applicationSet.getAllApplications().get(0).toApplicationInfo().getApplicationId(), + hostRegistry.getAllHosts()); reloadListener.configActivated(applicationSet); } @@ -271,7 +273,7 @@ public class TenantApplications implements RequestHandler, HostValidator<Applica } private void reloadListenersOnRemove(ApplicationId applicationId) { - reloadListener.hostsUpdated(tenant, hostRegistry.getAllHosts()); + reloadListener.hostsUpdated(applicationId, hostRegistry.getAllHosts()); reloadListener.applicationRemoved(applicationId); } @@ -382,9 +384,9 @@ public class TenantApplications implements RequestHandler, HostValidator<Applica } @Override - public void verifyHosts(ApplicationId key, Collection<String> newHosts) { - hostRegistry.verifyHosts(key, newHosts); - reloadListener.verifyHostsAreAvailable(tenant, newHosts); + public void verifyHosts(ApplicationId applicationId, Collection<String> newHosts) { + hostRegistry.verifyHosts(applicationId, newHosts); + reloadListener.verifyHostsAreAvailable(applicationId, newHosts); } public HostValidator<ApplicationId> getHostValidator() { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java index 14d156d39fb..40fbfe9ece8 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/deploy/ModelContextImpl.java @@ -49,7 +49,7 @@ public class ModelContextImpl implements ModelContext { private final DeployLogger deployLogger; private final ConfigDefinitionRepo configDefinitionRepo; private final FileRegistry fileRegistry; - private final Optional<HostProvisioner> hostProvisioner; + private final HostProvisioner hostProvisioner; private final Provisioned provisioned; private final Optional<? extends Reindexing> reindexing; private final ModelContext.Properties properties; @@ -76,7 +76,7 @@ public class ModelContextImpl implements ModelContext { ConfigDefinitionRepo configDefinitionRepo, FileRegistry fileRegistry, Optional<? extends Reindexing> reindexing, - Optional<HostProvisioner> hostProvisioner, + HostProvisioner hostProvisioner, Provisioned provisioned, ModelContext.Properties properties, Optional<File> appDir, @@ -112,9 +112,8 @@ public class ModelContextImpl implements ModelContext { * Returns the host provisioner to use, or empty to use the default provisioner, * creating hosts from the application package defined hosts */ - // TODO: Don't allow empty here but create the right provisioner when this is set up instead @Override - public Optional<HostProvisioner> hostProvisioner() { return hostProvisioner; } + public HostProvisioner getHostProvisioner() { return hostProvisioner; } @Override public Provisioned provisioned() { return provisioned; } @@ -152,8 +151,6 @@ public class ModelContextImpl implements ModelContext { private final double reindexerWindowSizeIncrement; private final double defaultTermwiseLimit; private final boolean useThreePhaseUpdates; - private final boolean useDirectStorageApiRpc; - private final boolean useFastValueTensorImplementation; private final String feedSequencer; private final String responseSequencer; private final int numResponseThreads; @@ -172,8 +169,6 @@ public class ModelContextImpl implements ModelContext { this.reindexerWindowSizeIncrement = flagValue(source, appId, Flags.REINDEXER_WINDOW_SIZE_INCREMENT); this.defaultTermwiseLimit = flagValue(source, appId, Flags.DEFAULT_TERM_WISE_LIMIT); this.useThreePhaseUpdates = flagValue(source, appId, Flags.USE_THREE_PHASE_UPDATES); - this.useDirectStorageApiRpc = flagValue(source, appId, Flags.USE_DIRECT_STORAGE_API_RPC); - this.useFastValueTensorImplementation = flagValue(source, appId, Flags.USE_FAST_VALUE_TENSOR_IMPLEMENTATION); this.feedSequencer = flagValue(source, appId, Flags.FEED_SEQUENCER_TYPE); this.responseSequencer = flagValue(source, appId, Flags.RESPONSE_SEQUENCER_TYPE); this.numResponseThreads = flagValue(source, appId, Flags.RESPONSE_NUM_THREADS); @@ -192,8 +187,6 @@ public class ModelContextImpl implements ModelContext { @Override public double reindexerWindowSizeIncrement() { return reindexerWindowSizeIncrement; } @Override public double defaultTermwiseLimit() { return defaultTermwiseLimit; } @Override public boolean useThreePhaseUpdates() { return useThreePhaseUpdates; } - @Override public boolean useDirectStorageApiRpc() { return useDirectStorageApiRpc; } - @Override public boolean useFastValueTensorImplementation() { return useFastValueTensorImplementation; } @Override public String feedSequencerType() { return feedSequencer; } @Override public String responseSequencerType() { return responseSequencer; } @Override public int defaultNumResponseThreads() { return numResponseThreads; } @@ -240,8 +233,6 @@ public class ModelContextImpl implements ModelContext { // Old non-permanent feature flags. Use ModelContext.FeatureFlag instead private final double defaultTermwiseLimit; private final boolean useThreePhaseUpdates; - private final boolean useDirectStorageApiRpc; - private final boolean useFastValueTensorImplementation; private final String feedSequencer; private final String responseSequencer; private final int numResponseThreads; @@ -287,8 +278,6 @@ public class ModelContextImpl implements ModelContext { // Old non-permanent feature flags. Use ModelContext.FeatureFlag instead defaultTermwiseLimit = flagValue(flagSource, applicationId, Flags.DEFAULT_TERM_WISE_LIMIT); useThreePhaseUpdates = flagValue(flagSource, applicationId, Flags.USE_THREE_PHASE_UPDATES); - useDirectStorageApiRpc = flagValue(flagSource, applicationId, Flags.USE_DIRECT_STORAGE_API_RPC); - useFastValueTensorImplementation = flagValue(flagSource, applicationId, Flags.USE_FAST_VALUE_TENSOR_IMPLEMENTATION); feedSequencer = flagValue(flagSource, applicationId, Flags.FEED_SEQUENCER_TYPE); responseSequencer = flagValue(flagSource, applicationId, Flags.RESPONSE_SEQUENCER_TYPE); numResponseThreads = flagValue(flagSource, applicationId, Flags.RESPONSE_NUM_THREADS); @@ -359,8 +348,6 @@ public class ModelContextImpl implements ModelContext { // Old non-permanent feature flags. Use ModelContext.FeatureFlag instead @Override public double defaultTermwiseLimit() { return defaultTermwiseLimit; } @Override public boolean useThreePhaseUpdates() { return useThreePhaseUpdates; } - @Override public boolean useDirectStorageApiRpc() { return useDirectStorageApiRpc; } - @Override public boolean useFastValueTensorImplementation() { return useFastValueTensorImplementation; } @Override public String feedSequencerType() { return feedSequencer; } @Override public String responseSequencerType() { return responseSequencer; } @Override public int defaultNumResponseThreads() { return numResponseThreads; } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistries.java b/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistries.java deleted file mode 100644 index c25ab0315a3..00000000000 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistries.java +++ /dev/null @@ -1,34 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.vespa.config.server.host; - -import com.yahoo.config.provision.ApplicationId; -import com.yahoo.config.provision.TenantName; - -import java.util.Map; -import java.util.concurrent.ConcurrentHashMap; - -/** - * Component to hold host registries. - * - * @author hmusum - */ -public class HostRegistries { - - private final HostRegistry<TenantName> tenantHostRegistry = new HostRegistry<>(); - private final Map<TenantName, HostRegistry<ApplicationId>> applicationHostRegistries = new ConcurrentHashMap<>(); - - public HostRegistry<TenantName> getTenantHostRegistry() { - return tenantHostRegistry; - } - - public HostRegistry<ApplicationId> getApplicationHostRegistry(TenantName tenant) { - return applicationHostRegistries.get(tenant); - } - - public HostRegistry<ApplicationId> createApplicationHostRegistry(TenantName tenant) { - HostRegistry<ApplicationId> applicationIdHostRegistry = new HostRegistry<>(); - applicationHostRegistries.put(tenant, applicationIdHostRegistry); - return applicationIdHostRegistry; - } - -} diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistry.java b/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistry.java index ec37f2598e0..1fc9d34e153 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistry.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/host/HostRegistry.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.host; import java.util.*; @@ -7,6 +7,9 @@ import java.util.logging.Logger; import java.util.stream.Collectors; import com.google.common.collect.Collections2; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.TenantName; + import java.util.logging.Level; /** @@ -15,20 +18,20 @@ import java.util.logging.Level; * * @author Ulf Lilleengen */ -public class HostRegistry<T> implements HostValidator<T> { +public class HostRegistry implements HostValidator<ApplicationId> { private static final Logger log = Logger.getLogger(HostRegistry.class.getName()); - private final Map<String, T> host2KeyMap = new ConcurrentHashMap<>(); + private final Map<String, ApplicationId> host2KeyMap = new ConcurrentHashMap<>(); - public T getKeyForHost(String hostName) { + public ApplicationId getKeyForHost(String hostName) { return host2KeyMap.get(hostName); } - public synchronized void update(T key, Collection<String> newHosts) { + public synchronized void update(ApplicationId key, Collection<String> newHosts) { verifyHosts(key, newHosts); Collection<String> currentHosts = getHostsForKey(key); - log.log(Level.FINE, () -> "Setting hosts for key '" + key + "', " + + log.log(Level.INFO, () -> "Setting hosts for key '" + key + "', " + "newHosts: " + newHosts + ", " + "currentHosts: " + currentHosts); Collection<String> removedHosts = getRemovedHosts(newHosts, currentHosts); @@ -37,7 +40,7 @@ public class HostRegistry<T> implements HostValidator<T> { } @Override - public synchronized void verifyHosts(T key, Collection<String> newHosts) { + public synchronized void verifyHosts(ApplicationId key, Collection<String> newHosts) { for (String host : newHosts) { if (hostAlreadyTaken(host, key)) { throw new IllegalArgumentException("'" + key + "' tried to allocate host '" + host + @@ -46,22 +49,26 @@ public class HostRegistry<T> implements HostValidator<T> { } } - public synchronized void removeHostsForKey(T key) { + public synchronized void removeHostsForKey(ApplicationId key) { host2KeyMap.entrySet().removeIf(entry -> entry.getValue().equals(key)); } + public synchronized void removeHostsForKey(TenantName key) { + host2KeyMap.entrySet().removeIf(entry -> entry.getValue().tenant().equals(key)); + } + public synchronized Collection<String> getAllHosts() { return Collections.unmodifiableCollection(new ArrayList<>(host2KeyMap.keySet())); } - synchronized Collection<String> getHostsForKey(T key) { + synchronized Collection<String> getHostsForKey(ApplicationId key) { return host2KeyMap.entrySet().stream() .filter(entry -> entry.getValue().equals(key)) .map(Map.Entry::getKey) .collect(Collectors.toSet()); } - private boolean hostAlreadyTaken(String host, T key) { + private boolean hostAlreadyTaken(String host, ApplicationId key) { return host2KeyMap.containsKey(host) && !key.equals(host2KeyMap.get(host)); } @@ -76,7 +83,7 @@ public class HostRegistry<T> implements HostValidator<T> { } } - private void addHosts(T key, Collection<String> newHosts) { + private void addHosts(ApplicationId key, Collection<String> newHosts) { for (String host : newHosts) { log.log(Level.FINE, () -> "Adding " + host); host2KeyMap.put(host, key); diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java index 0a81c408ef4..fa058514d17 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ActivatedModelsBuilder.java @@ -102,9 +102,7 @@ public class ActivatedModelsBuilder extends ModelsBuilder<Application> { configDefinitionRepo, getForVersionOrLatest(applicationPackage.getFileRegistries(), modelFactory.version()).orElse(new MockFileRegistry()), new ApplicationCuratorDatabase(tenant, curator).readReindexingStatus(applicationId), - createStaticProvisioner(applicationPackage.getAllocatedHosts(), - modelContextProperties.applicationId(), - provisioned), + createStaticProvisioner(applicationPackage, modelContextProperties.applicationId(), provisioned), provisioned, modelContextProperties, Optional.empty(), diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java index 245b9db020b..75ab09d241b 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/ModelsBuilder.java @@ -1,22 +1,22 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.modelfactory; import com.google.common.util.concurrent.UncheckedTimeoutException; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.component.Version; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.api.HostProvisioner; import com.yahoo.config.model.api.ModelFactory; import com.yahoo.config.model.api.Provisioned; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.ApplicationLockException; import com.yahoo.config.provision.DockerImage; import com.yahoo.config.provision.OutOfCapacityException; -import com.yahoo.component.Version; import com.yahoo.config.provision.TransientException; import com.yahoo.config.provision.Zone; import com.yahoo.lang.SettableOptional; -import java.util.logging.Level; import com.yahoo.vespa.config.server.http.InternalServerException; import com.yahoo.vespa.config.server.http.UnknownVespaVersionException; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; @@ -31,6 +31,7 @@ import java.util.List; import java.util.NoSuchElementException; import java.util.Optional; import java.util.Set; +import java.util.logging.Level; import java.util.logging.Logger; import java.util.stream.Collectors; @@ -48,7 +49,7 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { protected final ConfigserverConfig configserverConfig; /** True if we are running in hosted mode */ - private final boolean hosted; + protected final boolean hosted; private final Zone zone; @@ -156,7 +157,7 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { Instant now, boolean buildLatestModelForThisMajor, int majorVersion) { - List<MODELRESULT> allApplicationVersions = new ArrayList<>(); + List<MODELRESULT> builtModelVersions = new ArrayList<>(); Optional<Version> latest = Optional.empty(); if (buildLatestModelForThisMajor) { latest = Optional.of(findLatest(versions)); @@ -168,7 +169,7 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { wantedNodeVespaVersion, allocatedHosts.asOptional()); allocatedHosts.set(latestModelVersion.getModel().allocatedHosts()); // Update with additional clusters allocated - allApplicationVersions.add(latestModelVersion); + builtModelVersions.add(latestModelVersion); } // load old model versions @@ -181,28 +182,28 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { for (Version version : versions) { if (latest.isPresent() && version.equals(latest.get())) continue; // already loaded - MODELRESULT modelVersion; try { - modelVersion = buildModelVersion(modelFactoryRegistry.getFactory(version), - applicationPackage, - applicationId, - wantedDockerImageRepository, - wantedNodeVespaVersion, - allocatedHosts.asOptional()); + MODELRESULT modelVersion = buildModelVersion(modelFactoryRegistry.getFactory(version), + applicationPackage, + applicationId, + wantedDockerImageRepository, + wantedNodeVespaVersion, + allocatedHosts.asOptional()); allocatedHosts.set(modelVersion.getModel().allocatedHosts()); // Update with additional clusters allocated - allApplicationVersions.add(modelVersion); + builtModelVersions.add(modelVersion); } catch (RuntimeException e) { // allow failure to create old config models if there is a validation override that allow skipping old // config models (which is always true for manually deployed zones) - if (allApplicationVersions.size() > 0 && allApplicationVersions.get(0).getModel().skipOldConfigModels(now)) - log.log(Level.INFO, applicationId + ": Skipping old version (due to validation override)"); + if (builtModelVersions.size() > 0 && builtModelVersions.get(0).getModel().skipOldConfigModels(now)) + log.log(Level.INFO, applicationId + ": Failed to build version " + version + + ", but allow failure due to validation override ´skipOldConfigModels´"); else { log.log(Level.SEVERE, applicationId + ": Failed to build version " + version); throw e; } } } - return allApplicationVersions; + return builtModelVersions; } private Set<Version> versionsToBuild(Set<Version> versions, Version wantedVersion, int majorVersion, AllocatedHosts allocatedHosts) { @@ -250,12 +251,20 @@ public abstract class ModelsBuilder<MODELRESULT extends ModelResult> { * returns empty otherwise, which may either mean that no hosts are allocated or that we are running * non-hosted and should default to use hosts defined in the application package, depending on context */ - Optional<HostProvisioner> createStaticProvisioner(Optional<AllocatedHosts> allocatedHosts, - ApplicationId applicationId, - Provisioned provisioned) { + HostProvisioner createStaticProvisioner(ApplicationPackage applicationPackage, + ApplicationId applicationId, + Provisioned provisioned) { + Optional<AllocatedHosts> allocatedHosts = applicationPackage.getAllocatedHosts(); if (hosted && allocatedHosts.isPresent()) - return Optional.of(new StaticProvisioner(allocatedHosts.get(), createNodeRepositoryProvisioner(applicationId, provisioned).get())); - return Optional.empty(); + return createStaticProvisionerForHosted(allocatedHosts.get(), createNodeRepositoryProvisioner(applicationId, provisioned).get()); + return DeployState.getDefaultModelHostProvisioner(applicationPackage); + } + + /** + * Returns a host provisioner returning the previously allocated hosts + */ + HostProvisioner createStaticProvisionerForHosted(AllocatedHosts allocatedHosts, HostProvisioner nodeRepositoryProvisioner) { + return new StaticProvisioner(allocatedHosts, nodeRepositoryProvisioner); } Optional<HostProvisioner> createNodeRepositoryProvisioner(ApplicationId applicationId, Provisioned provisioned) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java index a064e8a9cac..7606aacff15 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/modelfactory/PreparedModelsBuilder.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.modelfactory; import com.yahoo.cloud.config.ConfigserverConfig; @@ -17,6 +17,7 @@ import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.api.ValidationParameters; import com.yahoo.config.model.api.ValidationParameters.IgnoreValidationErrors; import com.yahoo.config.model.application.provider.FilesApplicationPackage; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.DockerImage; @@ -28,7 +29,6 @@ import com.yahoo.vespa.config.server.deploy.ModelContextImpl; import com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider; import com.yahoo.vespa.config.server.host.HostValidator; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; -import com.yahoo.vespa.config.server.provision.StaticProvisioner; import com.yahoo.vespa.config.server.session.PrepareParams; import com.yahoo.vespa.curator.Curator; @@ -101,7 +101,7 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P configDefinitionRepo, fileDistributionProvider.getFileRegistry(), new ApplicationCuratorDatabase(applicationId.tenant(), curator).readReindexingStatus(applicationId), - createHostProvisioner(allocatedHosts, provisioned), + createHostProvisioner(applicationPackage, provisioned), provisioned, properties, getAppDir(applicationPackage), @@ -119,7 +119,7 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P log.log(Level.FINE, "Done building model " + modelVersion + " for " + applicationId); params.getTimeoutBudget().assertNotTimedOut(() -> "prepare timed out after building model " + modelVersion + " (timeout " + params.getTimeoutBudget().timeout() + "): " + applicationId); - return new PreparedModelsBuilder.PreparedModelResult(modelVersion, result.getModel(), fileDistributionProvider, result.getConfigChangeActions()); + return new PreparedModelResult(modelVersion, result.getModel(), fileDistributionProvider, result.getConfigChangeActions()); } private Optional<Model> modelOf(Version version) { @@ -127,22 +127,19 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P return currentActiveApplicationSet.get().get(version).map(Application::getModel); } - // This method is an excellent demonstration of what happens when one is too liberal with Optional - // -bratseth, who had to write the below :-\ - private Optional<HostProvisioner> createHostProvisioner(Optional<AllocatedHosts> allocatedHosts, - Provisioned provisioned) { - Optional<HostProvisioner> nodeRepositoryProvisioner = createNodeRepositoryProvisioner(properties.applicationId(), - provisioned); - if (allocatedHosts.isEmpty()) return nodeRepositoryProvisioner; - - Optional<HostProvisioner> staticProvisioner = createStaticProvisioner(allocatedHosts, - properties.applicationId(), - provisioned); - if (staticProvisioner.isEmpty()) return Optional.empty(); // Since we have hosts allocated this means we are on non-hosted - + private HostProvisioner createHostProvisioner(ApplicationPackage applicationPackage, Provisioned provisioned) { + HostProvisioner defaultHostProvisioner = DeployState.getDefaultModelHostProvisioner(applicationPackage); + // Note: nodeRepositoryProvisioner will always be present when hosted is true + Optional<HostProvisioner> nodeRepositoryProvisioner = createNodeRepositoryProvisioner(properties.applicationId(), provisioned); + Optional<AllocatedHosts> allocatedHosts = applicationPackage.getAllocatedHosts(); + + if (allocatedHosts.isEmpty()) return nodeRepositoryProvisioner.orElse(defaultHostProvisioner); + // Nodes are already allocated by a model and we should use them unless this model requests hosts from a // previously unallocated cluster. This allows future models to stop allocate certain clusters. - return Optional.of(new StaticProvisioner(allocatedHosts.get(), nodeRepositoryProvisioner.get())); + if (hosted) return createStaticProvisionerForHosted(allocatedHosts.get(), nodeRepositoryProvisioner.get()); + + return defaultHostProvisioner; } private Optional<File> getAppDir(ApplicationPackage applicationPackage) { @@ -166,11 +163,13 @@ public class PreparedModelsBuilder extends ModelsBuilder<PreparedModelsBuilder.P public final Version version; public final Model model; - public final com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider fileDistributionProvider; + public final FileDistributionProvider fileDistributionProvider; public final List<ConfigChangeAction> actions; - public PreparedModelResult(Version version, Model model, - com.yahoo.vespa.config.server.filedistribution.FileDistributionProvider fileDistributionProvider, List<ConfigChangeAction> actions) { + public PreparedModelResult(Version version, + Model model, + FileDistributionProvider fileDistributionProvider, + List<ConfigChangeAction> actions) { this.version = version; this.model = model; this.fileDistributionProvider = fileDistributionProvider; diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java index e64859e7267..370ae72bbbd 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/RpcServer.java @@ -34,7 +34,6 @@ import com.yahoo.vespa.config.server.RequestHandler; import com.yahoo.vespa.config.server.SuperModelRequestHandler; import com.yahoo.vespa.config.server.application.ApplicationSet; import com.yahoo.vespa.config.server.filedistribution.FileServer; -import com.yahoo.vespa.config.server.host.HostRegistries; import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.monitoring.MetricUpdater; import com.yahoo.vespa.config.server.monitoring.MetricUpdaterFactory; @@ -92,7 +91,7 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { private final DelayedConfigResponses delayedConfigResponses; - private final HostRegistry<TenantName> hostRegistry; + private final HostRegistry hostRegistry; private final Map<TenantName, Tenant> tenants = new ConcurrentHashMap<>(); private final Map<ApplicationId, ApplicationState> applicationStateMap = new ConcurrentHashMap<>(); private final SuperModelRequestHandler superModelRequestHandler; @@ -122,7 +121,7 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { */ @Inject public RpcServer(ConfigserverConfig config, SuperModelRequestHandler superModelRequestHandler, - MetricUpdaterFactory metrics, HostRegistries hostRegistries, + MetricUpdaterFactory metrics, HostRegistry hostRegistry, HostLivenessTracker hostLivenessTracker, FileServer fileServer, RpcAuthorizer rpcAuthorizer, RpcRequestHandlerProvider handlerProvider) { this.superModelRequestHandler = superModelRequestHandler; @@ -136,7 +135,7 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { 0, TimeUnit.SECONDS, workQueue, ThreadFactoryFactory.getDaemonThreadFactory(THREADPOOL_NAME)); delayedConfigResponses = new DelayedConfigResponses(this, config.numDelayedResponseThreads()); spec = new Spec(null, config.rpcport()); - hostRegistry = hostRegistries.getTenantHostRegistry(); + this.hostRegistry = hostRegistry; this.useRequestVersion = config.useVespaVersionInRequest(); this.hostedVespa = config.hostedVespa(); this.canReturnEmptySentinelConfig = config.canReturnEmptySentinelConfig(); @@ -303,14 +302,14 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { } @Override - public void hostsUpdated(TenantName tenant, Collection<String> newHosts) { + public void hostsUpdated(ApplicationId applicationId, Collection<String> newHosts) { log.log(Level.FINE, "Updating hosts in tenant host registry '" + hostRegistry + "' with " + newHosts); - hostRegistry.update(tenant, newHosts); + hostRegistry.update(applicationId, newHosts); } @Override - public void verifyHostsAreAvailable(TenantName tenant, Collection<String> newHosts) { - hostRegistry.verifyHosts(tenant, newHosts); + public void verifyHostsAreAvailable(ApplicationId applicationId, Collection<String> newHosts) { + hostRegistry.verifyHosts(applicationId, newHosts); } @Override @@ -334,8 +333,8 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { Optional<TenantName> resolveTenant(JRTServerConfigRequest request, Trace trace) { if ("*".equals(request.getConfigKey().getConfigId())) return Optional.of(ApplicationId.global().tenant()); String hostname = request.getClientHostName(); - TenantName tenant = hostRegistry.getKeyForHost(hostname); - if (tenant == null) { + ApplicationId applicationId = hostRegistry.getKeyForHost(hostname); + if (applicationId == null) { if (GetConfigProcessor.logDebug(trace)) { String message = "Did not find tenant for host '" + hostname + "', using " + TenantName.defaultName(); log.log(Level.FINE, message); @@ -344,7 +343,7 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { } return Optional.empty(); } - return Optional.of(tenant); + return Optional.of(applicationId.tenant()); } public ConfigResponse resolveConfig(JRTServerConfigRequest request, GetConfigContext context, Optional<Version> vespaVersion) { @@ -425,7 +424,8 @@ public class RpcServer implements Runnable, ReloadListener, TenantListener { @Override public void onTenantDelete(TenantName tenant) { - log.log(Level.FINE, TenantRepository.logPre(tenant)+"Tenant deleted, removing request handler and cleaning host registry"); + log.log(Level.FINE, TenantRepository.logPre(tenant) + + "Tenant deleted, removing request handler and cleaning host registry"); tenants.remove(tenant); hostRegistry.removeHostsForKey(tenant); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/DefaultRpcAuthorizerProvider.java b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/DefaultRpcAuthorizerProvider.java index 8d1d4f58e37..242c401de92 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/DefaultRpcAuthorizerProvider.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/DefaultRpcAuthorizerProvider.java @@ -1,12 +1,14 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.rpc.security; import com.google.inject.Inject; import com.yahoo.cloud.config.ConfigserverConfig; +import com.yahoo.config.provision.ApplicationId; +import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.security.NodeIdentifier; import com.yahoo.container.di.componentgraph.Provider; import com.yahoo.security.tls.TransportSecurityUtils; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.rpc.RequestHandlerProvider; /** @@ -21,13 +23,13 @@ public class DefaultRpcAuthorizerProvider implements Provider<RpcAuthorizer> { @Inject public DefaultRpcAuthorizerProvider(ConfigserverConfig config, NodeIdentifier nodeIdentifier, - HostRegistries hostRegistries, + HostRegistry hostRegistry, RequestHandlerProvider handlerProvider) { boolean useMultiTenantAuthorizer = TransportSecurityUtils.isTransportSecurityEnabled() && config.multitenant() && config.hostedVespa(); this.rpcAuthorizer = useMultiTenantAuthorizer - ? new MultiTenantRpcAuthorizer(nodeIdentifier, hostRegistries, handlerProvider, getThreadPoolSize(config)) + ? new MultiTenantRpcAuthorizer(nodeIdentifier, hostRegistry, handlerProvider, getThreadPoolSize(config)) : new NoopRpcAuthorizer(); } diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizer.java index 49a8df3d0e4..8353e3fab1f 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizer.java @@ -1,4 +1,4 @@ -// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.rpc.security; import com.yahoo.cloud.config.SentinelConfig; @@ -16,7 +16,6 @@ import com.yahoo.security.tls.TransportSecurityUtils; import com.yahoo.vespa.config.ConfigKey; import com.yahoo.vespa.config.protocol.JRTServerConfigRequestV3; import com.yahoo.vespa.config.server.RequestHandler; -import com.yahoo.vespa.config.server.host.HostRegistries; import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.rpc.RequestHandlerProvider; @@ -34,7 +33,6 @@ import java.util.logging.Logger; import static com.yahoo.vespa.config.server.rpc.security.AuthorizationException.Type; import static com.yahoo.yolean.Exceptions.throwUnchecked; - /** * A {@link RpcAuthorizer} that perform access control for configserver RPC methods when TLS and multi-tenant mode are enabled. * @@ -45,22 +43,22 @@ public class MultiTenantRpcAuthorizer implements RpcAuthorizer { private static final Logger log = Logger.getLogger(MultiTenantRpcAuthorizer.class.getName()); private final NodeIdentifier nodeIdentifier; - private final HostRegistry<TenantName> hostRegistry; + private final HostRegistry hostRegistry; private final RequestHandlerProvider handlerProvider; private final Executor executor; public MultiTenantRpcAuthorizer(NodeIdentifier nodeIdentifier, - HostRegistries hostRegistries, + HostRegistry hostRegistry, RequestHandlerProvider handlerProvider, int threadPoolSize) { this(nodeIdentifier, - hostRegistries.getTenantHostRegistry(), + hostRegistry, handlerProvider, Executors.newFixedThreadPool(threadPoolSize, new DaemonThreadFactory("multi-tenant-rpc-authorizer-"))); } MultiTenantRpcAuthorizer(NodeIdentifier nodeIdentifier, - HostRegistry<TenantName> hostRegistry, + HostRegistry hostRegistry, RequestHandlerProvider handlerProvider, Executor executor) { this.nodeIdentifier = nodeIdentifier; @@ -108,14 +106,14 @@ public class MultiTenantRpcAuthorizer implements RpcAuthorizer { return; // global config access ok } else { String hostname = configRequest.getClientHostName(); - Optional<TenantName> tenantName = Optional.ofNullable(hostRegistry.getKeyForHost(hostname)); - if (tenantName.isEmpty()) { + ApplicationId applicationId = hostRegistry.getKeyForHost(hostname); + if (applicationId == null) { if (isConfigKeyForSentinelConfig(configKey)) { return; // config processor will return empty sentinel config for unknown nodes } throw new AuthorizationException(Type.SILENT, String.format("Host '%s' not found in host registry for [%s]", hostname, configKey)); } - RequestHandler tenantHandler = getTenantHandler(tenantName.get()); + RequestHandler tenantHandler = getTenantHandler(applicationId.tenant()); ApplicationId resolvedApplication = tenantHandler.resolveApplicationId(hostname); ApplicationId peerOwner = applicationId(peerIdentity); if (peerOwner.equals(resolvedApplication)) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java index f1775492003..5609de68391 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionPreparer.java @@ -47,10 +47,7 @@ import com.yahoo.vespa.config.server.tenant.EndpointCertificateRetriever; import com.yahoo.vespa.config.server.tenant.TenantRepository; import com.yahoo.vespa.curator.Curator; import com.yahoo.vespa.flags.FlagSource; -import org.xml.sax.SAXException; -import javax.xml.parsers.ParserConfigurationException; -import javax.xml.transform.TransformerException; import java.io.File; import java.io.IOException; import java.time.Instant; @@ -174,11 +171,11 @@ public class SessionPreparer { Preparation(HostValidator<ApplicationId> hostValidator, DeployLogger logger, PrepareParams params, Optional<ApplicationSet> currentActiveApplicationSet, Path tenantPath, - File serverDbSessionDir, ApplicationPackage preprocessedApplicationPackage, + File serverDbSessionDir, ApplicationPackage applicationPackage, SessionZooKeeperClient sessionZooKeeperClient) { this.logger = logger; this.params = params; - this.applicationPackage = preprocessedApplicationPackage; + this.applicationPackage = applicationPackage; this.sessionZooKeeperClient = sessionZooKeeperClient; this.applicationId = params.getApplicationId(); this.dockerImageRepository = params.dockerImageRepository(); @@ -402,6 +399,7 @@ public class SessionPreparer { * This class ensures these constraints and returns a reconciliated set of nodes which should be activated, * given a set of model activation results. */ + @SuppressWarnings("unused") private static final class ReconciliatedHostAllocations { public ReconciliatedHostAllocations(List<PreparedModelsBuilder.PreparedModelResult> results) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java index 64b757037d0..592198cbbef 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/session/SessionRepository.java @@ -73,6 +73,7 @@ public class SessionRepository { private static final FilenameFilter sessionApplicationsFilter = (dir, name) -> name.matches("\\d+"); private static final long nonExistingActiveSessionId = 0; + private final Object monitor = new Object(); private final Map<Long, LocalSession> localSessionCache = new ConcurrentHashMap<>(); private final Map<Long, RemoteSession> remoteSessionCache = new ConcurrentHashMap<>(); private final Map<Long, SessionStateWatcher> sessionStateWatchers = new HashMap<>(); @@ -202,8 +203,8 @@ public class SessionRepository { } /** - * This method is used when creating a session based on a remote session and the distributed application package - * It does not wait for session being created on other servers + * Creates a local session based on a remote session and the distributed application package. + * Does not wait for session being created on other servers. */ private void createLocalSession(File applicationFile, ApplicationId applicationId, long sessionId) { try { @@ -548,19 +549,25 @@ public class SessionRepository { } } - private ApplicationPackage createApplicationPackage(File applicationFile, ApplicationId applicationId, - long sessionId, boolean internalRedeploy) throws IOException { - Optional<Long> activeSessionId = getActiveSessionId(applicationId); - File userApplicationDir = getSessionAppDir(sessionId); - copyApp(applicationFile, userApplicationDir); - ApplicationPackage applicationPackage = createApplication(applicationFile, - userApplicationDir, - applicationId, - sessionId, - activeSessionId, - internalRedeploy); - applicationPackage.writeMetaData(); - return applicationPackage; + private ApplicationPackage createApplicationPackage(File applicationFile, + ApplicationId applicationId, + long sessionId, + boolean internalRedeploy) throws IOException { + // Synchronize to avoid threads trying to create an application package concurrently + // (e.g. a maintainer and an external deployment) + synchronized (monitor) { + Optional<Long> activeSessionId = getActiveSessionId(applicationId); + File userApplicationDir = getSessionAppDir(sessionId); + copyApp(applicationFile, userApplicationDir); + ApplicationPackage applicationPackage = createApplication(applicationFile, + userApplicationDir, + applicationId, + sessionId, + activeSessionId, + internalRedeploy); + applicationPackage.writeMetaData(); + return applicationPackage; + } } public Optional<ApplicationSet> getActiveApplicationSet(ApplicationId appId) { diff --git a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java index 5c15b72eaac..896c2efaa56 100644 --- a/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java +++ b/configserver/src/main/java/com/yahoo/vespa/config/server/tenant/TenantRepository.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.tenant; import com.google.common.collect.ImmutableSet; @@ -14,6 +14,7 @@ import com.yahoo.transaction.Transaction; import com.yahoo.vespa.config.server.GlobalComponentRegistry; import com.yahoo.vespa.config.server.application.TenantApplications; import com.yahoo.vespa.config.server.deploy.TenantFileSystemDirs; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.monitoring.MetricUpdater; import com.yahoo.vespa.config.server.session.SessionRepository; import com.yahoo.vespa.curator.Curator; @@ -228,7 +229,7 @@ public class TenantRepository { componentRegistry.getMetrics(), componentRegistry.getReloadListener(), componentRegistry.getConfigserverConfig(), - componentRegistry.getHostRegistries().createApplicationHostRegistry(tenantName), + componentRegistry.hostRegistry(), new TenantFileSystemDirs(componentRegistry.getConfigServerDB(), tenantName), componentRegistry.getClock()); SessionRepository sessionRepository = new SessionRepository(tenantName, diff --git a/configserver/src/main/resources/configserver-app/services.xml b/configserver/src/main/resources/configserver-app/services.xml index dac10ac4e59..0ac548c54d8 100644 --- a/configserver/src/main/resources/configserver-app/services.xml +++ b/configserver/src/main/resources/configserver-app/services.xml @@ -27,7 +27,6 @@ <component id="com.yahoo.vespa.config.server.InjectedGlobalComponentRegistry" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.tenant.TenantRepository" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.application.PermanentApplicationPackage" bundle="configserver" /> - <component id="com.yahoo.vespa.config.server.host.HostRegistries" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.ApplicationRepository" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.version.VersionState" bundle="configserver" /> <component id="com.yahoo.vespa.config.server.zookeeper.ConfigCurator" bundle="configserver" /> diff --git a/configserver/src/test/apps/app-jdisc-only-restart/hosts.xml b/configserver/src/test/apps/app-jdisc-only-restart/hosts.xml index f4256c9fc81..ab70b288ba6 100644 --- a/configserver/src/test/apps/app-jdisc-only-restart/hosts.xml +++ b/configserver/src/test/apps/app-jdisc-only-restart/hosts.xml @@ -1,7 +1,7 @@ <?xml version="1.0" encoding="utf-8" ?> <!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <hosts> - <host name="mytesthost"> + <host name="mytesthost2"> <alias>node1</alias> </host> </hosts> diff --git a/configserver/src/test/apps/app-jdisc-only/hosts.xml b/configserver/src/test/apps/app-jdisc-only/hosts.xml index f4256c9fc81..ab70b288ba6 100644 --- a/configserver/src/test/apps/app-jdisc-only/hosts.xml +++ b/configserver/src/test/apps/app-jdisc-only/hosts.xml @@ -1,7 +1,7 @@ <?xml version="1.0" encoding="utf-8" ?> <!-- Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> <hosts> - <host name="mytesthost"> + <host name="mytesthost2"> <alias>node1</alias> </host> </hosts> diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java index 0cab49a9f00..cb91a01ab55 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ApplicationRepositoryTest.java @@ -192,7 +192,7 @@ public class ApplicationRepositoryTest { PrepareResult result = prepareAndActivate(testAppJdiscOnlyRestart); assertTrue(result.configChangeActions().getRefeedActions().isEmpty()); assertTrue(result.configChangeActions().getRestartActions().isEmpty()); - assertEquals(HostFilter.hostname("mytesthost"), provisioner.lastRestartFilter()); + assertEquals(HostFilter.hostname("mytesthost2"), provisioner.lastRestartFilter()); } @Test @@ -345,7 +345,7 @@ public class ApplicationRepositoryTest { // Deploy another app (with id fooId) ApplicationId fooId = applicationId(tenant2); PrepareParams prepareParams2 = new PrepareParams.Builder().applicationId(fooId).build(); - deployApp(testApp, prepareParams2); + deployApp(testAppJdiscOnly, prepareParams2); assertNotNull(applicationRepository.getActiveSession(fooId)); // Delete app with id fooId, should not affect original app @@ -530,14 +530,14 @@ public class ApplicationRepositoryTest { list.add(new NetworkPorts.Allocation(19081, "logserver", "admin/logserver", "unused/1")); list.add(new NetworkPorts.Allocation(19082, "logserver", "admin/logserver", "unused/2")); list.add(new NetworkPorts.Allocation(19083, "logserver", "admin/logserver", "unused/3")); - list.add(new NetworkPorts.Allocation(19089, "logd", "hosts/mytesthost/logd", "http")); - list.add(new NetworkPorts.Allocation(19090, "configproxy", "hosts/mytesthost/configproxy", "rpc")); - list.add(new NetworkPorts.Allocation(19092, "metricsproxy-container", "admin/metrics/mytesthost", "http")); - list.add(new NetworkPorts.Allocation(19093, "metricsproxy-container", "admin/metrics/mytesthost", "http/1")); - list.add(new NetworkPorts.Allocation(19094, "metricsproxy-container", "admin/metrics/mytesthost", "rpc/admin")); - list.add(new NetworkPorts.Allocation(19095, "metricsproxy-container", "admin/metrics/mytesthost", "rpc/metrics")); - list.add(new NetworkPorts.Allocation(19097, "config-sentinel", "hosts/mytesthost/sentinel", "rpc")); - list.add(new NetworkPorts.Allocation(19098, "config-sentinel", "hosts/mytesthost/sentinel", "http")); + list.add(new NetworkPorts.Allocation(19089, "logd", "hosts/mytesthost2/logd", "http")); + list.add(new NetworkPorts.Allocation(19090, "configproxy", "hosts/mytesthost2/configproxy", "rpc")); + list.add(new NetworkPorts.Allocation(19092, "metricsproxy-container", "admin/metrics/mytesthost2", "http")); + list.add(new NetworkPorts.Allocation(19093, "metricsproxy-container", "admin/metrics/mytesthost2", "http/1")); + list.add(new NetworkPorts.Allocation(19094, "metricsproxy-container", "admin/metrics/mytesthost2", "rpc/admin")); + list.add(new NetworkPorts.Allocation(19095, "metricsproxy-container", "admin/metrics/mytesthost2", "rpc/metrics")); + list.add(new NetworkPorts.Allocation(19097, "config-sentinel", "hosts/mytesthost2/sentinel", "rpc")); + list.add(new NetworkPorts.Allocation(19098, "config-sentinel", "hosts/mytesthost2/sentinel", "http")); list.add(new NetworkPorts.Allocation(19099, "slobrok", "admin/slobrok.0", "rpc")); list.add(new NetworkPorts.Allocation(19100, "container", "container/container.0", "http/1")); list.add(new NetworkPorts.Allocation(19101, "container", "container/container.0", "messaging")); @@ -547,7 +547,7 @@ public class ApplicationRepositoryTest { AllocatedHosts info = session.getAllocatedHosts(); assertNotNull(info); assertThat(info.getHosts().size(), is(1)); - assertTrue(info.getHosts().contains(new HostSpec("mytesthost", + assertTrue(info.getHosts().contains(new HostSpec("mytesthost2", Collections.emptyList(), Optional.empty()))); Optional<NetworkPorts> portsCopy = info.getHosts().iterator().next().networkPorts(); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistryTest.java index bf54c2b309e..4f6642610dd 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/InjectedGlobalComponentRegistryTest.java @@ -8,12 +8,12 @@ import com.yahoo.config.provision.Zone; import com.yahoo.vespa.config.server.application.PermanentApplicationPackage; import com.yahoo.vespa.config.server.filedistribution.FileServer; import com.yahoo.vespa.config.server.host.ConfigRequestHostLivenessTracker; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.modelfactory.ModelFactoryRegistry; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; -import com.yahoo.vespa.config.server.rpc.RpcServer; import com.yahoo.vespa.config.server.rpc.RpcRequestHandlerProvider; +import com.yahoo.vespa.config.server.rpc.RpcServer; import com.yahoo.vespa.config.server.rpc.security.NoopRpcAuthorizer; import com.yahoo.vespa.config.server.session.SessionPreparer; import com.yahoo.vespa.config.server.session.SessionTest; @@ -46,7 +46,6 @@ public class InjectedGlobalComponentRegistryTest { private RpcServer rpcServer; private ConfigDefinitionRepo defRepo; private PermanentApplicationPackage permanentApplicationPackage; - private HostRegistries hostRegistries; private GlobalComponentRegistry globalComponentRegistry; private ModelFactoryRegistry modelFactoryRegistry; private Zone zone; @@ -65,20 +64,21 @@ public class InjectedGlobalComponentRegistryTest { .configServerDBDir(temporaryFolder.newFolder("serverdb").getAbsolutePath()) .configDefinitionsDir(temporaryFolder.newFolder("configdefinitions").getAbsolutePath())); sessionPreparer = new SessionTest.MockSessionPreparer(); + HostRegistry hostRegistry = new HostRegistry(); rpcServer = new RpcServer(configserverConfig, null, Metrics.createTestMetrics(), - new HostRegistries(), new ConfigRequestHostLivenessTracker(), + hostRegistry, new ConfigRequestHostLivenessTracker(), new FileServer(temporaryFolder.newFolder("filereferences")), new NoopRpcAuthorizer(), new RpcRequestHandlerProvider()); - SuperModelGenerationCounter generationCounter = new SuperModelGenerationCounter(curator); defRepo = new StaticConfigDefinitionRepo(); permanentApplicationPackage = new PermanentApplicationPackage(configserverConfig); - hostRegistries = new HostRegistries(); HostProvisionerProvider hostProvisionerProvider = HostProvisionerProvider.withProvisioner(new MockProvisioner()); zone = Zone.defaultZone(); globalComponentRegistry = - new InjectedGlobalComponentRegistry(curator, configCurator, metrics, modelFactoryRegistry, sessionPreparer, rpcServer, configserverConfig, - generationCounter, defRepo, permanentApplicationPackage, hostRegistries, hostProvisionerProvider, zone, - new ConfigServerDB(configserverConfig), new InMemoryFlagSource(), new MockSecretStore()); + new InjectedGlobalComponentRegistry(curator, configCurator, metrics, modelFactoryRegistry, sessionPreparer, + rpcServer, configserverConfig, defRepo, permanentApplicationPackage, + hostProvisionerProvider, zone, + new ConfigServerDB(configserverConfig), new InMemoryFlagSource(), + new MockSecretStore(), hostRegistry); } @Test @@ -92,7 +92,6 @@ public class InjectedGlobalComponentRegistryTest { assertThat(globalComponentRegistry.getTenantListener().hashCode(), is(rpcServer.hashCode())); assertThat(globalComponentRegistry.getStaticConfigDefinitionRepo(), is(defRepo)); assertThat(globalComponentRegistry.getPermanentApplicationPackage(), is(permanentApplicationPackage)); - assertThat(globalComponentRegistry.getHostRegistries(), is(hostRegistries)); assertThat(globalComponentRegistry.getZone(), is (zone)); assertTrue(globalComponentRegistry.getHostProvisioner().isPresent()); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java index d094ad09bec..3e27e0b61ea 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/ModelContextImplTest.java @@ -3,11 +3,14 @@ package com.yahoo.vespa.config.server; import com.yahoo.cloud.config.ConfigserverConfig; import com.yahoo.component.Version; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.model.api.ContainerEndpoint; +import com.yahoo.config.model.api.HostProvisioner; import com.yahoo.config.model.api.ModelContext; import com.yahoo.config.model.api.Provisioned; import com.yahoo.config.model.application.provider.BaseDeployLogger; import com.yahoo.config.model.application.provider.MockFileRegistry; +import com.yahoo.config.model.deploy.DeployState; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Zone; @@ -45,15 +48,17 @@ public class ModelContextImplTest { .hostedVespa(false) .build(); + ApplicationPackage applicationPackage = MockApplicationPackage.createEmpty(); + HostProvisioner hostProvisioner = DeployState.getDefaultModelHostProvisioner(applicationPackage); ModelContext context = new ModelContextImpl( - MockApplicationPackage.createEmpty(), + applicationPackage, Optional.empty(), Optional.empty(), new BaseDeployLogger(), new StaticConfigDefinitionRepo(), new MockFileRegistry(), Optional.empty(), - Optional.empty(), + hostProvisioner, new Provisioned(), new ModelContextImpl.Properties( ApplicationId.defaultId(), @@ -72,7 +77,7 @@ public class ModelContextImplTest { new Version(7), new Version(8)); assertTrue(context.applicationPackage() instanceof MockApplicationPackage); - assertFalse(context.hostProvisioner().isPresent()); + assertEquals(hostProvisioner, context.getHostProvisioner()); assertFalse(context.permanentApplicationPackage().isPresent()); assertFalse(context.previousModel().isPresent()); assertTrue(context.getFileRegistry() instanceof MockFileRegistry); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/SuperModelControllerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/SuperModelControllerTest.java index 965374f2aa4..c58bd6d6b0a 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/SuperModelControllerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/SuperModelControllerTest.java @@ -83,34 +83,6 @@ public class SuperModelControllerTest { } @Test - public void test_lb_config_multiple_apps_legacy_super_model() throws IOException, SAXException { - Map<ApplicationId, ApplicationInfo> models = new LinkedHashMap<>(); - TenantName t1 = TenantName.from("t1"); - TenantName t2 = TenantName.from("t2"); - File testApp1 = new File("src/test/resources/deploy/app"); - File testApp2 = new File("src/test/resources/deploy/advancedapp"); - File testApp3 = new File("src/test/resources/deploy/advancedapp"); - - ApplicationId simple = applicationId("mysimpleapp", t1); - ApplicationId advanced = applicationId("myadvancedapp", t1); - ApplicationId tooAdvanced = applicationId("minetooadvancedapp", t2); - models.put(simple, createApplicationInfo(testApp1, simple, 4L)); - models.put(advanced, createApplicationInfo(testApp2, advanced, 4L)); - models.put(tooAdvanced, createApplicationInfo(testApp3, tooAdvanced, 4L)); - - SuperModel superModel = new SuperModel(models, true); - SuperModelController han = new SuperModelController(new SuperModelConfigProvider(superModel, Zone.defaultZone(), new InMemoryFlagSource()), new TestConfigDefinitionRepo(), 2, new UncompressedConfigResponseFactory()); - LbServicesConfig.Builder lb = new LbServicesConfig.Builder(); - han.getSuperModel().getConfig(lb); - LbServicesConfig lbc = new LbServicesConfig(lb); - assertThat(lbc.tenants().size(), is(2)); - assertThat(lbc.tenants("t1").applications().size(), is(2)); - assertThat(lbc.tenants("t2").applications().size(), is(1)); - assertThat(lbc.tenants("t2").applications("minetooadvancedapp:prod:default:default").hosts().size(), is(1)); - assertQrServer(lbc.tenants("t2").applications("minetooadvancedapp:prod:default:default")); - } - - @Test public void test_lb_config_multiple_apps() throws IOException, SAXException { Map<ApplicationId, ApplicationInfo> models = new LinkedHashMap<>(); TenantName t1 = TenantName.from("t1"); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/TestComponentRegistry.java b/configserver/src/test/java/com/yahoo/vespa/config/server/TestComponentRegistry.java index e6652c3c5e1..cbff99c04dc 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/TestComponentRegistry.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/TestComponentRegistry.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server; import com.yahoo.cloud.config.ConfigserverConfig; @@ -12,12 +12,12 @@ import com.yahoo.config.provision.Zone; import com.yahoo.container.jdisc.secretstore.SecretStore; import com.yahoo.vespa.config.server.application.PermanentApplicationPackage; import com.yahoo.vespa.config.server.application.TenantApplicationsTest; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.filedistribution.FileDistributionFactory; +import com.yahoo.vespa.config.server.filedistribution.MockFileDistributionFactory; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.modelfactory.ModelFactoryRegistry; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.provision.HostProvisionerProvider; -import com.yahoo.vespa.config.server.filedistribution.FileDistributionFactory; -import com.yahoo.vespa.config.server.filedistribution.MockFileDistributionFactory; import com.yahoo.vespa.config.server.session.SessionPreparer; import com.yahoo.vespa.config.server.tenant.MockTenantListener; import com.yahoo.vespa.config.server.tenant.TenantListener; @@ -50,7 +50,6 @@ public class TestComponentRegistry implements GlobalComponentRegistry { private final ReloadListener reloadListener; private final TenantListener tenantListener; private final PermanentApplicationPackage permanentApplicationPackage; - private final HostRegistries hostRegistries; private final FileDistributionFactory fileDistributionFactory; private final ModelFactoryRegistry modelFactoryRegistry; private final Optional<Provisioner> hostProvisioner; @@ -61,12 +60,12 @@ public class TestComponentRegistry implements GlobalComponentRegistry { private final ExecutorService zkCacheExecutor; private final SecretStore secretStore; private final FlagSource flagSource; + private final HostRegistry hostRegistry; private TestComponentRegistry(Curator curator, ConfigCurator configCurator, Metrics metrics, ModelFactoryRegistry modelFactoryRegistry, PermanentApplicationPackage permanentApplicationPackage, FileDistributionFactory fileDistributionFactory, - HostRegistries hostRegistries, ConfigserverConfig configserverConfig, SessionPreparer sessionPreparer, Optional<Provisioner> hostProvisioner, @@ -76,7 +75,8 @@ public class TestComponentRegistry implements GlobalComponentRegistry { Zone zone, Clock clock, SecretStore secretStore, - FlagSource flagSource) { + FlagSource flagSource, + HostRegistry hostRegistry) { this.curator = curator; this.configCurator = configCurator; this.metrics = metrics; @@ -85,7 +85,6 @@ public class TestComponentRegistry implements GlobalComponentRegistry { this.tenantListener = tenantListener; this.defRepo = defRepo; this.permanentApplicationPackage = permanentApplicationPackage; - this.hostRegistries = hostRegistries; this.fileDistributionFactory = fileDistributionFactory; this.modelFactoryRegistry = modelFactoryRegistry; this.hostProvisioner = hostProvisioner; @@ -97,6 +96,7 @@ public class TestComponentRegistry implements GlobalComponentRegistry { this.zkCacheExecutor = new InThreadExecutorService(); this.secretStore = secretStore; this.flagSource = flagSource; + this.hostRegistry = hostRegistry; } public static class Builder { @@ -112,13 +112,13 @@ public class TestComponentRegistry implements GlobalComponentRegistry { private ReloadListener reloadListener = new TenantApplicationsTest.MockReloadListener(); private final MockTenantListener tenantListener = new MockTenantListener(); private Optional<PermanentApplicationPackage> permanentApplicationPackage = Optional.empty(); - private final HostRegistries hostRegistries = new HostRegistries(); private final Optional<FileDistributionFactory> fileDistributionFactory = Optional.empty(); private ModelFactoryRegistry modelFactoryRegistry = new ModelFactoryRegistry(Collections.singletonList(new VespaModelFactory(new NullConfigModelRegistry()))); private Optional<Provisioner> hostProvisioner = Optional.empty(); private Zone zone = Zone.defaultZone(); private Clock clock = Clock.systemUTC(); private FlagSource flagSource = new InMemoryFlagSource(); + private HostRegistry hostRegistry = new HostRegistry(); public Builder configServerConfig(ConfigserverConfig configserverConfig) { this.configserverConfig = configserverConfig; @@ -175,6 +175,11 @@ public class TestComponentRegistry implements GlobalComponentRegistry { return this; } + public Builder hostRegistry(HostRegistry hostRegistry) { + this.hostRegistry = hostRegistry; + return this; + } + public TestComponentRegistry build() { final PermanentApplicationPackage permApp = this.permanentApplicationPackage .orElse(new PermanentApplicationPackage(configserverConfig)); @@ -188,9 +193,9 @@ public class TestComponentRegistry implements GlobalComponentRegistry { configserverConfig, defRepo, curator, zone, flagSource, secretStore); return new TestComponentRegistry(curator, ConfigCurator.create(curator), metrics, modelFactoryRegistry, - permApp, fileDistributionProvider, hostRegistries, configserverConfig, + permApp, fileDistributionProvider, configserverConfig, sessionPreparer, hostProvisioner, defRepo, reloadListener, tenantListener, - zone, clock, secretStore, flagSource); + zone, clock, secretStore, flagSource, hostRegistry); } } @@ -213,8 +218,6 @@ public class TestComponentRegistry implements GlobalComponentRegistry { @Override public PermanentApplicationPackage getPermanentApplicationPackage() { return permanentApplicationPackage; } @Override - public HostRegistries getHostRegistries() { return hostRegistries;} - @Override public ModelFactoryRegistry getModelFactoryRegistry() { return modelFactoryRegistry; } @Override public Optional<Provisioner> getHostProvisioner() { @@ -247,6 +250,11 @@ public class TestComponentRegistry implements GlobalComponentRegistry { return secretStore; } + @Override + public HostRegistry hostRegistry() { + return hostRegistry; + } + public FileDistributionFactory getFileDistributionFactory() { return fileDistributionFactory; } } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java index 947308962d4..ddbe97f4389 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/application/TenantApplicationsTest.java @@ -151,12 +151,12 @@ public class TenantApplicationsTest { } @Override - public void hostsUpdated(TenantName tenant, Collection<String> newHosts) { - tenantHosts.put(tenant.value(), newHosts); + public void hostsUpdated(ApplicationId applicationId, Collection<String> newHosts) { + tenantHosts.put(applicationId.tenant().value(), newHosts); } @Override - public void verifyHostsAreAvailable(TenantName tenant, Collection<String> newHosts) { + public void verifyHostsAreAvailable(ApplicationId applicationId, Collection<String> newHosts) { } @Override diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/host/HostRegistryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/host/HostRegistryTest.java index 63dfb1d01bd..173e9f3e148 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/host/HostRegistryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/host/HostRegistryTest.java @@ -1,37 +1,44 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.host; +import com.yahoo.config.provision.ApplicationId; +import org.junit.Test; + import java.util.ArrayList; import java.util.Collection; import java.util.List; -import org.junit.Test; - import static org.hamcrest.collection.IsIterableContainingInOrder.contains; import static org.hamcrest.core.Is.is; -import static org.junit.Assert.*; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertThat; /** * @author Ulf Lilleengen */ public class HostRegistryTest { + + private final ApplicationId foo = ApplicationId.from("foo", "app1", "default"); + private final ApplicationId bar = ApplicationId.from("bar", "app2", "default"); + @Test public void old_hosts_are_removed() { - HostRegistry<String> reg = new HostRegistry<>(); + HostRegistry reg = new HostRegistry(); assertNull(reg.getKeyForHost("foo.com")); - reg.update("fookey", List.of("foo.com", "bar.com", "baz.com")); - assertGetKey(reg, "foo.com", "fookey"); - assertGetKey(reg, "bar.com", "fookey"); - assertGetKey(reg, "baz.com", "fookey"); + reg.update(foo, List.of("foo.com", "bar.com", "baz.com")); + assertGetKey(reg, "foo.com", foo); + assertGetKey(reg, "bar.com", foo); + assertGetKey(reg, "baz.com", foo); assertThat(reg.getAllHosts().size(), is(3)); - reg.update("fookey", List.of("bar.com", "baz.com")); + reg.update(foo, List.of("bar.com", "baz.com")); assertNull(reg.getKeyForHost("foo.com")); - assertGetKey(reg, "bar.com", "fookey"); - assertGetKey(reg, "baz.com", "fookey"); + assertGetKey(reg, "bar.com", foo); + assertGetKey(reg, "baz.com", foo); assertThat(reg.getAllHosts().size(), is(2)); assertThat(reg.getAllHosts(), contains("bar.com", "baz.com")); - reg.removeHostsForKey("fookey"); + reg.removeHostsForKey(foo); assertThat(reg.getAllHosts().size(), is(0)); assertNull(reg.getKeyForHost("foo.com")); assertNull(reg.getKeyForHost("bar.com")); @@ -39,51 +46,51 @@ public class HostRegistryTest { @Test public void multiple_keys_are_handled() { - HostRegistry<String> reg = new HostRegistry<>(); - reg.update("fookey", List.of("foo.com", "bar.com")); - reg.update("barkey", List.of("baz.com", "quux.com")); - assertGetKey(reg, "foo.com", "fookey"); - assertGetKey(reg, "bar.com", "fookey"); - assertGetKey(reg, "baz.com", "barkey"); - assertGetKey(reg, "quux.com", "barkey"); + HostRegistry reg = new HostRegistry(); + reg.update(foo, List.of("foo.com", "bar.com")); + reg.update(bar, List.of("baz.com", "quux.com")); + assertGetKey(reg, "foo.com", foo); + assertGetKey(reg, "bar.com", foo); + assertGetKey(reg, "baz.com", bar); + assertGetKey(reg, "quux.com", bar); } @Test(expected = IllegalArgumentException.class) public void keys_cannot_overlap() { - HostRegistry<String> reg = new HostRegistry<>(); - reg.update("fookey", List.of("foo.com", "bar.com")); - reg.update("barkey", List.of("bar.com", "baz.com")); + HostRegistry reg = new HostRegistry(); + reg.update(foo, List.of("foo.com", "bar.com")); + reg.update(bar, List.of("bar.com", "baz.com")); } @Test public void all_hosts_are_returned() { - HostRegistry<String> reg = new HostRegistry<>(); - reg.update("fookey", List.of("foo.com", "bar.com")); - reg.update("barkey", List.of("baz.com", "quux.com")); + HostRegistry reg = new HostRegistry(); + reg.update(foo, List.of("foo.com", "bar.com")); + reg.update(bar, List.of("baz.com", "quux.com")); assertThat(reg.getAllHosts().size(), is(4)); } @Test public void ensure_that_collection_is_copied() { - HostRegistry<String> reg = new HostRegistry<>(); + HostRegistry reg = new HostRegistry(); List<String> hosts = new ArrayList<>(List.of("foo.com", "bar.com", "baz.com")); - reg.update("fookey", hosts); - assertThat(reg.getHostsForKey("fookey").size(), is(3)); + reg.update(foo, hosts); + assertThat(reg.getHostsForKey(foo).size(), is(3)); hosts.remove(2); - assertThat(reg.getHostsForKey("fookey").size(), is(3)); + assertThat(reg.getHostsForKey(foo).size(), is(3)); } @Test public void ensure_that_underlying_hosts_do_not_change() { - HostRegistry<String> reg = new HostRegistry<>(); - reg.update("fookey", List.of("foo.com", "bar.com", "baz.com")); + HostRegistry reg = new HostRegistry(); + reg.update(foo, List.of("foo.com", "bar.com", "baz.com")); Collection<String> hosts = reg.getAllHosts(); assertThat(hosts.size(), is(3)); - reg.update("fookey", List.of("foo.com")); + reg.update(foo, List.of("foo.com")); assertThat(hosts.size(), is(3)); } - private void assertGetKey(HostRegistry<String> reg, String host, String expectedKey) { + private void assertGetKey(HostRegistry reg, String host, ApplicationId expectedKey) { assertNotNull(reg.getKeyForHost(host)); assertThat(reg.getKeyForHost(host), is(expectedKey)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java index 126dfe88141..809bcdc1a6e 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationContentHandlerTest.java @@ -36,14 +36,12 @@ import static org.junit.Assert.assertThat; public class ApplicationContentHandlerTest extends ContentHandlerTestBase { private static final File testApp = new File("src/test/apps/content"); - private static final File testApp2 = new File("src/test/apps/content2"); private final TenantName tenantName1 = TenantName.from("mofet"); private final TenantName tenantName2 = TenantName.from("bla"); private final String baseServer = "http://foo:1337"; private final ApplicationId appId1 = new ApplicationId.Builder().tenant(tenantName1).applicationName("foo").instanceName("quux").build(); - private final ApplicationId appId2 = new ApplicationId.Builder().tenant(tenantName2).applicationName("foo").instanceName("quux").build(); private ApplicationRepository applicationRepository; private ApplicationHandler handler; @@ -77,8 +75,6 @@ public class ApplicationContentHandlerTest extends ContentHandlerTestBase { .build(); applicationRepository.deploy(testApp, prepareParams(appId1)); - applicationRepository.deploy(testApp2, prepareParams(appId2)); - handler = new ApplicationHandler(ApplicationHandler.testOnlyContext(), Zone.defaultZone(), applicationRepository); @@ -113,14 +109,6 @@ public class ApplicationContentHandlerTest extends ContentHandlerTestBase { } @Test - public void require_that_multiple_tenants_are_handled() throws IOException { - assertContent("/test.txt", "foo\n"); - pathPrefix = createPath(appId2, Zone.defaultZone()); - baseUrl = baseServer + pathPrefix; - assertContent("/test.txt", "bar\n"); - } - - @Test public void require_that_get_does_not_set_write_flag() throws IOException { Tenant tenant1 = applicationRepository.getTenant(appId1); Session session = applicationRepository.getActiveLocalSession(tenant1, appId1); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java index a0c63c8bba1..3fb999e85f7 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/http/v2/ApplicationHandlerTest.java @@ -76,6 +76,7 @@ import static org.mockito.Mockito.when; public class ApplicationHandlerTest { private static final File testApp = new File("src/test/apps/app"); + private static final File testAppJdiscOnly = new File("src/test/apps/app-jdisc-only"); private final static TenantName mytenantName = TenantName.from("mytenant"); private final static ApplicationId myTenantApplicationId = ApplicationId.from(mytenantName, ApplicationName.defaultName(), InstanceName.defaultName()); @@ -150,7 +151,7 @@ public class ApplicationHandlerTest { .instanceName("quux") .build(); PrepareParams prepareParams2 = new PrepareParams.Builder().applicationId(fooId).build(); - applicationRepository.deploy(testApp, prepareParams2); + applicationRepository.deploy(testAppJdiscOnly, prepareParams2); assertApplicationExists(fooId, Zone.defaultZone()); deleteAndAssertOKResponseMocked(fooId, true); diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java index 9bd7a25faf2..a072bf62852 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/maintenance/TenantsMaintainerTest.java @@ -56,7 +56,8 @@ public class TenantsMaintainerTest { assertNotNull(tenantRepository.getTenant(TenantName.defaultName())); assertNotNull(tenantRepository.getTenant(TenantRepository.HOSTED_VESPA_TENANT)); - // Add tenant again and deploy + // Delete app, add tenant again and deploy + tester.applicationRepository().delete(applicationId(shouldNotBeDeleted)); tenantRepository.addTenant(shouldBeDeleted); tester.deployApp(applicationPackage, prepareParams(shouldBeDeleted)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/MockRpcServer.java b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/MockRpcServer.java index 7f4733f0b7c..997633eeb53 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/MockRpcServer.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/MockRpcServer.java @@ -1,4 +1,4 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.vespa.config.server.rpc; import com.yahoo.cloud.config.ConfigserverConfig; @@ -8,7 +8,7 @@ import com.yahoo.vespa.config.protocol.JRTServerConfigRequest; import com.yahoo.vespa.config.server.GetConfigContext; import com.yahoo.vespa.config.server.filedistribution.FileServer; import com.yahoo.vespa.config.server.host.ConfigRequestHostLivenessTracker; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.rpc.security.NoopRpcAuthorizer; @@ -37,7 +37,7 @@ public class MockRpcServer extends RpcServer { super(createConfig(port), null, Metrics.createTestMetrics(), - new HostRegistries(), + new HostRegistry(), new ConfigRequestHostLivenessTracker(), new FileServer(tempDir), new NoopRpcAuthorizer(), diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcServerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcServerTest.java index 735eae2700f..5a41eff3cc9 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcServerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcServerTest.java @@ -64,31 +64,19 @@ public class RpcServerTest { try (RpcTester tester = new RpcTester(applicationId, temporaryFolder)) { ApplicationRepository applicationRepository = tester.applicationRepository(); applicationRepository.deploy(testApp, new PrepareParams.Builder().applicationId(applicationId).build()); - TenantApplications applicationRepo = tester.tenant().getApplicationRepo(); - RemoteSession activeSession = applicationRepository.getActiveRemoteSession(applicationId); - ApplicationSet applicationSet = tester.tenant().getSessionRepository().ensureApplicationLoaded(activeSession); - applicationRepo.activateApplication(applicationSet, activeSession.getSessionId()); testPrintStatistics(tester); testGetConfig(tester); testEnabled(tester); testApplicationNotLoadedErrorWhenAppDeleted(tester); - testEmptySentinelConfigWhenAppDeletedOnHostedVespa(); } } - private void testApplicationNotLoadedErrorWhenAppDeleted(RpcTester tester) throws InterruptedException, IOException { - tester.rpcServer().onTenantDelete(tenantName); - tester.rpcServer().onTenantsLoaded(); + private void testApplicationNotLoadedErrorWhenAppDeleted(RpcTester tester) { + tester.applicationRepository().delete(applicationId); JRTClientConfigRequest clientReq = createSimpleRequest(); tester.performRequest(clientReq.getRequest()); assertFalse(clientReq.validateResponse()); assertThat(clientReq.errorCode(), is(ErrorCode.APPLICATION_NOT_LOADED)); - tester.stopRpc(); - tester.createAndStartRpcServer(); - tester.rpcServer().onTenantsLoaded(); - clientReq = createSimpleRequest(); - tester.performRequest(clientReq.getRequest()); - assertTrue(clientReq.validateResponse()); } @Test diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java index 2b2ed13fcfe..8284aacc97e 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/RpcTester.java @@ -12,7 +12,6 @@ import com.yahoo.jrt.Supervisor; import com.yahoo.jrt.Transport; import com.yahoo.net.HostName; import com.yahoo.test.ManualClock; -import com.yahoo.vespa.config.GenerationCounter; import com.yahoo.vespa.config.server.ApplicationRepository; import com.yahoo.vespa.config.server.MemoryGenerationCounter; import com.yahoo.vespa.config.server.MockProvisioner; @@ -24,7 +23,7 @@ import com.yahoo.vespa.config.server.TestConfigDefinitionRepo; import com.yahoo.vespa.config.server.application.OrchestratorMock; import com.yahoo.vespa.config.server.filedistribution.FileServer; import com.yahoo.vespa.config.server.host.ConfigRequestHostLivenessTracker; -import com.yahoo.vespa.config.server.host.HostRegistries; +import com.yahoo.vespa.config.server.host.HostRegistry; import com.yahoo.vespa.config.server.monitoring.Metrics; import com.yahoo.vespa.config.server.rpc.security.NoopRpcAuthorizer; import com.yahoo.vespa.config.server.tenant.Tenant; @@ -55,15 +54,15 @@ public class RpcTester implements AutoCloseable { private final ManualClock clock = new ManualClock(Instant.ofEpochMilli(100)); private final String myHostname = HostName.getLocalhost(); private final HostLivenessTracker hostLivenessTracker = new ConfigRequestHostLivenessTracker(clock); - private final GenerationCounter generationCounter; private final Spec spec; - private RpcServer rpcServer; + private final RpcServer rpcServer; private Thread t; private Supervisor sup; private final ApplicationId applicationId; private final TenantName tenantName; private final TenantRepository tenantRepository; + private final HostRegistry hostRegistry = new HostRegistry(); private final ApplicationRepository applicationRepository; private final List<Integer> allocatedPorts = new ArrayList<>(); @@ -85,23 +84,27 @@ public class RpcTester implements AutoCloseable { .configDefinitionsDir(temporaryFolder.newFolder().getAbsolutePath()) .fileReferencesDir(temporaryFolder.newFolder().getAbsolutePath()); configserverConfig = new ConfigserverConfig(configBuilder); + rpcServer = createRpcServer(configserverConfig); TestComponentRegistry componentRegistry = new TestComponentRegistry.Builder() .configDefinitionRepo(new TestConfigDefinitionRepo()) .configServerConfig(configserverConfig) + .reloadListener(rpcServer) + .hostRegistry(hostRegistry) .build(); tenantRepository = new TenantRepository(componentRegistry); tenantRepository.addTenant(tenantName); + startRpcServer(); applicationRepository = new ApplicationRepository.Builder() .withTenantRepository(tenantRepository) + .withConfigserverConfig(configserverConfig) .withProvisioner(new MockProvisioner()) .withOrchestrator(new OrchestratorMock()) .build(); - generationCounter = new MemoryGenerationCounter(); - createAndStartRpcServer(); assertFalse(hostLivenessTracker.lastRequestFrom(myHostname).isPresent()); } public void close() { + rpcServer.stop(); for (Integer port : allocatedPorts) { PortRangeAllocator.releasePort(port); } @@ -113,24 +116,25 @@ public class RpcTester implements AutoCloseable { return port; } - void createAndStartRpcServer() throws IOException { - HostRegistries hostRegistries = new HostRegistries(); - hostRegistries.createApplicationHostRegistry(tenantName).update(applicationId, List.of("localhost")); - hostRegistries.getTenantHostRegistry().update(tenantName, List.of("localhost")); - rpcServer = new RpcServer(configserverConfig, - new SuperModelRequestHandler(new TestConfigDefinitionRepo(), - configserverConfig, - new SuperModelManager( - configserverConfig, - Zone.defaultZone() , - generationCounter, - new InMemoryFlagSource())), - Metrics.createTestMetrics(), - hostRegistries, - hostLivenessTracker, - new FileServer(temporaryFolder.newFolder()), - new NoopRpcAuthorizer(), - new RpcRequestHandlerProvider()); + RpcServer createRpcServer(ConfigserverConfig config) throws IOException { + return new RpcServer(config, + new SuperModelRequestHandler(new TestConfigDefinitionRepo(), + configserverConfig, + new SuperModelManager( + config, + Zone.defaultZone(), + new MemoryGenerationCounter(), + new InMemoryFlagSource())), + Metrics.createTestMetrics(), + hostRegistry, + hostLivenessTracker, + new FileServer(temporaryFolder.newFolder()), + new NoopRpcAuthorizer(), + new RpcRequestHandlerProvider()); + } + + void startRpcServer() { + hostRegistry.update(applicationId, List.of("localhost")); rpcServer.onTenantCreate(tenantRepository.getTenant(tenantName)); t = new Thread(rpcServer); t.start(); @@ -165,7 +169,7 @@ public class RpcTester implements AutoCloseable { void performRequest(Request req) { clock.advance(Duration.ofMillis(10)); - sup.connect(spec).invokeSync(req, 120.0); + sup.connect(spec).invokeSync(req, 10.0); if (req.methodName().equals(RpcServer.getConfigMethodName)) assertEquals(clock.instant(), hostLivenessTracker.lastRequestFrom(myHostname).get()); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizerTest.java index 9e1edb35b8f..12debc347de 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/rpc/security/MultiTenantRpcAuthorizerTest.java @@ -7,7 +7,6 @@ import com.yahoo.config.FileReference; import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.HostName; import com.yahoo.config.provision.NodeType; -import com.yahoo.config.provision.TenantName; import com.yahoo.config.provision.security.NodeIdentifier; import com.yahoo.config.provision.security.NodeIdentifierException; import com.yahoo.config.provision.security.NodeIdentity; @@ -66,7 +65,7 @@ public class MultiTenantRpcAuthorizerTest { @Test public void configserver_can_access_files_and_config() throws InterruptedException, ExecutionException { RpcAuthorizer authorizer = createAuthorizer(new NodeIdentity.Builder(NodeType.config).build(), - new HostRegistry<>()); + new HostRegistry()); Request configRequest = createConfigRequest(new ConfigKey<>("name", "configid", "namespace"), HOSTNAME); authorizer.authorizeConfigRequest(configRequest) @@ -83,8 +82,8 @@ public class MultiTenantRpcAuthorizerTest { .applicationId(APPLICATION_ID) .build(); - HostRegistry<TenantName> hostRegistry = new HostRegistry<>(); - hostRegistry.update(APPLICATION_ID.tenant(), List.of(HOSTNAME.value())); + HostRegistry hostRegistry = new HostRegistry(); + hostRegistry.update(APPLICATION_ID, List.of(HOSTNAME.value())); RpcAuthorizer authorizer = createAuthorizer(identity, hostRegistry); @@ -99,7 +98,7 @@ public class MultiTenantRpcAuthorizerTest { @Test public void proxy_node_can_access_lbservice_config() throws ExecutionException, InterruptedException { - RpcAuthorizer authorizer = createAuthorizer(new NodeIdentity.Builder(NodeType.proxy).build(), new HostRegistry<>()); + RpcAuthorizer authorizer = createAuthorizer(new NodeIdentity.Builder(NodeType.proxy).build(), new HostRegistry()); Request configRequest = createConfigRequest( new ConfigKey<>(LbServicesConfig.CONFIG_DEF_NAME, "*", LbServicesConfig.CONFIG_DEF_NAMESPACE), @@ -110,7 +109,7 @@ public class MultiTenantRpcAuthorizerTest { @Test public void tenant_node_cannot_access_lbservice_config() throws ExecutionException, InterruptedException { - RpcAuthorizer authorizer = createAuthorizer(new NodeIdentity.Builder(NodeType.tenant).build(), new HostRegistry<>()); + RpcAuthorizer authorizer = createAuthorizer(new NodeIdentity.Builder(NodeType.tenant).build(), new HostRegistry()); Request configRequest = createConfigRequest( new ConfigKey<>(LbServicesConfig.CONFIG_DEF_NAME, "*", LbServicesConfig.CONFIG_DEF_NAMESPACE), @@ -129,8 +128,8 @@ public class MultiTenantRpcAuthorizerTest { .applicationId(APPLICATION_ID) .build(); - HostRegistry<TenantName> hostRegistry = new HostRegistry<>(); - hostRegistry.update(APPLICATION_ID.tenant(), List.of(HOSTNAME.value())); + HostRegistry hostRegistry = new HostRegistry(); + hostRegistry.update(APPLICATION_ID, List.of(HOSTNAME.value())); RpcAuthorizer authorizer = createAuthorizer(identity, hostRegistry); @@ -149,8 +148,8 @@ public class MultiTenantRpcAuthorizerTest { .applicationId(EVIL_APP_ID) .build(); - HostRegistry<TenantName> hostRegistry = new HostRegistry<>(); - hostRegistry.update(APPLICATION_ID.tenant(), List.of(HOSTNAME.value())); + HostRegistry hostRegistry = new HostRegistry(); + hostRegistry.update(APPLICATION_ID, List.of(HOSTNAME.value())); RpcAuthorizer authorizer = createAuthorizer(identity, hostRegistry); @@ -169,7 +168,7 @@ public class MultiTenantRpcAuthorizerTest { .applicationId(EVIL_APP_ID) .build(); - HostRegistry<TenantName> hostRegistry = new HostRegistry<>(); + HostRegistry hostRegistry = new HostRegistry(); RpcAuthorizer authorizer = createAuthorizer(identity, hostRegistry); @@ -188,8 +187,8 @@ public class MultiTenantRpcAuthorizerTest { .applicationId(EVIL_APP_ID) .build(); - HostRegistry<TenantName> hostRegistry = new HostRegistry<>(); - hostRegistry.update(EVIL_APP_ID.tenant(), List.of(HOSTNAME.value())); + HostRegistry hostRegistry = new HostRegistry(); + hostRegistry.update(EVIL_APP_ID, List.of(HOSTNAME.value())); RpcAuthorizer authorizer = createAuthorizer(identity, hostRegistry); @@ -208,7 +207,7 @@ public class MultiTenantRpcAuthorizerTest { .applicationId(APPLICATION_ID) .build(); - HostRegistry<TenantName> hostRegistry = new HostRegistry<>(); + HostRegistry hostRegistry = new HostRegistry(); RpcAuthorizer authorizer = createAuthorizer(identity, hostRegistry); @@ -219,7 +218,7 @@ public class MultiTenantRpcAuthorizerTest { } - private static RpcAuthorizer createAuthorizer(NodeIdentity identity, HostRegistry<TenantName> hostRegistry) { + private static RpcAuthorizer createAuthorizer(NodeIdentity identity, HostRegistry hostRegistry) { return new MultiTenantRpcAuthorizer( new StaticNodeIdentifier(identity), hostRegistry, diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java index 76958264d84..90d3bddc88d 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionPreparerTest.java @@ -187,7 +187,7 @@ public class SessionPreparerTest { @Test(expected = InvalidApplicationException.class) public void require_exception_for_overlapping_host() throws IOException { FilesApplicationPackage app = getApplicationPackage(testApp); - HostRegistry<ApplicationId> hostValidator = new HostRegistry<>(); + HostRegistry hostValidator = new HostRegistry(); hostValidator.update(applicationId("foo"), Collections.singletonList("mytesthost")); preparer.prepare(hostValidator, new BaseDeployLogger(), new PrepareParams.Builder().applicationId(applicationId("default")).build(), Optional.empty(), Instant.now(), app.getAppDir(), app, createSessionZooKeeperClient()); @@ -200,7 +200,7 @@ public class SessionPreparerTest { if (level.equals(Level.WARNING) && message.contains("The host mytesthost is already in use")) logged.append("ok"); }; FilesApplicationPackage app = getApplicationPackage(testApp); - HostRegistry<ApplicationId> hostValidator = new HostRegistry<>(); + HostRegistry hostValidator = new HostRegistry(); ApplicationId applicationId = applicationId(); hostValidator.update(applicationId, Collections.singletonList("mytesthost")); preparer.prepare(hostValidator, logger, new PrepareParams.Builder().applicationId(applicationId).build(), @@ -367,7 +367,7 @@ public class SessionPreparerTest { private PrepareResult prepare(File app, PrepareParams params, long sessionId) throws IOException { FilesApplicationPackage applicationPackage = getApplicationPackage(app); - return preparer.prepare(new HostRegistry<>(), getLogger(), params, + return preparer.prepare(new HostRegistry(), getLogger(), params, Optional.empty(), Instant.now(), applicationPackage.getAppDir(), applicationPackage, createSessionZooKeeperClient(sessionId)); } diff --git a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java index bfcfc7d6e43..c7612937d47 100644 --- a/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java +++ b/configserver/src/test/java/com/yahoo/vespa/config/server/session/SessionRepositoryTest.java @@ -60,6 +60,7 @@ public class SessionRepositoryTest { private static final TenantName tenantName = TenantName.defaultName(); private static final ApplicationId applicationId = ApplicationId.from(tenantName.value(), "testApp", "default"); private static final File testApp = new File("src/test/apps/app"); + private static final File appJdiscOnly = new File("src/test/apps/app-jdisc-only"); private MockCurator curator; private TenantRepository tenantRepository; @@ -135,7 +136,7 @@ public class SessionRepositoryTest { // tenant is "newTenant" TenantName newTenant = TenantName.from("newTenant"); tenantRepository.addTenant(newTenant); - long sessionId = deploy(ApplicationId.from(newTenant.value(), "testapp", "default")); + long sessionId = deploy(ApplicationId.from(newTenant.value(), "testapp", "default"), appJdiscOnly); SessionRepository sessionRepository2 = tenantRepository.getTenant(newTenant).getSessionRepository(); assertNotNull(sessionRepository2.getLocalSession(sessionId)); } diff --git a/configserver/src/test/resources/deploy/advancedapp/services.xml b/configserver/src/test/resources/deploy/advancedapp/services.xml index e3d5aea585b..b8e93b14317 100644 --- a/configserver/src/test/resources/deploy/advancedapp/services.xml +++ b/configserver/src/test/resources/deploy/advancedapp/services.xml @@ -22,7 +22,7 @@ <documents> <document type="keyvalue" mode="index"/> </documents> - <nodes>> + <nodes> <node hostalias="node1" distribution-key="0"/> </nodes> </content> diff --git a/container-dev/pom.xml b/container-dev/pom.xml index dd2d9ceb188..48f409dfc2c 100644 --- a/container-dev/pom.xml +++ b/container-dev/pom.xml @@ -140,6 +140,10 @@ <groupId>net.java.dev.jna</groupId> <artifactId>jna</artifactId> </exclusion> + <exclusion> + <groupId>io.airlift</groupId> + <artifactId>aircompressor</artifactId> + </exclusion> </exclusions> </dependency> <dependency> diff --git a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentNode.java b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentNode.java index 92695565d47..b6fa4241e26 100644 --- a/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentNode.java +++ b/container-di/src/main/java/com/yahoo/container/di/componentgraph/core/ComponentNode.java @@ -1,4 +1,4 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.container.di.componentgraph.core; import com.google.inject.Inject; @@ -16,6 +16,8 @@ import java.lang.reflect.InvocationTargetException; import java.lang.reflect.Modifier; import java.lang.reflect.ParameterizedType; import java.lang.reflect.Type; +import java.time.Duration; +import java.time.Instant; import java.util.ArrayList; import java.util.Arrays; import java.util.Comparator; @@ -30,6 +32,7 @@ import static com.yahoo.container.di.componentgraph.core.Exceptions.cutStackTrac import static com.yahoo.container.di.componentgraph.core.Exceptions.removeStackTrace; import static com.yahoo.container.di.componentgraph.core.Keys.createKey; import static java.util.logging.Level.FINE; +import static java.util.logging.Level.INFO; /** * @author Tony Vaagenes @@ -148,12 +151,14 @@ public class ComponentNode extends Node { Object instance; try { - log.log(FINE, "Constructing " + idAndType()); + log.log(FINE, () -> "Constructing " + idAndType()); + Instant start = Instant.now(); instance = constructor.newInstance(actualArguments.toArray()); - log.log(FINE, "Finished constructing " + idAndType()); + Duration duration = Duration.between(start, Instant.now()); + log.log(duration.compareTo(Duration.ofMinutes(1)) > 0 ? INFO : FINE, + () -> "Finished constructing " + idAndType() + " in " + duration); } catch (InvocationTargetException | InstantiationException | IllegalAccessException e) { StackTraceElement dependencyInjectorMarker = new StackTraceElement("============= Dependency Injection =============", "newInstance", null, -1); - throw removeStackTrace(new ComponentConstructorException("Error constructing " + idAndType() + ": " + e.getMessage(), cutStackTraceAtConstructor(e.getCause(), dependencyInjectorMarker))); } diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index 6f48ae5b41a..705318cb8de 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -884,10 +884,12 @@ "public void <init>(java.lang.String, java.lang.String)", "public int getTargetNumHits()", "public java.lang.String getIndexName()", + "public double getDistanceThreshold()", "public int getHnswExploreAdditionalHits()", "public boolean getAllowApproximate()", "public java.lang.String getQueryTensorName()", "public void setTargetNumHits(int)", + "public void setDistanceThreshold(double)", "public void setHnswExploreAdditionalHits(int)", "public void setAllowApproximate(boolean)", "public void setIndexName(java.lang.String)", @@ -5619,6 +5621,19 @@ ], "fields": [] }, + "com.yahoo.search.query.profile.CompoundNameChildCache": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public", + "final" + ], + "methods": [ + "public void <init>()", + "public com.yahoo.processing.request.CompoundName append(com.yahoo.processing.request.CompoundName, java.lang.String)" + ], + "fields": [] + }, "com.yahoo.search.query.profile.CopyOnWriteContent": { "superClass": "com.yahoo.component.provider.FreezableClass", "interfaces": [ @@ -6129,7 +6144,7 @@ ], "methods": [ "public void <init>()", - "public void put(com.yahoo.processing.request.CompoundName, com.yahoo.search.query.profile.DimensionBinding, java.lang.Object)", + "public void put(com.yahoo.processing.request.CompoundName, com.yahoo.search.query.profile.compiled.Binding, java.lang.Object)", "public com.yahoo.search.query.profile.compiled.DimensionalMap build()" ], "fields": [] @@ -6156,7 +6171,7 @@ "methods": [ "public void <init>()", "public java.lang.Object valueFor(com.yahoo.search.query.profile.compiled.Binding)", - "public void add(java.lang.Object, com.yahoo.search.query.profile.DimensionBinding)", + "public void add(java.lang.Object, com.yahoo.search.query.profile.compiled.Binding)", "public com.yahoo.search.query.profile.compiled.DimensionalValue build(java.util.Map)" ], "fields": [] diff --git a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java index e237463582f..bb95cbad178 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java @@ -22,6 +22,7 @@ public class NearestNeighborItem extends SimpleTaggableItem { private int targetNumHits = 0; private int hnswExploreAdditionalHits = 0; + private double distanceThreshold = Double.POSITIVE_INFINITY; private boolean approximate = true; private String field; private final String queryTensorName; @@ -37,6 +38,9 @@ public class NearestNeighborItem extends SimpleTaggableItem { /** Returns the field name */ public String getIndexName() { return field; } + /** Returns the distance threshold for nearest-neighbor hits */ + public double getDistanceThreshold () { return this.distanceThreshold ; } + /** Returns the number of extra hits to explore in HNSW algorithm */ public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; } @@ -49,6 +53,9 @@ public class NearestNeighborItem extends SimpleTaggableItem { /** Set the K number of hits to produce */ public void setTargetNumHits(int target) { this.targetNumHits = target; } + /** Set the distance threshold for nearest-neighbor hits */ + public void setDistanceThreshold(double threshold) { this.distanceThreshold = threshold; } + /** Set the number of extra hits to explore in HNSW algorithm */ public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; } @@ -72,9 +79,18 @@ public class NearestNeighborItem extends SimpleTaggableItem { super.encodeThis(buffer); putString(field, buffer); putString(queryTensorName, buffer); + int approxNum = (approximate ? 1 : 0); + // should become always-true later: + boolean sendDistanceThreshold = (distanceThreshold < Double.POSITIVE_INFINITY); + if (sendDistanceThreshold) { + approxNum |= 0x40; // temporary flag bit + } IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer); - IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer); + IntegerCompressor.putCompressedPositiveNumber(approxNum, buffer); IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer); + if (sendDistanceThreshold) { + buffer.putDouble(distanceThreshold); + } return 1; // number of encoded stack dump items } @@ -83,6 +99,7 @@ public class NearestNeighborItem extends SimpleTaggableItem { buffer.append("{field=").append(field); buffer.append(",queryTensorName=").append(queryTensorName); buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits); + buffer.append(",distanceThreshold=").append(distanceThreshold); buffer.append(",approximate=").append(approximate); buffer.append(",targetHits=").append(targetNumHits).append("}"); } @@ -93,6 +110,7 @@ public class NearestNeighborItem extends SimpleTaggableItem { discloser.addProperty("field", field); discloser.addProperty("queryTensorName", queryTensorName); discloser.addProperty("hnsw.exploreAdditionalHits", hnswExploreAdditionalHits); + discloser.addProperty("distanceThreshold", distanceThreshold); discloser.addProperty("approximate", approximate); discloser.addProperty("targetHits", targetNumHits); } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java index d8fb7b46440..c60e1bf39cb 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InterleavedSearchInvoker.java @@ -43,6 +43,7 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM private final SearchCluster searchCluster; private final LinkedBlockingQueue<SearchInvoker> availableForProcessing; private final Set<Integer> alreadyFailedNodes; + private final boolean isContentWellBalanced; private Query query; private boolean adaptiveTimeoutCalculated = false; @@ -59,13 +60,14 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM private boolean timedOut = false; private boolean degradedByMatchPhase = false; - public InterleavedSearchInvoker(Collection<SearchInvoker> invokers, SearchCluster searchCluster, Set<Integer> alreadyFailedNodes) { + public InterleavedSearchInvoker(Collection<SearchInvoker> invokers, boolean isContentWellBalanced, SearchCluster searchCluster, Set<Integer> alreadyFailedNodes) { super(Optional.empty()); this.invokers = Collections.newSetFromMap(new IdentityHashMap<>()); this.invokers.addAll(invokers); this.searchCluster = searchCluster; this.availableForProcessing = newQueue(); this.alreadyFailedNodes = alreadyFailedNodes; + this.isContentWellBalanced = isContentWellBalanced; } /** @@ -82,10 +84,13 @@ public class InterleavedSearchInvoker extends SearchInvoker implements ResponseM int originalHits = query.getHits(); int originalOffset = query.getOffset(); int neededHits = originalHits + originalOffset; - Double topkProbabilityOverrride = query.properties().getDouble(Dispatcher.topKProbability); - int q = (topkProbabilityOverrride != null) - ? searchCluster.estimateHitsToFetch(neededHits, invokers.size(), topkProbabilityOverrride) - : searchCluster.estimateHitsToFetch(neededHits, invokers.size()); + int q = neededHits; + if (isContentWellBalanced) { + Double topkProbabilityOverrride = query.properties().getDouble(Dispatcher.topKProbability); + q = (topkProbabilityOverrride != null) + ? searchCluster.estimateHitsToFetch(neededHits, invokers.size(), topkProbabilityOverrride) + : searchCluster.estimateHitsToFetch(neededHits, invokers.size()); + } query.setHits(q); query.setOffset(0); diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/InvokerFactory.java b/container-search/src/main/java/com/yahoo/search/dispatch/InvokerFactory.java index 03160e6c9c7..f65e0e43757 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/InvokerFactory.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/InvokerFactory.java @@ -46,12 +46,12 @@ public abstract class InvokerFactory { * @return the invoker or empty if some node in the * list is invalid and the remaining coverage is not sufficient */ - public Optional<SearchInvoker> createSearchInvoker(VespaBackEndSearcher searcher, - Query query, - OptionalInt groupId, - List<Node> nodes, - boolean acceptIncompleteCoverage, - int maxHits) { + Optional<SearchInvoker> createSearchInvoker(VespaBackEndSearcher searcher, + Query query, + OptionalInt groupId, + List<Node> nodes, + boolean acceptIncompleteCoverage, + int maxHits) { List<SearchInvoker> invokers = new ArrayList<>(nodes.size()); Set<Integer> failed = null; for (Node node : nodes) { @@ -90,7 +90,7 @@ public abstract class InvokerFactory { if (invokers.size() == 1 && failed == null) { return Optional.of(invokers.get(0)); } else { - return Optional.of(new InterleavedSearchInvoker(invokers, searchCluster, failed)); + return Optional.of(new InterleavedSearchInvoker(invokers, searchCluster.isGroupWellBalanced(groupId), searchCluster, failed)); } } diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/Group.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/Group.java index ec616a18e09..e5066797b06 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/Group.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/Group.java @@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableList; import java.util.List; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicLong; +import java.util.logging.Logger; /** * A group in a search cluster. This class is multithread safe. @@ -22,6 +23,9 @@ public class Group { private final AtomicBoolean hasFullCoverage = new AtomicBoolean(true); private final AtomicLong activeDocuments = new AtomicLong(0); private final AtomicBoolean isBlockingWrites = new AtomicBoolean(false); + private final AtomicBoolean isContentWellBalanced = new AtomicBoolean(true); + private final static double MAX_UNBALANCE = 0.10; // If documents on a node is more than 10% off from the average the group is unbalanced + private static final Logger log = Logger.getLogger(Group.class.getName()); public Group(int id, List<Node> nodes) { this.id = id; @@ -53,25 +57,34 @@ public class Group { } public int workingNodes() { - int nodesUp = 0; - for (Node node : nodes) { - if (node.isWorking() == Boolean.TRUE) { - nodesUp++; - } - } - return nodesUp; + return (int) nodes.stream().filter(node -> node.isWorking() == Boolean.TRUE).count(); } void aggregateNodeValues() { - activeDocuments.set(nodes.stream().filter(node -> node.isWorking() == Boolean.TRUE).mapToLong(Node::getActiveDocuments).sum()); - isBlockingWrites.set(nodes.stream().anyMatch(node -> node.isBlockingWrites())); + long activeDocs = nodes.stream().filter(node -> node.isWorking() == Boolean.TRUE).mapToLong(Node::getActiveDocuments).sum(); + activeDocuments.set(activeDocs); + isBlockingWrites.set(nodes.stream().anyMatch(Node::isBlockingWrites)); + int numWorkingNodes = workingNodes(); + if (numWorkingNodes > 0) { + long average = activeDocs / numWorkingNodes; + long deviation = nodes.stream().filter(node -> node.isWorking() == Boolean.TRUE).mapToLong(node -> Math.abs(node.getActiveDocuments() - average)).sum(); + boolean isDeviationSmall = deviation <= (activeDocs * MAX_UNBALANCE); + if ((!isContentWellBalanced.get() || isDeviationSmall != isContentWellBalanced.get()) && (activeDocs > 0)) { + log.info("Content is " + (isDeviationSmall ? "" : "not ") + "well balanced. Current deviation = " + deviation*100/activeDocs + " %" + + ". activeDocs = " + activeDocs + ", deviation = " + deviation + ", average = " + average); + isContentWellBalanced.set(isDeviationSmall); + } + } else { + isContentWellBalanced.set(true); + } } - /** Returns the active documents on this node. If unknown, 0 is returned. */ + /** Returns the active documents on this group. If unknown, 0 is returned. */ long getActiveDocuments() { return activeDocuments.get(); } /** Returns whether any node in this group is currently blocking write operations */ public boolean isBlockingWrites() { return isBlockingWrites.get(); } + public boolean isContentWellBalanced() { return isContentWellBalanced.get(); } public boolean isFullCoverageStatusChanged(boolean hasFullCoverageNow) { boolean previousState = hasFullCoverage.getAndSet(hasFullCoverageNow); diff --git a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java index 2f62b07ac04..1897c0af8bc 100644 --- a/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java +++ b/container-search/src/main/java/com/yahoo/search/dispatch/searchcluster/SearchCluster.java @@ -368,6 +368,12 @@ public class SearchCluster implements NodeManager<Node> { return workingNodes + nodesAllowedDown >= nodesInGroup; } + public boolean isGroupWellBalanced(OptionalInt groupId) { + if (groupId.isEmpty()) return false; + Group group = groups().get(groupId.getAsInt()); + return (group != null) && group.isContentWellBalanced(); + } + /** * Calculate whether a subset of nodes in a group has enough coverage */ diff --git a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java index 30d741f465c..5f1f26b77e9 100644 --- a/container-search/src/main/java/com/yahoo/search/query/SelectParser.java +++ b/container-search/src/main/java/com/yahoo/search/query/SelectParser.java @@ -78,6 +78,7 @@ import static com.yahoo.search.yql.YqlParser.CONNECTIVITY; import static com.yahoo.search.yql.YqlParser.DEFAULT_TARGET_NUM_HITS; import static com.yahoo.search.yql.YqlParser.DESCENDING_HITS_ORDER; import static com.yahoo.search.yql.YqlParser.DISTANCE; +import static com.yahoo.search.yql.YqlParser.DISTANCE_THRESHOLD; import static com.yahoo.search.yql.YqlParser.DOT_PRODUCT; import static com.yahoo.search.yql.YqlParser.EQUIV; import static com.yahoo.search.yql.YqlParser.FILTER; @@ -481,6 +482,10 @@ public class SelectParser implements Parser { if (TARGET_NUM_HITS.equals(annotation_name)){ item.setTargetNumHits((int)(annotation_value.asDouble())); } + if (DISTANCE_THRESHOLD.equals(annotation_name)) { + double distanceThreshold = annotation_value.asDouble(); + item.setDistanceThreshold(distanceThreshold); + } if (HNSW_EXPLORE_ADDITIONAL_HITS.equals(annotation_name)) { int hnswExploreAdditionalHits = (int)(annotation_value.asDouble()); item.setHnswExploreAdditionalHits(hnswExploreAdditionalHits); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/AllValuesQueryProfileVisitor.java b/container-search/src/main/java/com/yahoo/search/query/profile/AllValuesQueryProfileVisitor.java index 68bf112133a..b24bf1195eb 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/AllValuesQueryProfileVisitor.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/AllValuesQueryProfileVisitor.java @@ -16,8 +16,8 @@ final class AllValuesQueryProfileVisitor extends PrefixQueryProfileVisitor { private final Map<String, ValueWithSource> values = new HashMap<>(); /* Lists all values starting at prefix */ - public AllValuesQueryProfileVisitor(CompoundName prefix) { - super(prefix); + public AllValuesQueryProfileVisitor(CompoundName prefix, CompoundNameChildCache pathCache) { + super(prefix, pathCache); } @Override @@ -43,7 +43,7 @@ final class AllValuesQueryProfileVisitor extends PrefixQueryProfileVisitor { QueryProfile owner, DimensionValues variant, DimensionBinding binding) { - CompoundName fullName = currentPrefix.append(key); + CompoundName fullName = cache.append(currentPrefix, key); ValueWithSource existing = values.get(fullName.toString()); diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/CompoundNameChildCache.java b/container-search/src/main/java/com/yahoo/search/query/profile/CompoundNameChildCache.java new file mode 100644 index 00000000000..4163e45ae61 --- /dev/null +++ b/container-search/src/main/java/com/yahoo/search/query/profile/CompoundNameChildCache.java @@ -0,0 +1,24 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.query.profile; + +import com.yahoo.processing.request.CompoundName; + +import java.util.HashMap; +import java.util.Map; + +/** + * Cache for compound names created through {@link CompoundName#append(String)}. + * Creating new {@link CompoundName}s can be expensive, and since they are immutable, they + * are safe to cache and reuse. Use this if you will create <em>a lot</em> of them, by appending suffixes. + * + * @author jonmv + */ +public final class CompoundNameChildCache { + + private final Map<CompoundName, Map<String, CompoundName>> cache = new HashMap<>(); + + public CompoundName append(CompoundName prefix, String suffix) { + return cache.computeIfAbsent(prefix, __ -> new HashMap<>()).computeIfAbsent(suffix, prefix::append); + } + +} diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/DimensionBinding.java b/container-search/src/main/java/com/yahoo/search/query/profile/DimensionBinding.java index e0edf9f9894..43462e8f327 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/DimensionBinding.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/DimensionBinding.java @@ -2,10 +2,9 @@ package com.yahoo.search.query.profile; import java.util.ArrayList; -import java.util.Collections; -import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.Objects; /** * An immutable, binding of a list of dimensions to dimension values @@ -24,13 +23,13 @@ public class DimensionBinding { private Map<String, String> context; public static final DimensionBinding nullBinding = - new DimensionBinding(Collections.unmodifiableList(Collections.emptyList()), DimensionValues.empty, null); + new DimensionBinding(List.of(), DimensionValues.empty, null); public static final DimensionBinding invalidBinding = - new DimensionBinding(Collections.unmodifiableList(Collections.emptyList()), DimensionValues.empty, null); + new DimensionBinding(List.of(), DimensionValues.empty, null); /** Whether the value array contains only nulls */ - private boolean containsAllNulls; + private final boolean containsAllNulls; // NOTE: Map must be ordered public static DimensionBinding createFrom(Map<String, String> values) { @@ -125,72 +124,54 @@ public class DimensionBinding { * * @return the combined binding, or the special invalidBinding if these two bindings are incompatible */ - public DimensionBinding combineWith(DimensionBinding binding) { - List<String> combinedDimensions = combineDimensions(getDimensions(), binding.getDimensions()); - if (combinedDimensions == null) return invalidBinding; - - // not runtime, so assume we don't need to preserve values outside the dimensions - Map<String, String> combinedValues = combineValues(getContext(), binding.getContext()); - if (combinedValues == null) return invalidBinding; - - return DimensionBinding.createFrom(combinedDimensions, combinedValues); - } - - /** - * Returns a combined list of dimensions from two separate lists, - * or null if they are incompatible. - * This is to combine two lists to one such that the partial order in both is preserved - * (or return null if impossible). - */ - private List<String> combineDimensions(List<String> d1, List<String> d2) { - if (d1.equals(d2)) return d1; - if (d1.isEmpty()) return d2; - if (d2.isEmpty()) return d1; - - List<String> combined = new ArrayList<>(); - int d1Index = 0, d2Index = 0; - while (d1Index < d1.size() && d2Index < d2.size()) { - if (d1.get(d1Index).equals(d2.get(d2Index))) { // agreement on next element - combined.add(d1.get(d1Index)); - d1Index++; - d2Index++; + public DimensionBinding combineWith(DimensionBinding other) { + List<String> d1 = getDimensions(); + List<String> d2 = other.getDimensions(); + DimensionValues v1 = getValues(); + DimensionValues v2 = other.getValues(); + List<String> dimensions = new ArrayList<>(); + List<String> values = new ArrayList<>(); + int i1 = 0, i2 = 0; + while (i1 < d1.size() && i2 < d2.size()) { + if (d1.get(i1).equals(d2.get(i2))) { // agreement on next dimension + String s1 = v1.get(i1), s2 = v2.get(i2); + if (s1 == null) + values.add(s2); + else if (s2 == null || s1.equals(s2)) + values.add(s1); + else + return invalidBinding; // disagreement on next value + + dimensions.add(d1.get(i1)); + i1++; + i2++; } - else if ( ! d2.contains(d1.get(d1Index))) { // next in d1 is independent from d2 - combined.add(d1.get(d1Index++)); + else if ( ! d2.contains(d1.get(i1))) { // next dimension in d1 is independent from d2 + dimensions.add(d1.get(i1)); + values.add(v1.get(i1)); + i1++; } - else if ( ! d1.contains(d2.get(d2Index))) { // next in d2 is independent from d1 - combined.add(d2.get(d2Index++)); + else if ( ! d1.contains(d2.get(i2))) { // next dimension in d2 is independent from d1 + dimensions.add(d2.get(i2)); + values.add(v2.get(i2)); + i2++; } else { - return null; // not independent and no agreement + return invalidBinding; // not independent and no agreement } } - if (d1Index < d1.size()) - combined.addAll(d1.subList(d1Index, d1.size())); - else if (d2Index < d2.size()) - combined.addAll(d2.subList(d2Index, d2.size())); - - return combined; - } - - /** - * Returns a combined map of dimension values from two separate maps, - * or null if they are incompatible. - */ - private Map<String, String> combineValues(Map<String, String> m1, Map<String, String> m2) { - if (m1.isEmpty()) return m2; - if (m2.isEmpty()) return m1; - Map<String, String> combinedValues = null; - for (Map.Entry<String, String> m2Entry : m2.entrySet()) { - if (m2Entry.getValue() == null) continue; - String m1Value = m1.get(m2Entry.getKey()); - if (m1Value != null && ! m1Value.equals(m2Entry.getValue())) - return null; // conflicting values of a key - if (combinedValues == null) - combinedValues = new LinkedHashMap<>(m1); - combinedValues.put(m2Entry.getKey(), m2Entry.getValue()); + while (i1 < d1.size()) { + dimensions.add(d1.get(i1)); + values.add(v1.get(i1)); + i1++; + } + while (i2 < d2.size()) { + dimensions.add(d2.get(i2)); + values.add(v2.get(i2)); + i2++; } - return combinedValues == null ? m1 : combinedValues; + + return DimensionBinding.createFrom(dimensions, DimensionValues.createFrom(values.toArray(new String[0]))); } /** Returns true if this == invalidBinding */ @@ -223,7 +204,7 @@ public class DimensionBinding { @Override public int hashCode() { - return dimensions.hashCode() + 17 * values.hashCode(); + return Objects.hash(dimensions, values); } } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/DimensionValues.java b/container-search/src/main/java/com/yahoo/search/query/profile/DimensionValues.java index 7c3307223c3..f3c4548c491 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/DimensionValues.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/DimensionValues.java @@ -80,7 +80,7 @@ public class DimensionValues implements Comparable<DimensionValues> { public boolean equals(Object o) { if (this == o) return true; if ( ! (o instanceof DimensionValues)) return false; - DimensionValues other = (DimensionValues)o; + DimensionValues other = (DimensionValues) o; for (int i = 0; i < this.size() || i < other.size(); i++) { if (get(i) == null) { if (other.get(i) != null) return false; diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/PrefixQueryProfileVisitor.java b/container-search/src/main/java/com/yahoo/search/query/profile/PrefixQueryProfileVisitor.java index 690a48f8124..b53fc4f96f2 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/PrefixQueryProfileVisitor.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/PrefixQueryProfileVisitor.java @@ -3,6 +3,11 @@ package com.yahoo.search.query.profile; import com.yahoo.processing.request.CompoundName; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.HashMap; +import java.util.Map; + /** * A query profile visitor which keeps track of name prefixes and can skip values outside a given prefix * @@ -10,18 +15,22 @@ import com.yahoo.processing.request.CompoundName; */ abstract class PrefixQueryProfileVisitor extends QueryProfileVisitor { + protected final CompoundNameChildCache cache; + /** Only call onValue/onQueryProfile for nodes having this prefix */ private final CompoundName prefix; /** The current prefix, relative to prefix. */ protected CompoundName currentPrefix = CompoundName.empty; + private final Deque<CompoundName> currentPrefixes = new ArrayDeque<>(); private int prefixComponentIndex = -1; - public PrefixQueryProfileVisitor(CompoundName prefix) { + public PrefixQueryProfileVisitor(CompoundName prefix, CompoundNameChildCache cache) { if (prefix == null) prefix = CompoundName.empty; this.prefix = prefix; + this.cache = cache; } @Override @@ -40,18 +49,19 @@ abstract class PrefixQueryProfileVisitor extends QueryProfileVisitor { @Override public final boolean enter(String name) { - prefixComponentIndex++; - if (prefixComponentIndex-1 < prefix.size()) return true; // we're in the given prefix, which should not be included in the name - currentPrefix = currentPrefix.append(name); + if (prefixComponentIndex++ < prefix.size()) return true; // we're in the given prefix, which should not be included in the name + if ( ! name.isEmpty()) { + currentPrefixes.push(currentPrefix); + currentPrefix = cache.append(currentPrefix, name); + } return true; } @Override public final void leave(String name) { - prefixComponentIndex--; - if (prefixComponentIndex < prefix.size()) return; // we're in the given prefix, which should not be included in the name - if ( ! name.isEmpty() && ! currentPrefix.isEmpty()) - currentPrefix = currentPrefix.first(currentPrefix.size() - 1); + if (--prefixComponentIndex < prefix.size()) return; // we're in the given prefix, which should not be included in the name + if ( ! name.isEmpty()) + currentPrefix = currentPrefixes.pop(); } /** diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java index be4a683d9d2..4371955ae63 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfile.java @@ -18,7 +18,6 @@ import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.Set; import java.util.regex.Matcher; import java.util.regex.Pattern; @@ -266,10 +265,13 @@ public class QueryProfile extends FreezableSimpleComponent implements Cloneable } AllValuesQueryProfileVisitor visitValues(CompoundName prefix, Map<String, String> context) { - DimensionBinding dimensionBinding = DimensionBinding.createFrom(getDimensions(), context); + return visitValues(prefix, context, new CompoundNameChildCache()); + } - AllValuesQueryProfileVisitor visitor = new AllValuesQueryProfileVisitor(prefix); - accept(visitor, dimensionBinding, null); + AllValuesQueryProfileVisitor visitValues(CompoundName prefix, Map<String, String> context, + CompoundNameChildCache pathCache) { + AllValuesQueryProfileVisitor visitor = new AllValuesQueryProfileVisitor(prefix, pathCache); + accept(visitor, DimensionBinding.createFrom(getDimensions(), context), null); return visitor; } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileCompiler.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileCompiler.java index 3010c6a9d09..29a997a75dd 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileCompiler.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileCompiler.java @@ -2,17 +2,24 @@ package com.yahoo.search.query.profile; import com.yahoo.processing.request.CompoundName; +import com.yahoo.search.query.profile.compiled.Binding; import com.yahoo.search.query.profile.compiled.CompiledQueryProfile; import com.yahoo.search.query.profile.compiled.CompiledQueryProfileRegistry; import com.yahoo.search.query.profile.compiled.DimensionalMap; import com.yahoo.search.query.profile.compiled.ValueWithSource; import com.yahoo.search.query.profile.types.QueryProfileType; +import java.util.ArrayList; +import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; +import java.util.List; import java.util.Map; import java.util.Set; +import java.util.SortedMap; +import java.util.TreeMap; +import java.util.function.BiConsumer; import java.util.logging.Logger; -import java.util.stream.Collectors; /** * Compile a set of query profiles into compiled profiles. @@ -42,17 +49,20 @@ public class QueryProfileCompiler { variants.add(new DimensionBindingForPath(DimensionBinding.nullBinding, CompoundName.empty)); // if this contains no variants log.fine(() -> "Compiling " + in + " having " + variants.size() + " variants"); - for (DimensionBindingForPath variant : variants) { - log.finer(() -> " Compiling variant " + variant); - for (Map.Entry<String, ValueWithSource> entry : in.visitValues(variant.path(), variant.binding().getContext()).valuesWithSource().entrySet()) { - CompoundName fullName = variant.path().append(entry.getKey()); - values.put(fullName, variant.binding(), entry.getValue()); + CompoundNameChildCache pathCache = new CompoundNameChildCache(); + Map<DimensionBinding, Binding> bindingCache = new HashMap<>(); + for (var variant : variants) { + log.finer(() -> "Compiling variant " + variant); + Binding variantBinding = bindingCache.computeIfAbsent(variant.binding(), Binding::createFrom); + for (var entry : in.visitValues(variant.path(), variant.binding().getContext(), pathCache).valuesWithSource().entrySet()) { + CompoundName fullName = pathCache.append(variant.path, entry.getKey()); + values.put(fullName, variantBinding, entry.getValue()); if (entry.getValue().isUnoverridable()) - unoverridables.put(fullName, variant.binding(), Boolean.TRUE); + unoverridables.put(fullName, variantBinding, Boolean.TRUE); if (entry.getValue().isQueryProfile()) - references.put(fullName, variant.binding(), Boolean.TRUE); + references.put(fullName, variantBinding, Boolean.TRUE); if (entry.getValue().queryProfileType() != null) - types.put(fullName, variant.binding(), entry.getValue().queryProfileType()); + types.put(fullName, variantBinding, entry.getValue().queryProfileType()); } } @@ -96,17 +106,31 @@ public class QueryProfileCompiler { * I.e if we have the variants [-,b=b1], [a=a1,-], [a=a2,-], * this returns the variants [a=a1,b=b1], [a=a2,b=b1] * - * This is necessary because left-specified values takes precedence, such that resolving [a=a1,b=b1] would - * lead us to the compiled profile [a=a1,-], which may contain default values for properties where + * This is necessary because left-specified values takes precedence, and resolving [a=a1,b=b1] would + * otherwise lead us to the compiled profile [a=a1,-], which may contain default values for properties where * we should have preferred variant values in [-,b=b1]. */ private static Set<DimensionBindingForPath> wildcardExpanded(Set<DimensionBindingForPath> variants) { Set<DimensionBindingForPath> expanded = new HashSet<>(); - for (var variant : variants) { - if (hasWildcardBeforeEnd(variant.binding())) - expanded.addAll(wildcardExpanded(variant, variants)); - } + PathTree trie = new PathTree(); + for (var variant : variants) + trie.add(variant); + + // Visit all variant prefixes, grouped on path, and all their unique child bindings. + trie.forEachPrefixAndChildren((prefixes, childBindings) -> { + Set<DimensionBinding> processed = new HashSet<>(); + for (DimensionBindingForPath prefix : prefixes) + if (processed.add(prefix.binding())) // Only compute once for similar bindings, since path is equal. + if (hasWildcardBeforeEnd(prefix.binding())) + for (DimensionBinding childBinding : childBindings) + if (childBinding != prefix.binding()) { + DimensionBinding combined = prefix.binding().combineWith(childBinding); + if ( ! combined.isInvalid()) + expanded.add(new DimensionBindingForPath(combined, prefix.path())); + } + }); + return expanded; } @@ -118,20 +142,6 @@ public class QueryProfileCompiler { return false; } - private static Set<DimensionBindingForPath> wildcardExpanded(DimensionBindingForPath variantToExpand, - Set<DimensionBindingForPath> variants) { - Set<DimensionBindingForPath> expanded = new HashSet<>(); - for (var variant : variants) { - if (variant.binding().isNull()) continue; - if ( ! variant.path().hasPrefix(variantToExpand.path())) continue; - DimensionBinding combined = variantToExpand.binding().combineWith(variant.binding()); - if ( ! combined.isInvalid() ) - expanded.add(new DimensionBindingForPath(combined, variantToExpand.path())); - } - return expanded; - } - - /** Generates a set of all the (legal) combinations of the variants in the given sets */ private static Set<DimensionBindingForPath> combined(Set<DimensionBindingForPath> v1s, Set<DimensionBindingForPath> v2s) { @@ -214,4 +224,61 @@ public class QueryProfileCompiler { } + + /** + * Simple trie for CompoundName paths. + * + * @author jonmv + */ + static class PathTree { + + private final Node root = new Node(0); + + void add(DimensionBindingForPath entry) { + root.add(entry); + } + + /** Performs action on sets of path prefixes against all their (common) children. */ + void forEachPrefixAndChildren(BiConsumer<Collection<DimensionBindingForPath>, Collection<DimensionBinding>> action) { + root.visit(action); + } + + private static class Node { + + private final int depth; + private final SortedMap<String, Node> children = new TreeMap<>(); + private final List<DimensionBindingForPath> elements = new ArrayList<>(); + + private Node(int depth) { + this.depth = depth; + } + + /** Performs action on the elements of this against all child element bindings, then returns the union of these two sets. */ + Set<DimensionBinding> visit(BiConsumer<Collection<DimensionBindingForPath>, Collection<DimensionBinding>> action) { + Set<DimensionBinding> allChildren = new HashSet<>(); + for (Node child : children.values()) + allChildren.addAll(child.visit(action)); + + for (DimensionBindingForPath element : elements) + if ( ! element.binding().isNull()) + allChildren.add(element.binding()); + + action.accept(elements, allChildren); + + return allChildren; + } + + void add(DimensionBindingForPath entry) { + if (depth == entry.path().size()) + elements.add(entry); + else + children.computeIfAbsent(entry.path().get(depth), + __ -> new Node(depth + 1)).add(entry); + } + + } + + } + + } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java index e873c80add1..46430a3041a 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/Binding.java @@ -33,7 +33,7 @@ public class Binding implements Comparable<Binding> { private final int hashCode; - public static final Binding nullBinding = new Binding(Integer.MAX_VALUE, Collections.<String,String>emptyMap()); + public static final Binding nullBinding = new Binding(Integer.MAX_VALUE, Map.of()); public static Binding createFrom(DimensionBinding dimensionBinding) { if (dimensionBinding.getDimensions().size() > maxDimensions) diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalMap.java b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalMap.java index b6bd6dc5a6a..6dc5f61c1f6 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalMap.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalMap.java @@ -46,14 +46,9 @@ public class DimensionalMap<VALUE> { private final Map<CompoundName, DimensionalValue.Builder<VALUE>> entries = new HashMap<>(); - // TODO: DimensionBinding -> Binding? - public void put(CompoundName key, DimensionBinding binding, VALUE value) { - DimensionalValue.Builder<VALUE> entry = entries.get(key); - if (entry == null) { - entry = new DimensionalValue.Builder<>(); - entries.put(key, entry); - } - entry.add(value, binding); + public void put(CompoundName key, Binding binding, VALUE value) { + entries.computeIfAbsent(key, __ -> new DimensionalValue.Builder<>()) + .add(value, binding); } public DimensionalMap<VALUE> build() { diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalValue.java b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalValue.java index fb62cfca7d3..afe07b09d41 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalValue.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/compiled/DimensionalValue.java @@ -76,15 +76,11 @@ public class DimensionalValue<VALUE> { return null; } - public void add(VALUE value, DimensionBinding variantBinding) { + public void add(VALUE value, Binding variantBinding) { // Note: We know we can index by the value because its possible types are constrained // to what query profiles allow: String, primitives and query profiles (wrapped as a ValueWithSource) - Value.Builder<VALUE> variant = buildableVariants.get(value); - if (variant == null) { - variant = new Value.Builder<>(value); - buildableVariants.put(value, variant); - } - variant.addVariant(variantBinding, value); + buildableVariants.computeIfAbsent(value, Value.Builder::new) + .addVariant(variantBinding, value); } public DimensionalValue<VALUE> build(Map<CompoundName, DimensionalValue.Builder<VALUE>> entries) { @@ -156,8 +152,8 @@ public class DimensionalValue<VALUE> { /** Add a binding this holds for */ @SuppressWarnings("unchecked") - public void addVariant(DimensionBinding binding, VALUE newValue) { - variants.add(Binding.createFrom(binding)); + public void addVariant(Binding binding, VALUE newValue) { + variants.add(binding); // We're combining values for efficiency, so remove incorrect provenance info if (value instanceof ValueWithSource) { diff --git a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java index a38e48fd89d..f4a36ea51ab 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java +++ b/container-search/src/main/java/com/yahoo/search/yql/VespaSerializer.java @@ -723,7 +723,13 @@ public class VespaSerializer { destination.append(leafAnnotations(item)); comma(destination, initLen); int targetNumHits = item.getTargetNumHits(); - annotationKey(destination, "targetNumHits").append(targetNumHits); + annotationKey(destination, YqlParser.TARGET_NUM_HITS).append(targetNumHits); + double distanceThreshold = item.getDistanceThreshold(); + if (distanceThreshold < Double.POSITIVE_INFINITY) { + comma(destination, initLen); + String key = YqlParser.DISTANCE_THRESHOLD; + annotationKey(destination, key).append(distanceThreshold); + } int explore = item.getHnswExploreAdditionalHits(); if (explore != 0) { comma(destination, initLen); diff --git a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java index 739aae0e277..f37aeb4c1e0 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java +++ b/container-search/src/main/java/com/yahoo/search/yql/YqlParser.java @@ -156,6 +156,7 @@ public class YqlParser implements Parser { public static final String FILTER = "filter"; public static final String GEO_LOCATION = "geoLocation"; public static final String HIT_LIMIT = "hitLimit"; + public static final String DISTANCE_THRESHOLD = "distanceThreshold"; public static final String HNSW_EXPLORE_ADDITIONAL_HITS = "hnsw.exploreAdditionalHits"; public static final String IMPLICIT_TRANSFORMS = "implicitTransforms"; public static final String LABEL = "label"; @@ -459,6 +460,11 @@ public class YqlParser implements Parser { if (targetNumHits != null) { item.setTargetNumHits(targetNumHits); } + Double distanceThreshold = getAnnotation(ast, DISTANCE_THRESHOLD, + Double.class, null, "maximum distance allowed from query point"); + if (distanceThreshold != null) { + item.setDistanceThreshold(distanceThreshold); + } Integer hnswExploreAdditionalHits = getAnnotation(ast, HNSW_EXPLORE_ADDITIONAL_HITS, Integer.class, null, "number of extra hits to explore for HNSW algorithm"); if (hnswExploreAdditionalHits != null) { diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java index 2bfa778a2ba..8cab7884152 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/InterleavedSearchInvokerTest.java @@ -204,8 +204,8 @@ public class InterleavedSearchInvokerTest { private static final List<Double> A5Aux = Arrays.asList(-1.0,11.0,8.5,7.5,-7.0,3.0,2.0); private static final List<Double> B5Aux = Arrays.asList(9.0,8.0,-3.0,7.0,6.0,1.0, -1.0); - private void validateThatTopKProbabilityOverrideTakesEffect(Double topKProbability, int expectedK) throws IOException { - InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); + private void validateThatTopKProbabilityOverrideTakesEffect(Double topKProbability, int expectedK, boolean isContentWellBalanced) throws IOException { + InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5, isContentWellBalanced); query.setHits(8); query.properties().set(Dispatcher.topKProbability, topKProbability); SearchInvoker [] invokers = invoker.invokers().toArray(new SearchInvoker[0]); @@ -227,13 +227,17 @@ public class InterleavedSearchInvokerTest { @Test public void requireThatTopKProbabilityOverrideTakesEffect() throws IOException { - validateThatTopKProbabilityOverrideTakesEffect(null, 8); - validateThatTopKProbabilityOverrideTakesEffect(0.8, 7); + validateThatTopKProbabilityOverrideTakesEffect(null, 8, true); + validateThatTopKProbabilityOverrideTakesEffect(0.8, 7, true); + } + @Test + public void requireThatTopKProbabilityOverrideIsDisabledOnContentSkew() throws IOException { + validateThatTopKProbabilityOverrideTakesEffect(0.8, 8, false); } @Test public void requireThatMergeOfConcreteHitsObeySorting() throws IOException { - InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); + InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5, true); query.setHits(12); Result result = invoker.search(query, null); assertEquals(10, result.hits().size()); @@ -242,7 +246,7 @@ public class InterleavedSearchInvokerTest { assertEquals(0, result.getQuery().getOffset()); assertEquals(12, result.getQuery().getHits()); - invoker = createInterLeavedTestInvoker(B5, A5); + invoker = createInterLeavedTestInvoker(B5, A5, true); result = invoker.search(query, null); assertEquals(10, result.hits().size()); assertEquals(11.0, result.hits().get(0).getRelevance().getScore(), DELTA); @@ -253,7 +257,7 @@ public class InterleavedSearchInvokerTest { @Test public void requireThatMergeOfConcreteHitsObeyOffset() throws IOException { - InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5); + InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5, B5, true); query.setHits(3); query.setOffset(5); Result result = invoker.search(query, null); @@ -263,7 +267,7 @@ public class InterleavedSearchInvokerTest { assertEquals(0, result.getQuery().getOffset()); assertEquals(3, result.getQuery().getHits()); - invoker = createInterLeavedTestInvoker(B5, A5); + invoker = createInterLeavedTestInvoker(B5, A5, true); query.setOffset(5); result = invoker.search(query, null); assertEquals(3, result.hits().size()); @@ -275,7 +279,7 @@ public class InterleavedSearchInvokerTest { @Test public void requireThatMergeOfConcreteHitsObeyOffsetWithAuxilliaryStuff() throws IOException { - InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5Aux, B5Aux); + InterleavedSearchInvoker invoker = createInterLeavedTestInvoker(A5Aux, B5Aux, true); query.setHits(3); query.setOffset(5); Result result = invoker.search(query, null); @@ -286,7 +290,7 @@ public class InterleavedSearchInvokerTest { assertEquals(0, result.getQuery().getOffset()); assertEquals(3, result.getQuery().getHits()); - invoker = createInterLeavedTestInvoker(B5Aux, A5Aux); + invoker = createInterLeavedTestInvoker(B5Aux, A5Aux, true); query.setOffset(5); result = invoker.search(query, null); assertEquals(7, result.hits().size()); @@ -297,12 +301,13 @@ public class InterleavedSearchInvokerTest { assertEquals(3, result.getQuery().getHits()); } - private static InterleavedSearchInvoker createInterLeavedTestInvoker(List<Double> a, List<Double> b) { + private static InterleavedSearchInvoker createInterLeavedTestInvoker(List<Double> a, List<Double> b, + boolean isContentWellBalanced) { SearchCluster cluster = new MockSearchCluster("!", 1, 2); List<SearchInvoker> invokers = new ArrayList<>(); invokers.add(createInvoker(a, 0)); invokers.add(createInvoker(b, 1)); - InterleavedSearchInvoker invoker = new InterleavedSearchInvoker(invokers, cluster, Collections.emptySet()); + InterleavedSearchInvoker invoker = new InterleavedSearchInvoker(invokers, isContentWellBalanced, cluster, Collections.emptySet()); invoker.responseAvailable(invokers.get(0)); invoker.responseAvailable(invokers.get(1)); return invoker; @@ -353,7 +358,7 @@ public class InterleavedSearchInvokerTest { invokers.add(new MockInvoker(i)); } - return new InterleavedSearchInvoker(invokers, searchCluster, null) { + return new InterleavedSearchInvoker(invokers, false, searchCluster, null) { @Override protected long currentTime() { return clock.millis(); diff --git a/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java b/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java index 09024150a9a..c6fd48836fe 100644 --- a/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java +++ b/container-search/src/test/java/com/yahoo/search/dispatch/searchcluster/SearchClusterTest.java @@ -8,7 +8,6 @@ import com.yahoo.net.HostName; import com.yahoo.prelude.Pong; import com.yahoo.search.cluster.ClusterMonitor; import com.yahoo.search.dispatch.MockSearchCluster; -import com.yahoo.search.dispatch.TopKEstimator; import com.yahoo.search.result.ErrorMessage; import org.junit.Test; @@ -335,4 +334,48 @@ public class SearchClusterTest { assertEquals(3, node.getLastReceivedPongId()); } + @Test + public void requireThatEmptyGroupIsInBalance() { + Group group = new Group(0, new ArrayList<>()); + assertTrue(group.isContentWellBalanced()); + group.aggregateNodeValues(); + assertTrue(group.isContentWellBalanced()); + } + + @Test + public void requireThatSingleNodeGroupIsInBalance() { + Group group = new Group(0, Arrays.asList(new Node(1, "n", 1))); + group.nodes().forEach(node -> node.setWorking(true)); + assertTrue(group.isContentWellBalanced()); + group.aggregateNodeValues(); + assertTrue(group.isContentWellBalanced()); + group.nodes().get(0).setActiveDocuments(1000); + group.aggregateNodeValues(); + assertTrue(group.isContentWellBalanced()); + } + + @Test + public void requireThatMultiNodeGroupDetectsBalance() { + Group group = new Group(0, Arrays.asList(new Node(1, "n1", 1), new Node(2, "n2", 1))); + assertTrue(group.isContentWellBalanced()); + group.nodes().forEach(node -> node.setWorking(true)); + assertTrue(group.isContentWellBalanced()); + group.aggregateNodeValues(); + assertTrue(group.isContentWellBalanced()); + group.nodes().get(0).setActiveDocuments(1000); + group.aggregateNodeValues(); + assertFalse(group.isContentWellBalanced()); + group.nodes().get(1).setActiveDocuments(100); + group.aggregateNodeValues(); + assertFalse(group.isContentWellBalanced()); + group.nodes().get(1).setActiveDocuments(800); + group.aggregateNodeValues(); + assertFalse(group.isContentWellBalanced()); + group.nodes().get(1).setActiveDocuments(818); + group.aggregateNodeValues(); + assertFalse(group.isContentWellBalanced()); + group.nodes().get(1).setActiveDocuments(819); + group.aggregateNodeValues(); + assertTrue(group.isContentWellBalanced()); + } } diff --git a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java index c49603737a6..72956b5b6eb 100644 --- a/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/searchers/ValidateNearestNeighborTestCase.java @@ -138,6 +138,7 @@ public class ValidateNearestNeighborTestCase { r.append("field=").append(field); r.append(",queryTensorName=").append(qt); r.append(",hnsw.exploreAdditionalHits=0"); + r.append(",distanceThreshold=Infinity"); r.append(",approximate=true"); r.append(",targetHits=").append(th); r.append("} ").append(errmsg); diff --git a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java index 63840b0f5ec..a44a9f25b62 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/VespaSerializerTestCase.java @@ -139,6 +139,7 @@ public class VespaSerializerTestCase { parseAndConfirm("[{\"targetNumHits\": 1, \"hnsw.exploreAdditionalHits\": 76}]nearestNeighbor(semantic_embedding, my_property)"); parseAndConfirm("[{\"targetNumHits\": 2, \"approximate\": false}]nearestNeighbor(semantic_embedding, my_property)"); parseAndConfirm("[{\"targetNumHits\": 3, \"hnsw.exploreAdditionalHits\": 67, \"approximate\": false}]nearestNeighbor(semantic_embedding, my_property)"); + parseAndConfirm("[{\"targetNumHits\": 4, \"distanceThreshold\": 100100.25}]nearestNeighbor(semantic_embedding, my_property)"); } @Test diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java index f5e22e30f45..2d88351f9ea 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java @@ -568,11 +568,15 @@ public class YqlParserTestCase { @Test public void testNearestNeighbor() { assertParse("select foo from bar where nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=0}"); assertParse("select foo from bar where [{\"targetHits\": 37}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=37}"); assertParse("select foo from bar where [{\"approximate\": false, \"hnsw.exploreAdditionalHits\": 8, \"targetHits\": 3}]nearestNeighbor(semantic_embedding, my_vector);", - "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,approximate=false,targetHits=3}"); + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=8,distanceThreshold=Infinity,approximate=false,targetHits=3}"); + + assertParse("select foo from bar where [{\"targetHits\": 7, \"distanceThreshold\": 100100.25}]nearestNeighbor(semantic_embedding, my_vector);", + "NEAREST_NEIGHBOR {field=semantic_embedding,queryTensorName=my_vector,hnsw.exploreAdditionalHits=0,distanceThreshold=100100.25,approximate=true,targetHits=7}"); + } @Test diff --git a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java index c802eb18c0f..3239a97a094 100644 --- a/container-search/src/test/java/com/yahoo/select/SelectTestCase.java +++ b/container-search/src/test/java/com/yahoo/select/SelectTestCase.java @@ -537,10 +537,10 @@ public class SelectTestCase { @Test public void testNearestNeighbor() { assertParse("{ \"nearestNeighbor\": [ \"f1field\", \"q2prop\" ] }", - "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=0}"); + "NEAREST_NEIGHBOR {field=f1field,queryTensorName=q2prop,hnsw.exploreAdditionalHits=0,distanceThreshold=Infinity,approximate=true,targetHits=0}"); - assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37} }}", - "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=0,approximate=true,targetHits=37}"); + assertParse("{ \"nearestNeighbor\": { \"children\" : [ \"f3field\", \"q4prop\" ], \"attributes\" : {\"targetHits\": 37, \"hnsw.exploreAdditionalHits\": 42, \"distanceThreshold\": 100100.25 } }}", + "NEAREST_NEIGHBOR {field=f3field,queryTensorName=q4prop,hnsw.exploreAdditionalHits=42,distanceThreshold=100100.25,approximate=true,targetHits=37}"); } @Test diff --git a/container-test/pom.xml b/container-test/pom.xml index dfbe5f755ac..07134e8aa7f 100644 --- a/container-test/pom.xml +++ b/container-test/pom.xml @@ -65,6 +65,10 @@ <groupId>xerces</groupId> <artifactId>xercesImpl</artifactId> </dependency> - + <dependency> + <groupId>io.airlift</groupId> + <artifactId>aircompressor</artifactId> + <scope>compile</scope> + </dependency> </dependencies> </project> diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java index 0c9a415beab..11940b30ac1 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/integration/configserver/ConfigServer.java @@ -144,4 +144,8 @@ public interface ConfigServer { /** Get maximum resources consumed */ QuotaUsage getQuotaUsage(DeploymentId deploymentId); + + /** Sets suspension status — whether application node operations are orchestrated — for the given deployment. */ + void setSuspension(DeploymentId deploymentId, boolean suspend); + } diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java index 2acf7c93925..12df0a5e0a7 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/PathGroup.java @@ -169,8 +169,10 @@ enum PathGroup { "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/deploy/{job}", "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/dev/region/{region}", "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/dev/region/{region}/deploy", + "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/dev/region/{region}/suspend", "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/perf/region/{region}", "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/perf/region/{region}/deploy", + "/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/perf/region/{region}/suspend", "/application/v4/tenant/{tenant}/application/{application}/environment/dev/region/{region}/instance/{instance}", "/application/v4/tenant/{tenant}/application/{application}/environment/dev/region/{region}/instance/{instance}/deploy", "/application/v4/tenant/{tenant}/application/{application}/environment/perf/region/{region}/instance/{instance}", diff --git a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java index 92f902dc0f7..46d1dc76b57 100644 --- a/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java +++ b/controller-api/src/main/java/com/yahoo/vespa/hosted/controller/api/role/SecurityContext.java @@ -2,6 +2,7 @@ package com.yahoo.vespa.hosted.controller.api.role; import java.security.Principal; +import java.time.Instant; import java.util.Objects; import java.util.Set; @@ -14,10 +15,16 @@ public class SecurityContext { private final Principal principal; private final Set<Role> roles; + private final Instant issuedAt; - public SecurityContext(Principal principal, Set<Role> roles) { + public SecurityContext(Principal principal, Set<Role> roles, Instant issuedAt) { this.principal = Objects.requireNonNull(principal); this.roles = Set.copyOf(roles); + this.issuedAt = Objects.requireNonNull(issuedAt); + } + + public SecurityContext(Principal principal, Set<Role> roles) { + this(principal, roles, Instant.EPOCH); } public Principal principal() { @@ -28,18 +35,23 @@ public class SecurityContext { return roles; } + public Instant issuedAt() { + return issuedAt; + } + @Override public boolean equals(Object o) { if (this == o) return true; if (o == null || getClass() != o.getClass()) return false; SecurityContext that = (SecurityContext) o; return Objects.equals(principal, that.principal) && - Objects.equals(roles, that.roles); + Objects.equals(roles, that.roles) && + Objects.equals(issuedAt, that.issuedAt); } @Override public int hashCode() { - return Objects.hash(principal, roles); + return Objects.hash(principal, roles, issuedAt); } @Override @@ -47,6 +59,7 @@ public class SecurityContext { return "SecurityContext{" + "principal=" + principal + ", roles=" + roles + + ", issuedAt=" + issuedAt + '}'; } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java index 832668bf9f7..e071221dd05 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/ApplicationController.java @@ -985,4 +985,9 @@ public class ApplicationController { return Map.copyOf(warnings); } + /** Sets suspension status of the given deployment in its zone. */ + public void setSuspension(DeploymentId deploymentId, boolean suspend) { + configServer.setSuspension(deploymentId, suspend); + } + } diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java index 0dff1600751..67490bf9d8c 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiHandler.java @@ -292,6 +292,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/reindex")) return reindex(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/reindexing")) return enableReindexing(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/restart")) return restart(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); + if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/suspend")) return suspend(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), true); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}")) return deploy(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/deploy")) return deploy(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); // legacy synonym of the above if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/restart")) return restart(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); @@ -319,6 +320,7 @@ public class ApplicationApiHandler extends LoggingRequestHandler { if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/job/{jobtype}/pause")) return resume(appIdFromPath(path), jobTypeFromPath(path)); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}")) return deactivate(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/reindexing")) return disableReindexing(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); + if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/suspend")) return suspend(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), false); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/instance/{instance}/environment/{environment}/region/{region}/global-rotation/override")) return setGlobalRotationOverride(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), true, request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}")) return deactivate(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), request); if (path.matches("/application/v4/tenant/{tenant}/application/{application}/environment/{environment}/region/{region}/instance/{instance}/global-rotation/override")) return setGlobalRotationOverride(path.get("tenant"), path.get("application"), path.get("instance"), path.get("environment"), path.get("region"), true, request); @@ -1651,6 +1653,14 @@ public class ApplicationApiHandler extends LoggingRequestHandler { return new MessageResponse("Requested restart of " + deploymentId); } + /** Set suspension status of the given deployment. */ + private HttpResponse suspend(String tenantName, String applicationName, String instanceName, String environment, String region, boolean suspend) { + DeploymentId deploymentId = new DeploymentId(ApplicationId.from(tenantName, applicationName, instanceName), + requireZone(environment, region)); + controller.applications().setSuspension(deploymentId, suspend); + return new MessageResponse((suspend ? "Suspended" : "Resumed") + " orchestration of " + deploymentId); + } + private HttpResponse jobDeploy(ApplicationId id, JobType type, HttpRequest request) { if ( ! type.environment().isManuallyDeployed() && ! isOperator(request)) throw new IllegalArgumentException("Direct deployments are only allowed to manually deployed environments."); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java index 569c22d8bf6..8fcbb365804 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/ControllerTest.java @@ -667,7 +667,7 @@ public class ControllerTest { DeploymentId deployment2 = context.deploymentIdIn(ZoneId.from(Environment.prod, RegionName.from("us-east-3"))); assertFalse(tester.configServer().isSuspended(deployment1)); assertFalse(tester.configServer().isSuspended(deployment2)); - tester.configServer().setSuspended(deployment1, true); + tester.configServer().setSuspension(deployment1, true); assertTrue(tester.configServer().isSuspended(deployment1)); assertFalse(tester.configServer().isSuspended(deployment2)); } diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java index 9b1eff60831..724ba61da2b 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/deployment/DeploymentTriggerTest.java @@ -253,7 +253,7 @@ public class DeploymentTriggerTest { var application = tester.newDeploymentContext().submit().deploy(); // The first production zone is suspended: - tester.configServer().setSuspended(application.deploymentIdIn(ZoneId.from("prod", "us-central-1")), true); + tester.configServer().setSuspension(application.deploymentIdIn(ZoneId.from("prod", "us-central-1")), true); // A new change needs to be pushed out, but should not go beyond the suspended zone: application.submit() @@ -265,7 +265,7 @@ public class DeploymentTriggerTest { application.assertNotRunning(productionUsWest1); // The zone is unsuspended so jobs start: - tester.configServer().setSuspended(application.deploymentIdIn(ZoneId.from("prod", "us-central-1")), false); + tester.configServer().setSuspension(application.deploymentIdIn(ZoneId.from("prod", "us-central-1")), false); tester.triggerJobs(); application.runJob(productionUsWest1).runJob(productionUsEast3); assertEquals(Change.empty(), application.instance().change()); diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java index ab31b7e21fe..7753570b72d 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/integration/ConfigServerMock.java @@ -244,7 +244,8 @@ public class ConfigServerMock extends AbstractComponent implements ConfigServer return Optional.ofNullable(applications.get(new DeploymentId(id, zone))); } - public void setSuspended(DeploymentId deployment, boolean suspend) { + @Override + public void setSuspension(DeploymentId deployment, boolean suspend) { if (suspend) suspendedApplications.add(deployment); else diff --git a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java index 3d1375601ad..d9234c9a28e 100644 --- a/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java +++ b/controller-server/src/test/java/com/yahoo/vespa/hosted/controller/restapi/application/ApplicationApiTest.java @@ -113,6 +113,7 @@ public class ApplicationApiTest extends ControllerContainerTest { "z/4jKSTHwbYR8wdsOSrJGVEUPbS2nguIJ64OJH7gFnxM6sxUVj+Nm2HlXw==\n" + "-----END PUBLIC KEY-----\n"; private static final String quotedPemPublicKey = pemPublicKey.replaceAll("\\n", "\\\\n"); + private static final String accessDenied = "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}"; private static final ApplicationPackage applicationPackageDefault = new ApplicationPackageBuilder() .instances("default") @@ -262,13 +263,13 @@ public class ApplicationApiTest extends ControllerContainerTest { tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/otheruser/deploy/dev-us-east-1", POST) .userIdentity(OTHER_USER_ID) .data(createApplicationDeployData(applicationPackageInstance1, false)), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // DELETE a dev deployment is not generally allowed under user instance tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/otheruser/environment/dev/region/us-east-1", DELETE) .userIdentity(OTHER_USER_ID), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // When the user is a tenant admin, user instances are allowed. @@ -648,6 +649,21 @@ public class ApplicationApiTest extends ControllerContainerTest { .screwdriverIdentity(SCREWDRIVER_ID), "{\"message\":\"Requested restart of tenant1.application1.instance1 in prod.us-central-1\"}", 200); + // POST a 'suspend application' in dev environment + tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/environment/dev/region/us-east-1/suspend", POST) + .userIdentity(USER_ID), + "{\"message\":\"Suspended orchestration of tenant1.application1.instance1 in dev.us-east-1\"}"); + + // POST a 'resume application' in dev environment + tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/environment/dev/region/us-east-1/suspend", DELETE) + .userIdentity(USER_ID), + "{\"message\":\"Resumed orchestration of tenant1.application1.instance1 in dev.us-east-1\"}"); + + // POST a 'suspend application' in prod environment fails + tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1/environment/prod/region/us-east-3/suspend", POST) + .userIdentity(USER_ID), + accessDenied, 403); + // GET suspended tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-central-1/instance/instance1/suspended", GET) .userIdentity(USER_ID), @@ -1060,7 +1076,7 @@ public class ApplicationApiTest extends ControllerContainerTest { .userIdentity(USER_ID) .oktaAccessToken(OKTA_AT).oktaIdentityToken(OKTA_IT) .data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}"), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // GET non-existing tenant @@ -1216,7 +1232,7 @@ public class ApplicationApiTest extends ControllerContainerTest { // DELETE tenant again returns 403 as tenant access cannot be determined when the tenant does not exist tester.assertResponse(request("/application/v4/tenant/tenant1", DELETE) .userIdentity(USER_ID), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // Create legancy tenant name containing underscores @@ -1271,7 +1287,7 @@ public class ApplicationApiTest extends ControllerContainerTest { tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/instance1", POST) .userIdentity(unauthorizedUser) .oktaAccessToken(OKTA_AT).oktaIdentityToken(OKTA_IT), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // (Create it with the right tenant id) @@ -1286,13 +1302,13 @@ public class ApplicationApiTest extends ControllerContainerTest { tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/environment/prod/region/us-west-1/instance/default/deploy", POST) .data(entity) .userIdentity(USER_ID), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // Deleting an application for an Athens domain the user is not admin for is disallowed tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1", DELETE) .userIdentity(unauthorizedUser), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // Create another instance under the application @@ -1313,7 +1329,7 @@ public class ApplicationApiTest extends ControllerContainerTest { tester.assertResponse(request("/application/v4/tenant/tenant1", PUT) .data("{\"athensDomain\":\"domain1\", \"property\":\"property1\"}") .userIdentity(unauthorizedUser), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // Change Athens domain @@ -1328,7 +1344,7 @@ public class ApplicationApiTest extends ControllerContainerTest { // Deleting a tenant for an Athens domain the user is not admin for is disallowed tester.assertResponse(request("/application/v4/tenant/tenant1", DELETE) .userIdentity(unauthorizedUser), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); } @@ -1394,7 +1410,7 @@ public class ApplicationApiTest extends ControllerContainerTest { tester.assertResponse(request("/application/v4/tenant/tenant1/application/application1/instance/new-user/deploy/dev-us-east-1", POST) .data(entity) .userIdentity(userId), - "{\n \"code\" : 403,\n \"message\" : \"Access denied\"\n}", + accessDenied, 403); // Add "new-user" to the admin role, to allow service launches. diff --git a/dist/vespa.spec b/dist/vespa.spec index b4bbb13e7d6..e897a03e58a 100644 --- a/dist/vespa.spec +++ b/dist/vespa.spec @@ -58,7 +58,7 @@ BuildRequires: vespa-gtest >= 1.8.1-1 BuildRequires: vespa-icu-devel >= 65.1.0-1 BuildRequires: vespa-lz4-devel >= 1.9.2-2 BuildRequires: vespa-onnxruntime-devel = 1.4.0 -BuildRequires: vespa-openssl-devel >= 1.1.1g-1 +BuildRequires: vespa-openssl-devel >= 1.1.1i-1 BuildRequires: vespa-protobuf-devel >= 3.7.0-4 BuildRequires: vespa-libzstd-devel >= 1.4.5-2 %endif @@ -167,7 +167,7 @@ Requires: llvm7.0 Requires: vespa-icu >= 65.1.0-1 Requires: vespa-lz4 >= 1.9.2-2 Requires: vespa-onnxruntime = 1.4.0 -Requires: vespa-openssl >= 1.1.1g-1 +Requires: vespa-openssl >= 1.1.1i-1 Requires: vespa-protobuf >= 3.7.0-4 Requires: vespa-telegraf >= 1.1.1-1 Requires: vespa-valgrind >= 3.16.0-1 @@ -251,7 +251,7 @@ Summary: Vespa - The open big data serving engine - base C++ libs Requires: xxhash-libs >= 0.8.0 %if 0%{?el7} -Requires: vespa-openssl >= 1.1.1g-1 +Requires: vespa-openssl >= 1.1.1i-1 %else Requires: openssl-libs %endif diff --git a/docproc/src/test/java/com/yahoo/docproc/util/SplitterJoinerTestCase.java b/docproc/src/test/java/com/yahoo/docproc/util/SplitterJoinerTestCase.java index aa55d5b6a41..6c8c485aacb 100644 --- a/docproc/src/test/java/com/yahoo/docproc/util/SplitterJoinerTestCase.java +++ b/docproc/src/test/java/com/yahoo/docproc/util/SplitterJoinerTestCase.java @@ -18,7 +18,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; /** - * @author <a href="mailto:einarmr@yahoo-inc.com">Einar M R Rosenvinge</a> + * @author Einar M R Rosenvinge */ @SuppressWarnings({"unchecked"}) public class SplitterJoinerTestCase { diff --git a/document/src/main/java/com/yahoo/document/BucketId.java b/document/src/main/java/com/yahoo/document/BucketId.java index 36ec53c51d3..f6d3dc34af1 100755 --- a/document/src/main/java/com/yahoo/document/BucketId.java +++ b/document/src/main/java/com/yahoo/document/BucketId.java @@ -74,9 +74,9 @@ public class BucketId implements Comparable<BucketId> { public int compareTo(BucketId other) { if (id >>> 32 == other.id >>> 32) { - if ((id & 0xFFFFFFFFl) > (other.id & 0xFFFFFFFFl)) { + if ((id & 0xFFFFFFFFL) > (other.id & 0xFFFFFFFFL)) { return 1; - } else if ((id & 0xFFFFFFFFl) < (other.id & 0xFFFFFFFFl)) { + } else if ((id & 0xFFFFFFFFL) < (other.id & 0xFFFFFFFFL)) { return -1; } return 0; @@ -97,8 +97,8 @@ public class BucketId implements Comparable<BucketId> { public long getId() { int notUsed = 64 - getUsedBits(); - long usedMask = (0xFFFFFFFFFFFFFFFFl << notUsed) >>> notUsed; - long countMask = (0xFFFFFFFFFFFFFFFFl >>> (64 - COUNT_BITS)) << (64 - COUNT_BITS); + long usedMask = (0xFFFFFFFFFFFFFFFFL << notUsed) >>> notUsed; + long countMask = (0xFFFFFFFFFFFFFFFFL >>> (64 - COUNT_BITS)) << (64 - COUNT_BITS); return id & (usedMask | countMask); } diff --git a/document/src/test/resources/tensor/multi_cell_tensor__cpp b/document/src/test/resources/tensor/multi_cell_tensor__cpp Binary files differindex deb53463fb5..9adda236a4a 100644 --- a/document/src/test/resources/tensor/multi_cell_tensor__cpp +++ b/document/src/test/resources/tensor/multi_cell_tensor__cpp diff --git a/document/src/vespa/document/update/tensor_partial_update.cpp b/document/src/vespa/document/update/tensor_partial_update.cpp index fbc60cc09af..f763c92741c 100644 --- a/document/src/vespa/document/update/tensor_partial_update.cpp +++ b/document/src/vespa/document/update/tensor_partial_update.cpp @@ -5,6 +5,7 @@ #include <vespa/vespalib/util/overload.h> #include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/util/visit_ranges.h> +#include <vespa/vespalib/util/shared_string_repo.h> #include <cassert> #include <set> @@ -43,7 +44,8 @@ struct DenseCoords { } ~DenseCoords(); void clear() { offset = 0; current = 0; } - void convert_label(vespalib::stringref label) { + void convert_label(label_t label_id) { + vespalib::string label = SharedStringRepo::Handle::string_from_id(label_id); uint32_t coord = 0; for (char c : label) { if (c < '0' || c > '9') { // bad char @@ -71,9 +73,9 @@ struct DenseCoords { DenseCoords::~DenseCoords() = default; struct SparseCoords { - std::vector<vespalib::stringref> addr; - std::vector<vespalib::stringref *> next_result_refs; - std::vector<const vespalib::stringref *> lookup_refs; + std::vector<label_t> addr; + std::vector<label_t *> next_result_refs; + std::vector<const label_t *> lookup_refs; std::vector<size_t> lookup_view_dims; SparseCoords(size_t sz) : addr(sz), next_result_refs(sz), lookup_refs(sz), lookup_view_dims(sz) @@ -327,7 +329,7 @@ calc_mapped_dimension_indexes(const ValueType& input_type, struct ModifierCoords { - std::vector<const vespalib::stringref *> lookup_refs; + std::vector<const label_t *> lookup_refs; std::vector<size_t> lookup_view_dims; ModifierCoords(const SparseCoords& input_coords, diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java index 875b74025d0..5def71e2d81 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/MessageBusAsyncSession.java @@ -277,13 +277,13 @@ public class MessageBusAsyncSession implements MessageBusSession, AsyncSession { } private static Response.Outcome toOutcome(Reply reply) { - if ( reply instanceof UpdateDocumentReply && ! ((UpdateDocumentReply) reply).wasFound() - || reply instanceof RemoveDocumentReply && ! ((RemoveDocumentReply) reply).wasFound()) - return NOT_FOUND; if (reply.getErrorCodes().contains(DocumentProtocol.ERROR_NO_SPACE)) return INSUFFICIENT_STORAGE; if (reply.getErrorCodes().contains(DocumentProtocol.ERROR_TEST_AND_SET_CONDITION_FAILED)) return CONDITION_FAILED; + if ( reply instanceof UpdateDocumentReply && ! ((UpdateDocumentReply) reply).wasFound() + || reply instanceof RemoveDocumentReply && ! ((RemoveDocumentReply) reply).wasFound()) + return NOT_FOUND; return ERROR; } diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/DocumentProtocol.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/DocumentProtocol.java index f5b4920fa3f..2680ed011af 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/DocumentProtocol.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/DocumentProtocol.java @@ -429,7 +429,7 @@ public class DocumentProtocol implements Protocol { * @param ctx the context whose children to merge */ public static void merge(RoutingContext ctx) { - merge(ctx, new HashSet<Integer>(0)); + merge(ctx, new HashSet<>(0)); } /** @@ -475,7 +475,7 @@ public class DocumentProtocol implements Protocol { * @return the merged Reply */ public static Reply merge(List<Reply> replies) { - return merge(replies, new HashSet<Integer>(0)).second; + return merge(replies, new HashSet<>(0)).second; } /** diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ReplyMerger.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ReplyMerger.java index a8ec53c5c97..630a8588495 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ReplyMerger.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/ReplyMerger.java @@ -3,6 +3,7 @@ package com.yahoo.documentapi.messagebus.protocol; import com.yahoo.collections.Tuple2; import com.yahoo.messagebus.EmptyReply; +import com.yahoo.messagebus.Error; import com.yahoo.messagebus.Reply; /** @@ -57,15 +58,22 @@ final class ReplyMerger { return; } if (error == null) { - error = new EmptyReply(); - r.swapState(error); - return; + error = r; + } + else if (mostSevereErrorCode(r) > mostSevereErrorCode(error)) { + error.getErrors().forEach(r::addError); + error = r; } - for (int j = 0; j < r.getNumErrors(); ++j) { - error.addError(r.getError(j)); + else { + r.getErrors().forEach(error::addError); } } + private static int mostSevereErrorCode(Reply reply) { + return reply.getErrors().mapToInt(Error::getCode).max() + .orElseThrow(() -> new IllegalArgumentException(reply + " has no errors")); + } + private boolean handleReplyWithOnlyIgnoredErrors(Reply r) { if (DocumentProtocol.hasOnlyErrorsOfType(r, DocumentProtocol.ERROR_MESSAGE_IGNORED)) { if (ignore == null) { @@ -96,7 +104,7 @@ final class ReplyMerger { } private Tuple2<Integer, Reply> createEmptyReplyResult() { - return new Tuple2<>(null, (Reply)new EmptyReply()); + return new Tuple2<>(null, new EmptyReply()); } public Tuple2<Integer, Reply> mergedReply() { diff --git a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/ReplyMergerTestCase.java b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/ReplyMergerTestCase.java index 157b4a6585b..11d74800a79 100644 --- a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/ReplyMergerTestCase.java +++ b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/ReplyMergerTestCase.java @@ -38,7 +38,7 @@ public class ReplyMergerTestCase { } @Test - public void mergingSingleReplyWithOneErrorReturnsEmptyReplyWithError() { + public void mergingSingleReplyWithOneErrorReturnsSameReplyWithError() { Reply r1 = new EmptyReply(); Error error = new Error(1234, "oh no!"); r1.addError(error); @@ -46,12 +46,12 @@ public class ReplyMergerTestCase { Tuple2<Integer, Reply> ret = merger.mergedReply(); assertNull(ret.first); - assertNotSame(r1, ret.second); + assertSame(r1, ret.second); assertThatErrorsMatch(new Error[] { error }, ret); } @Test - public void mergingSingleReplyWithMultipleErrorsReturnsEmptyReplyWithAllErrors() { + public void mergingSingleReplyWithMultipleErrorsReturnsSameReplyWithAllErrors() { Reply r1 = new EmptyReply(); Error errors[] = new Error[] { new Error(1234, "oh no!"), new Error(4567, "oh dear!"), @@ -62,12 +62,12 @@ public class ReplyMergerTestCase { Tuple2<Integer, Reply> ret = merger.mergedReply(); assertNull(ret.first); - assertNotSame(r1, ret.second); + assertSame(r1, ret.second); assertThatErrorsMatch(errors, ret); } @Test - public void mergingMultipleRepliesWithMultipleErrorsReturnsEmptyReplyWithAllErrors() { + public void mergingMultipleRepliesWithMultipleErrorsReturnsMostSevereReplyWithAllErrors() { Reply r1 = new EmptyReply(); Reply r2 = new EmptyReply(); Error errors[] = new Error[] { @@ -81,7 +81,7 @@ public class ReplyMergerTestCase { Tuple2<Integer, Reply> ret = merger.mergedReply(); assertNull(ret.first); - assertNotSame(r1, ret.second); + assertSame(r1, ret.second); assertNotSame(r2, ret.second); assertThatErrorsMatch(errors, ret); } @@ -143,7 +143,7 @@ public class ReplyMergerTestCase { merger.merge(1, r2); Tuple2<Integer, Reply> ret = merger.mergedReply(); assertNull(ret.first); - assertNotSame(r1, ret.second); + assertSame(r1, ret.second); assertNotSame(r2, ret.second); // All errors from replies with errors are included, not those that // are fully ignored. @@ -182,7 +182,7 @@ public class ReplyMergerTestCase { } @Test - // TODO: This seems wrong, and is probably a consequence of TAS being added later than this logic was written. + // TODO jonmv: This seems wrong, and is probably a consequence of TAS being implemented after reply merging. public void returnErrorDocumentReplyWhereDocWasFoundWhichIsProbablyWrong() { Error e1 = new Error(DocumentProtocol.ERROR_TEST_AND_SET_CONDITION_FAILED, "fail"); UpdateDocumentReply r1 = new UpdateDocumentReply(); @@ -197,7 +197,7 @@ public class ReplyMergerTestCase { merger.merge(2, r3); Tuple2<Integer, Reply> ret = merger.mergedReply(); assertNull(ret.first); - assertNotSame(r1, ret.second); + assertSame(r1, ret.second); assertThatErrorsMatch(new Error[] { e1 }, ret); } diff --git a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/PolicyTestFrame.java b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/PolicyTestFrame.java index 89d5db62899..92e94256411 100755 --- a/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/PolicyTestFrame.java +++ b/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/PolicyTestFrame.java @@ -155,16 +155,6 @@ public class PolicyTestFrame { } /** - * This is a convenience method for invoking {@link #assertMerge(Map,List,List)} with no expected value. - * - * @param replies The errors to set in the leaf node replies. - * @param expectedErrors The list of expected errors in the merged reply. - */ - public void assertMergeError(Map<String, Integer> replies, List<Integer> expectedErrors) { - assertMerge(replies, expectedErrors, null); - } - - /** * This is a convenience method for invoking {@link this#assertMerge(Map,List,List)} with no expected errors. * * @param replies The errors to set in the leaf node replies. @@ -233,10 +223,10 @@ public class PolicyTestFrame { Map<String, Integer> replies = new HashMap<>(); replies.put(recipient, ErrorCode.NONE); - assertMergeOk(replies, Arrays.asList(recipient)); + assertMergeOk(replies, List.of(recipient)); replies.put(recipient, ErrorCode.TRANSIENT_ERROR); - assertMergeError(replies, Arrays.asList(ErrorCode.TRANSIENT_ERROR)); + assertMerge(replies, List.of(ErrorCode.TRANSIENT_ERROR), List.of(recipient)); } /** @@ -252,28 +242,33 @@ public class PolicyTestFrame { Map<String, Integer> replies = new HashMap<>(); replies.put(recipientOne, ErrorCode.NONE); replies.put(recipientTwo, ErrorCode.NONE); - assertMergeOk(replies, Arrays.asList(recipientOne, recipientTwo)); + assertMergeOk(replies, List.of(recipientOne, recipientTwo)); replies.put(recipientOne, ErrorCode.TRANSIENT_ERROR); replies.put(recipientTwo, ErrorCode.NONE); - assertMergeError(replies, Arrays.asList(ErrorCode.TRANSIENT_ERROR)); + assertMerge(replies, List.of(ErrorCode.TRANSIENT_ERROR), List.of(recipientOne)); + + replies.put(recipientOne, ErrorCode.TRANSIENT_ERROR); + replies.put(recipientTwo, ErrorCode.FATAL_ERROR); + assertMerge(replies, List.of(ErrorCode.TRANSIENT_ERROR, ErrorCode.FATAL_ERROR), List.of(recipientTwo)); replies.put(recipientOne, ErrorCode.TRANSIENT_ERROR); replies.put(recipientTwo, ErrorCode.TRANSIENT_ERROR); - assertMergeError(replies, Arrays.asList(ErrorCode.TRANSIENT_ERROR, ErrorCode.TRANSIENT_ERROR)); + assertMerge(replies, Arrays.asList(ErrorCode.TRANSIENT_ERROR, ErrorCode.TRANSIENT_ERROR), List.of(recipientOne, recipientTwo)); replies.put(recipientOne, ErrorCode.NONE); replies.put(recipientTwo, DocumentProtocol.ERROR_MESSAGE_IGNORED); - assertMergeOk(replies, Arrays.asList(recipientOne)); + assertMergeOk(replies, List.of(recipientOne)); replies.put(recipientOne, DocumentProtocol.ERROR_MESSAGE_IGNORED); replies.put(recipientTwo, ErrorCode.NONE); - assertMergeOk(replies, Arrays.asList(recipientTwo)); + assertMergeOk(replies, List.of(recipientTwo)); replies.put(recipientOne, DocumentProtocol.ERROR_MESSAGE_IGNORED); replies.put(recipientTwo, DocumentProtocol.ERROR_MESSAGE_IGNORED); - assertMergeError(replies, Arrays.asList(DocumentProtocol.ERROR_MESSAGE_IGNORED, - DocumentProtocol.ERROR_MESSAGE_IGNORED)); + assertMerge(replies, List.of(DocumentProtocol.ERROR_MESSAGE_IGNORED, + DocumentProtocol.ERROR_MESSAGE_IGNORED), + null); // Only ignored errors specifically causes an EmptyReply. } /** diff --git a/eval/src/tests/eval/fast_value/fast_value_test.cpp b/eval/src/tests/eval/fast_value/fast_value_test.cpp index 03658d8351b..e809fb1bcda 100644 --- a/eval/src/tests/eval/fast_value/fast_value_test.cpp +++ b/eval/src/tests/eval/fast_value/fast_value_test.cpp @@ -8,6 +8,8 @@ using namespace vespalib; using namespace vespalib::eval; +using Handle = SharedStringRepo::Handle; + TEST(FastCellsTest, push_back_fast_works) { FastCells<float> cells(3); EXPECT_EQ(cells.capacity, 4); @@ -60,38 +62,37 @@ TEST(FastCellsTest, add_cells_works) { using SA = std::vector<vespalib::stringref>; -TEST(FastValueBuilderTest, dense_add_subspace_robustness) { +TEST(FastValueBuilderTest, scalar_add_subspace_robustness) { auto factory = FastValueBuilderFactory::get(); - ValueType type = ValueType::from_spec("tensor(x[2])"); + ValueType type = ValueType::from_spec("double"); auto builder = factory.create_value_builder<double>(type); - auto subspace = builder->add_subspace({}); + auto subspace = builder->add_subspace(); subspace[0] = 17.0; - subspace[1] = 666; - auto other = builder->add_subspace({}); - other[1] = 42.0; + auto other = builder->add_subspace(); + other[0] = 42.0; auto value = builder->build(std::move(builder)); + EXPECT_EQ(value->index().size(), 1); auto actual = spec_from_value(*value); - auto expected = TensorSpec("tensor(x[2])"). - add({{"x", 0}}, 17.0). - add({{"x", 1}}, 42.0); - EXPECT_EQ(actual, expected); + auto expected = TensorSpec("double"). + add({}, 42.0); + EXPECT_EQ(actual, expected); } -TEST(FastValueBuilderTest, sparse_add_subspace_robustness) { +TEST(FastValueBuilderTest, dense_add_subspace_robustness) { auto factory = FastValueBuilderFactory::get(); - ValueType type = ValueType::from_spec("tensor(x{})"); + ValueType type = ValueType::from_spec("tensor(x[2])"); auto builder = factory.create_value_builder<double>(type); - auto subspace = builder->add_subspace(SA{"foo"}); + auto subspace = builder->add_subspace(); subspace[0] = 17.0; - subspace = builder->add_subspace(SA{"bar"}); - subspace[0] = 18.0; - auto other = builder->add_subspace(SA{"foo"}); - other[0] = 42.0; + subspace[1] = 666; + auto other = builder->add_subspace(); + other[1] = 42.0; auto value = builder->build(std::move(builder)); + EXPECT_EQ(value->index().size(), 1); auto actual = spec_from_value(*value); - auto expected = TensorSpec("tensor(x{})"). - add({{"x", "bar"}}, 18.0). - add({{"x", "foo"}}, 42.0); + auto expected = TensorSpec("tensor(x[2])"). + add({{"x", 0}}, 17.0). + add({{"x", 1}}, 42.0); EXPECT_EQ(actual, expected); } @@ -100,21 +101,43 @@ TEST(FastValueBuilderTest, mixed_add_subspace_robustness) { ValueType type = ValueType::from_spec("tensor(x{},y[2])"); auto builder = factory.create_value_builder<double>(type); auto subspace = builder->add_subspace(SA{"foo"}); - subspace[0] = 17.0; - subspace[1] = 666; + subspace[0] = 1.0; + subspace[1] = 5.0; subspace = builder->add_subspace(SA{"bar"}); - subspace[0] = 18.0; - subspace[1] = 19.0; + subspace[0] = 2.0; + subspace[1] = 10.0; auto other = builder->add_subspace(SA{"foo"}); - other[1] = 42.0; + other[0] = 3.0; + other[1] = 15.0; auto value = builder->build(std::move(builder)); - auto actual = spec_from_value(*value); - auto expected = TensorSpec("tensor(x{},y[2])"). - add({{"x", "foo"}, {"y", 0}}, 17.0). - add({{"x", "bar"}, {"y", 0}}, 18.0). - add({{"x", "bar"}, {"y", 1}}, 19.0). - add({{"x", "foo"}, {"y", 1}}, 42.0); - EXPECT_EQ(actual, expected); + EXPECT_EQ(value->index().size(), 3); + Handle foo("foo"); + Handle bar("bar"); + label_t label; + label_t *label_ptr = &label; + size_t subspace_idx; + auto get_subspace = [&]() { + auto cells = value->cells().typify<double>(); + return ConstArrayRef<double>(cells.begin() + subspace_idx * 2, 2); + }; + auto view = value->index().create_view({}); + view->lookup({}); + while (view->next_result({&label_ptr, 1}, subspace_idx)) { + if (label == bar.id()) { + auto values = get_subspace(); + EXPECT_EQ(values[0], 2.0); + EXPECT_EQ(values[1], 10.0); + } else { + EXPECT_EQ(label, foo.id()); + auto values = get_subspace(); + if (values[0] == 1) { + EXPECT_EQ(values[1], 5.0); + } else { + EXPECT_EQ(values[0], 3.0); + EXPECT_EQ(values[1], 15.0); + } + } + } } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/tests/eval/simple_value/simple_value_test.cpp b/eval/src/tests/eval/simple_value/simple_value_test.cpp index c05f9976e1a..1691d5c263c 100644 --- a/eval/src/tests/eval/simple_value/simple_value_test.cpp +++ b/eval/src/tests/eval/simple_value/simple_value_test.cpp @@ -16,8 +16,12 @@ using namespace vespalib::eval::test; using vespalib::make_string_short::fmt; -using PA = std::vector<vespalib::stringref *>; -using CPA = std::vector<const vespalib::stringref *>; +using PA = std::vector<label_t *>; +using CPA = std::vector<const label_t *>; + +using Handle = SharedStringRepo::Handle; + +vespalib::string as_str(label_t label) { return Handle::string_from_id(label); } std::vector<Layout> layouts = { {}, @@ -98,17 +102,18 @@ TEST(SimpleValueTest, simple_value_can_be_built_and_inspected) { std::unique_ptr<Value> value = builder->build(std::move(builder)); EXPECT_EQ(value->index().size(), 6); auto view = value->index().create_view({0}); - vespalib::stringref query = "b"; - vespalib::stringref label; + Handle query_handle("b"); + label_t query = query_handle.id(); + label_t label; size_t subspace; + std::map<vespalib::string,size_t> result; view->lookup(CPA{&query}); - EXPECT_TRUE(view->next_result(PA{&label}, subspace)); - EXPECT_EQ(label, "aa"); - EXPECT_EQ(subspace, 2); - EXPECT_TRUE(view->next_result(PA{&label}, subspace)); - EXPECT_EQ(label, "bb"); - EXPECT_EQ(subspace, 3); - EXPECT_FALSE(view->next_result(PA{&label}, subspace)); + while (view->next_result(PA{&label}, subspace)) { + result[as_str(label)] = subspace; + } + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result["aa"], 2); + EXPECT_EQ(result["bb"], 3); } TEST(SimpleValueTest, new_generic_join_works_for_simple_values) { diff --git a/eval/src/tests/streamed/value/streamed_value_test.cpp b/eval/src/tests/streamed/value/streamed_value_test.cpp index 05d6e20451c..5221c4eda64 100644 --- a/eval/src/tests/streamed/value/streamed_value_test.cpp +++ b/eval/src/tests/streamed/value/streamed_value_test.cpp @@ -16,8 +16,12 @@ using namespace vespalib::eval::test; using vespalib::make_string_short::fmt; -using PA = std::vector<vespalib::stringref *>; -using CPA = std::vector<const vespalib::stringref *>; +using PA = std::vector<label_t *>; +using CPA = std::vector<const label_t *>; + +using Handle = SharedStringRepo::Handle; + +vespalib::string as_str(label_t label) { return Handle::string_from_id(label); } std::vector<Layout> layouts = { {}, @@ -98,17 +102,18 @@ TEST(StreamedValueTest, streamed_value_can_be_built_and_inspected) { std::unique_ptr<Value> value = builder->build(std::move(builder)); EXPECT_EQ(value->index().size(), 6); auto view = value->index().create_view({0}); - vespalib::stringref query = "b"; - vespalib::stringref label; + Handle query_handle("b"); + label_t query = query_handle.id(); + label_t label; size_t subspace; + std::map<vespalib::string,size_t> result; view->lookup(CPA{&query}); - EXPECT_TRUE(view->next_result(PA{&label}, subspace)); - EXPECT_EQ(label, "aa"); - EXPECT_EQ(subspace, 2); - EXPECT_TRUE(view->next_result(PA{&label}, subspace)); - EXPECT_EQ(label, "bb"); - EXPECT_EQ(subspace, 3); - EXPECT_FALSE(view->next_result(PA{&label}, subspace)); + while (view->next_result(PA{&label}, subspace)) { + result[as_str(label)] = subspace; + } + EXPECT_EQ(result.size(), 2); + EXPECT_EQ(result["aa"], 2); + EXPECT_EQ(result["bb"], 3); } TEST(StreamedValueTest, new_generic_join_works_for_streamed_values) { diff --git a/eval/src/vespa/eval/eval/CMakeLists.txt b/eval/src/vespa/eval/eval/CMakeLists.txt index 01eeff49662..5f8dd478a7b 100644 --- a/eval/src/vespa/eval/eval/CMakeLists.txt +++ b/eval/src/vespa/eval/eval/CMakeLists.txt @@ -10,6 +10,7 @@ vespa_add_library(eval_eval OBJECT delete_node.cpp dense_cells_value.cpp double_value_builder.cpp + fast_addr_map.cpp fast_forest.cpp fast_sparse_map.cpp fast_value.cpp diff --git a/eval/src/vespa/eval/eval/fast_addr_map.cpp b/eval/src/vespa/eval/eval/fast_addr_map.cpp new file mode 100644 index 00000000000..73163f411e6 --- /dev/null +++ b/eval/src/vespa/eval/eval/fast_addr_map.cpp @@ -0,0 +1,9 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "fast_addr_map.h" + +namespace vespalib::eval { + +FastAddrMap::~FastAddrMap() = default; + +} diff --git a/eval/src/vespa/eval/eval/fast_addr_map.h b/eval/src/vespa/eval/eval/fast_addr_map.h new file mode 100644 index 00000000000..a8a82718a28 --- /dev/null +++ b/eval/src/vespa/eval/eval/fast_addr_map.h @@ -0,0 +1,152 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include "label.h" +#include "memory_usage_stuff.h" +#include <vespa/vespalib/util/arrayref.h> +#include <vespa/vespalib/stllike/identity.h> +#include <vespa/vespalib/stllike/hashtable.h> +#include <vespa/vespalib/util/shared_string_repo.h> +#include <vector> + +namespace vespalib::eval { + +/** + * A wrapper around vespalib::hashtable, using it to map a list of + * labels (a sparse address) to an integer value (dense subspace + * index). Labels are represented by string enum values stored and + * handled outside this class. + **/ +class FastAddrMap +{ +public: + // label hasing functions + static constexpr uint32_t hash_label(label_t label) { return label; } + static constexpr uint32_t hash_label(const label_t *label) { return *label; } + static constexpr uint32_t combine_label_hash(uint32_t full_hash, uint32_t next_hash) { + return ((full_hash * 31) + next_hash); + } + template <typename T> + static constexpr uint32_t hash_labels(ConstArrayRef<T> addr) { + uint32_t hash = 0; + for (const T &label: addr) { + hash = combine_label_hash(hash, hash_label(label)); + } + return hash; + } + + // typed uint32_t index used to identify sparse address/dense subspace + struct Tag { + uint32_t idx; + static constexpr uint32_t npos() { return uint32_t(-1); } + static constexpr Tag make_invalid() { return Tag{npos()}; } + constexpr bool valid() const { return (idx != npos()); } + }; + + // sparse hash set entry + struct Entry { + Tag tag; + uint32_t hash; + }; + + // alternative key(s) used for lookup in sparse hash set + template <typename T> struct AltKey { + ConstArrayRef<T> key; + uint32_t hash; + }; + + // view able to convert tags into sparse addresses + struct LabelView { + size_t addr_size; + const std::vector<label_t> &labels; + LabelView(size_t num_mapped_dims, SharedStringRepo::HandleView handle_view) + : addr_size(num_mapped_dims), labels(handle_view.handles()) {} + ConstArrayRef<label_t> get_addr(size_t idx) const { + return {&labels[idx * addr_size], addr_size}; + } + }; + + // hashing functor for sparse hash set + struct Hash { + template <typename T> + constexpr uint32_t operator()(const AltKey<T> &key) const { return key.hash; } + constexpr uint32_t operator()(const Entry &entry) const { return entry.hash; } + }; + + // equality functor for sparse hash set + struct Equal { + const LabelView &label_view; + Equal(const LabelView &label_view_in) : label_view(label_view_in) {} + static constexpr bool eq_labels(label_t a, label_t b) { return (a == b); } + static constexpr bool eq_labels(label_t a, const label_t *b) { return (a == *b); } + template <typename T> + bool operator()(const Entry &a, const AltKey<T> &b) const { + if ((a.hash != b.hash) || (b.key.size() != label_view.addr_size)) { + return false; + } + auto a_key = label_view.get_addr(a.tag.idx); + for (size_t i = 0; i < a_key.size(); ++i) { + if (!eq_labels(a_key[i], b.key[i])) { + return false; + } + } + return true; + } + }; + + using HashType = hashtable<Entry, Entry, Hash, Equal, Identity, hashtable_base::and_modulator>; + +private: + LabelView _labels; + HashType _map; + +public: + FastAddrMap(size_t num_mapped_dims, SharedStringRepo::HandleView handle_view, size_t expected_subspaces) + : _labels(num_mapped_dims, handle_view), + _map(expected_subspaces * 2, Hash(), Equal(_labels)) {} + ~FastAddrMap(); + FastAddrMap(const FastAddrMap &) = delete; + FastAddrMap &operator=(const FastAddrMap &) = delete; + FastAddrMap(FastAddrMap &&) = delete; + FastAddrMap &operator=(FastAddrMap &&) = delete; + static constexpr size_t npos() { return -1; } + ConstArrayRef<label_t> get_addr(size_t idx) const { return _labels.get_addr(idx); } + size_t size() const { return _map.size(); } + constexpr size_t addr_size() const { return _labels.addr_size; } + template <typename T> + size_t lookup(ConstArrayRef<T> addr, uint32_t hash) const { + AltKey<T> key{addr, hash}; + auto pos = _map.find(key); + return (pos == _map.end()) ? npos() : pos->tag.idx; + } + template <typename T> + size_t lookup(ConstArrayRef<T> addr) const { + return lookup(addr, hash_labels(addr)); + } + void add_mapping(uint32_t hash) { + uint32_t idx = _map.size(); + _map.force_insert(Entry{{idx}, hash}); + } + template <typename F> + void each_map_entry(F &&f) const { + _map.for_each([&](const auto &entry) + { + f(entry.tag.idx, entry.hash); + }); + } + MemoryUsage estimate_extra_memory_usage() const { + MemoryUsage extra_usage; + size_t map_self_size = sizeof(_map); + size_t map_used = _map.getMemoryUsed(); + size_t map_allocated = _map.getMemoryConsumption(); + // avoid double-counting the map itself + map_used = std::min(map_used, map_used - map_self_size); + map_allocated = std::min(map_allocated, map_allocated - map_self_size); + extra_usage.incUsedBytes(map_used); + extra_usage.incAllocatedBytes(map_allocated); + return extra_usage; + } +}; + +} diff --git a/eval/src/vespa/eval/eval/fast_value.cpp b/eval/src/vespa/eval/eval/fast_value.cpp index 116e561a868..96d0fa84149 100644 --- a/eval/src/vespa/eval/eval/fast_value.cpp +++ b/eval/src/vespa/eval/eval/fast_value.cpp @@ -11,7 +11,7 @@ namespace vespalib::eval { namespace { struct CreateFastValueBuilderBase { - template <typename T> static std::unique_ptr<ValueBuilderBase> invoke(const ValueType &type, + template <typename T, typename R2> static std::unique_ptr<ValueBuilderBase> invoke(const ValueType &type, size_t num_mapped_dims, size_t subspace_size, size_t expected_subspaces) { assert(check_cell_type<T>(type.cell_type())); @@ -20,7 +20,7 @@ struct CreateFastValueBuilderBase { } else if (num_mapped_dims == 0) { return std::make_unique<FastDenseValue<T>>(type, subspace_size); } else { - return std::make_unique<FastValue<T>>(type, num_mapped_dims, subspace_size, expected_subspaces); + return std::make_unique<FastValue<T,R2::value>>(type, num_mapped_dims, subspace_size, expected_subspaces); } } }; @@ -32,11 +32,11 @@ struct CreateFastValueBuilderBase { std::unique_ptr<Value::Index::View> FastValueIndex::create_view(const std::vector<size_t> &dims) const { - if (map.num_dims() == 0) { + if (map.addr_size() == 0) { return TrivialIndex::get().create_view(dims); } else if (dims.empty()) { return std::make_unique<FastIterateView>(map); - } else if (dims.size() == map.num_dims()) { + } else if (dims.size() == map.addr_size()) { return std::make_unique<FastLookupView>(map); } else { return std::make_unique<FastFilterView>(map, dims); @@ -49,10 +49,11 @@ FastValueBuilderFactory::FastValueBuilderFactory() = default; FastValueBuilderFactory FastValueBuilderFactory::_factory; std::unique_ptr<ValueBuilderBase> -FastValueBuilderFactory::create_value_builder_base(const ValueType &type, size_t num_mapped_dims, size_t subspace_size, - size_t expected_subspaces) const +FastValueBuilderFactory::create_value_builder_base(const ValueType &type, bool transient, size_t num_mapped_dims, size_t subspace_size, + size_t expected_subspaces) const { - return typify_invoke<1,TypifyCellType,CreateFastValueBuilderBase>(type.cell_type(), type, num_mapped_dims, subspace_size, expected_subspaces); + using MyTypify = TypifyValue<TypifyCellType,TypifyBool>; + return typify_invoke<2,MyTypify,CreateFastValueBuilderBase>(type.cell_type(), transient, type, num_mapped_dims, subspace_size, expected_subspaces); } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/fast_value.h b/eval/src/vespa/eval/eval/fast_value.h index ac924ecc6eb..c6280b492db 100644 --- a/eval/src/vespa/eval/eval/fast_value.h +++ b/eval/src/vespa/eval/eval/fast_value.h @@ -19,7 +19,7 @@ class FastValueBuilderFactory : public ValueBuilderFactory { private: FastValueBuilderFactory(); static FastValueBuilderFactory _factory; - std::unique_ptr<ValueBuilderBase> create_value_builder_base(const ValueType &type, + std::unique_ptr<ValueBuilderBase> create_value_builder_base(const ValueType &type, bool transient, size_t num_mapped_dims, size_t subspace_size, size_t expected_subspaces) const override; public: static const FastValueBuilderFactory &get() { return _factory; } diff --git a/eval/src/vespa/eval/eval/fast_value.hpp b/eval/src/vespa/eval/eval/fast_value.hpp index 9914378cc9e..972aa68b8bd 100644 --- a/eval/src/vespa/eval/eval/fast_value.hpp +++ b/eval/src/vespa/eval/eval/fast_value.hpp @@ -1,11 +1,10 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "value.h" -#include "fast_sparse_map.h" +#include "fast_addr_map.h" #include "inline_operation.h" #include <vespa/eval/instruction/generic_join.h> -#include <vespa/vespalib/stllike/hash_map.hpp> -#include <vespa/vespalib/util/alloc.h> +#include <vespa/vespalib/stllike/hashtable.hpp> namespace vespalib::eval { @@ -18,22 +17,22 @@ namespace { // look up a full address in the map directly struct FastLookupView : public Value::Index::View { - const FastSparseMap ↦ - size_t subspace; + const FastAddrMap ↦ + size_t subspace; - FastLookupView(const FastSparseMap &map_in) - : map(map_in), subspace(FastSparseMap::npos()) {} + FastLookupView(const FastAddrMap &map_in) + : map(map_in), subspace(FastAddrMap::npos()) {} - void lookup(ConstArrayRef<const vespalib::stringref*> addr) override { + void lookup(ConstArrayRef<const label_t*> addr) override { subspace = map.lookup(addr); } - bool next_result(ConstArrayRef<vespalib::stringref*>, size_t &idx_out) override { - if (subspace == FastSparseMap::npos()) { + bool next_result(ConstArrayRef<label_t*>, size_t &idx_out) override { + if (subspace == FastAddrMap::npos()) { return false; } idx_out = subspace; - subspace = FastSparseMap::npos(); + subspace = FastAddrMap::npos(); return true; } }; @@ -43,30 +42,27 @@ struct FastLookupView : public Value::Index::View { // find matching mappings for a partial address with brute force filtering struct FastFilterView : public Value::Index::View { - using Label = FastSparseMap::HashedLabel; - - size_t num_mapped_dims; - const std::vector<Label> &labels; + const FastAddrMap ↦ std::vector<size_t> match_dims; std::vector<size_t> extract_dims; - std::vector<Label> query; + std::vector<label_t> query; size_t pos; - bool is_match() const { + bool is_match(ConstArrayRef<label_t> addr) const { for (size_t i = 0; i < query.size(); ++i) { - if (query[i].hash != labels[pos + match_dims[i]].hash) { + if (query[i] != addr[match_dims[i]]) { return false; } } return true; } - FastFilterView(const FastSparseMap &map, const std::vector<size_t> &match_dims_in) - : num_mapped_dims(map.num_dims()), labels(map.labels()), match_dims(match_dims_in), - extract_dims(), query(match_dims.size(), Label()), pos(labels.size()) + FastFilterView(const FastAddrMap &map_in, const std::vector<size_t> &match_dims_in) + : map(map_in), match_dims(match_dims_in), + extract_dims(), query(match_dims.size()), pos(FastAddrMap::npos()) { auto my_pos = match_dims.begin(); - for (size_t i = 0; i < num_mapped_dims; ++i) { + for (size_t i = 0; i < map.addr_size(); ++i) { if ((my_pos == match_dims.end()) || (*my_pos != i)) { extract_dims.push_back(i); } else { @@ -74,29 +70,29 @@ struct FastFilterView : public Value::Index::View { } } assert(my_pos == match_dims.end()); - assert((match_dims.size() + extract_dims.size()) == num_mapped_dims); + assert((match_dims.size() + extract_dims.size()) == map.addr_size()); } - void lookup(ConstArrayRef<const vespalib::stringref*> addr) override { + void lookup(ConstArrayRef<const label_t*> addr) override { assert(addr.size() == query.size()); for (size_t i = 0; i < addr.size(); ++i) { - query[i] = Label(*addr[i]); + query[i] = *addr[i]; } pos = 0; } - bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) override { - while (pos < labels.size()) { - if (is_match()) { + bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) override { + while (pos < map.size()) { + auto addr = map.get_addr(pos); + if (is_match(addr)) { assert(addr_out.size() == extract_dims.size()); for (size_t i = 0; i < extract_dims.size(); ++i) { - *addr_out[i] = labels[pos + extract_dims[i]].label; + *addr_out[i] = addr[extract_dims[i]]; } - idx_out = (pos / num_mapped_dims); // is this expensive? - pos += num_mapped_dims; + idx_out = pos++; return true; } - pos += num_mapped_dims; + ++pos; } return false; } @@ -107,29 +103,26 @@ struct FastFilterView : public Value::Index::View { // iterate all mappings struct FastIterateView : public Value::Index::View { - using Labels = std::vector<FastSparseMap::HashedLabel>; - - size_t num_mapped_dims; - const Labels &labels; - size_t pos; + const FastAddrMap ↦ + size_t pos; - FastIterateView(const FastSparseMap &map) - : num_mapped_dims(map.num_dims()), labels(map.labels()), pos(labels.size()) {} + FastIterateView(const FastAddrMap &map_in) + : map(map_in), pos(FastAddrMap::npos()) {} - void lookup(ConstArrayRef<const vespalib::stringref*>) override { + void lookup(ConstArrayRef<const label_t*>) override { pos = 0; } - bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) override { - if (pos >= labels.size()) { + bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) override { + if (pos >= map.size()) { return false; } - assert(addr_out.size() == num_mapped_dims); - for (size_t i = 0; i < num_mapped_dims; ++i) { - *addr_out[i] = labels[pos + i].label; + auto addr = map.get_addr(pos); + assert(addr.size() == addr_out.size()); + for (size_t i = 0; i < addr.size(); ++i) { + *addr_out[i] = addr[i]; } - idx_out = (pos / num_mapped_dims); // is this expensive? - pos += num_mapped_dims; + idx_out = pos++; return true; } }; @@ -145,9 +138,9 @@ using JoinAddrSource = instruction::SparseJoinPlan::Source; // operations by calling inline functions directly. struct FastValueIndex final : Value::Index { - FastSparseMap map; - FastValueIndex(size_t num_mapped_dims_in, size_t expected_subspaces_in) - : map(num_mapped_dims_in, expected_subspaces_in) {} + FastAddrMap map; + FastValueIndex(size_t num_mapped_dims_in, SharedStringRepo::HandleView handle_view, size_t expected_subspaces_in) + : map(num_mapped_dims_in, handle_view, expected_subspaces_in) {} template <typename LCT, typename RCT, typename OCT, typename Fun> static const Value &sparse_full_overlap_join(const ValueType &res_type, const Fun &fun, @@ -220,31 +213,64 @@ struct FastCells { //----------------------------------------------------------------------------- -template <typename T> +template <typename T, bool transient> struct FastValue final : Value, ValueBuilder<T> { + using Handles = std::conditional<transient, + SharedStringRepo::WeakHandles, + SharedStringRepo::StrongHandles>::type; + ValueType my_type; size_t my_subspace_size; + Handles my_handles; FastValueIndex my_index; FastCells<T> my_cells; FastValue(const ValueType &type_in, size_t num_mapped_dims_in, size_t subspace_size_in, size_t expected_subspaces_in) : my_type(type_in), my_subspace_size(subspace_size_in), - my_index(num_mapped_dims_in, expected_subspaces_in), + my_handles(expected_subspaces_in * num_mapped_dims_in), + my_index(num_mapped_dims_in, my_handles.view(), expected_subspaces_in), my_cells(subspace_size_in * expected_subspaces_in) {} ~FastValue() override; const ValueType &type() const override { return my_type; } const Value::Index &index() const override { return my_index; } TypedCells cells() const override { return TypedCells(my_cells.memory, get_cell_type<T>(), my_cells.size); } + void add_mapping(ConstArrayRef<vespalib::stringref> addr) { + if constexpr (transient) { + (void) addr; + abort(); // cannot use this for transient values + } else { + uint32_t hash = 0; + for (const auto &label: addr) { + hash = FastAddrMap::combine_label_hash(hash, FastAddrMap::hash_label(my_handles.add(label))); + } + my_index.map.add_mapping(hash); + } + } + void add_mapping(ConstArrayRef<label_t> addr) { + uint32_t hash = 0; + for (label_t label: addr) { + hash = FastAddrMap::combine_label_hash(hash, FastAddrMap::hash_label(label)); + my_handles.add(label); + } + my_index.map.add_mapping(hash); + } + void add_mapping(ConstArrayRef<label_t> addr, uint32_t hash) { + for (label_t label: addr) { + my_handles.add(label); + } + my_index.map.add_mapping(hash); + } ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref> addr) override { - size_t idx = my_index.map.add_mapping(addr) * my_subspace_size; - if (__builtin_expect((idx == my_cells.size), true)) { - return my_cells.add_cells(my_subspace_size); - } - return ArrayRef<T>(my_cells.get(idx), my_subspace_size); + add_mapping(addr); + return my_cells.add_cells(my_subspace_size); + } + ArrayRef<T> add_subspace(ConstArrayRef<label_t> addr) override { + add_mapping(addr); + return my_cells.add_cells(my_subspace_size); } std::unique_ptr<Value> build(std::unique_ptr<ValueBuilder<T>> self) override { - if (my_index.map.num_dims() == 0) { + if (my_index.map.addr_size() == 0) { assert(my_index.map.size() == 1); } assert(my_cells.size == (my_index.map.size() * my_subspace_size)); @@ -254,13 +280,14 @@ struct FastValue final : Value, ValueBuilder<T> { return std::unique_ptr<Value>(this); } MemoryUsage get_memory_usage() const override { - MemoryUsage usage = self_memory_usage<FastValue<T>>(); + MemoryUsage usage = self_memory_usage<FastValue<T,transient>>(); + usage.merge(vector_extra_memory_usage(my_handles.view().handles())); usage.merge(my_index.map.estimate_extra_memory_usage()); usage.merge(my_cells.estimate_extra_memory_usage()); return usage; } }; -template <typename T> FastValue<T>::~FastValue() = default; +template <typename T,bool transient> FastValue<T,transient>::~FastValue() = default; //----------------------------------------------------------------------------- @@ -282,6 +309,9 @@ struct FastDenseValue final : Value, ValueBuilder<T> { ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref>) override { return ArrayRef<T>(my_cells.get(0), my_cells.size); } + ArrayRef<T> add_subspace(ConstArrayRef<label_t>) override { + return ArrayRef<T>(my_cells.get(0), my_cells.size); + } std::unique_ptr<Value> build(std::unique_ptr<ValueBuilder<T>> self) override { ValueBuilder<T>* me = this; assert(me == self.get()); @@ -289,7 +319,7 @@ struct FastDenseValue final : Value, ValueBuilder<T> { return std::unique_ptr<Value>(this); } MemoryUsage get_memory_usage() const override { - MemoryUsage usage = self_memory_usage<FastValue<T>>(); + MemoryUsage usage = self_memory_usage<FastDenseValue<T>>(); usage.merge(my_cells.estimate_extra_memory_usage()); return usage; } @@ -302,6 +332,7 @@ template <typename T> struct FastScalarBuilder final : ValueBuilder<T> { T _value; ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref>) final override { return ArrayRef<T>(&_value, 1); } + ArrayRef<T> add_subspace(ConstArrayRef<label_t>) final override { return ArrayRef<T>(&_value, 1); }; std::unique_ptr<Value> build(std::unique_ptr<ValueBuilder<T>>) final override { return std::make_unique<ScalarValue<T>>(_value); } }; @@ -313,19 +344,16 @@ FastValueIndex::sparse_full_overlap_join(const ValueType &res_type, const Fun &f const FastValueIndex &lhs, const FastValueIndex &rhs, ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash) { - auto &result = stash.create<FastValue<OCT>>(res_type, lhs.map.num_dims(), 1, lhs.map.size()); - auto &result_map = result.my_index.map; - lhs.map.each_map_entry([&](auto lhs_subspace, auto hash) - { - auto rhs_subspace = rhs.map.lookup(hash); - if (rhs_subspace != FastSparseMap::npos()) { - auto idx = result_map.add_mapping(lhs.map.make_addr(lhs_subspace), hash); - if (__builtin_expect((idx == result.my_cells.size), true)) { - auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); - result.my_cells.push_back_fast(cell_value); - } - } - }); + auto &result = stash.create<FastValue<OCT,true>>(res_type, lhs.map.addr_size(), 1, lhs.map.size()); + lhs.map.each_map_entry([&](auto lhs_subspace, auto hash) { + auto lhs_addr = lhs.map.get_addr(lhs_subspace); + auto rhs_subspace = rhs.map.lookup(lhs_addr, hash); + if (rhs_subspace != FastAddrMap::npos()) { + result.add_mapping(lhs_addr, hash); + auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); + result.my_cells.push_back_fast(cell_value); + } + }); return result; } @@ -338,10 +366,9 @@ FastValueIndex::sparse_no_overlap_join(const ValueType &res_type, const Fun &fun const std::vector<JoinAddrSource> &addr_sources, ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash) { - using HashedLabelRef = std::reference_wrapper<const FastSparseMap::HashedLabel>; size_t num_mapped_dims = addr_sources.size(); - auto &result = stash.create<FastValue<OCT>>(res_type, num_mapped_dims, 1, lhs.map.size()*rhs.map.size()); - std::vector<HashedLabelRef> output_addr(num_mapped_dims, FastSparseMap::empty_label); + auto &result = stash.create<FastValue<OCT,true>>(res_type, num_mapped_dims, 1, lhs.map.size()*rhs.map.size()); + std::vector<label_t> output_addr(num_mapped_dims); std::vector<size_t> store_lhs_idx; std::vector<size_t> store_rhs_idx; size_t out_idx = 0; @@ -359,24 +386,22 @@ FastValueIndex::sparse_no_overlap_join(const ValueType &res_type, const Fun &fun } assert(out_idx == output_addr.size()); for (size_t lhs_subspace = 0; lhs_subspace < lhs.map.size(); ++lhs_subspace) { - auto l_addr = lhs.map.make_addr(lhs_subspace); + auto l_addr = lhs.map.get_addr(lhs_subspace); assert(l_addr.size() == store_lhs_idx.size()); for (size_t i = 0; i < store_lhs_idx.size(); ++i) { size_t addr_idx = store_lhs_idx[i]; output_addr[addr_idx] = l_addr[i]; } for (size_t rhs_subspace = 0; rhs_subspace < rhs.map.size(); ++rhs_subspace) { - auto r_addr = rhs.map.make_addr(rhs_subspace); + auto r_addr = rhs.map.get_addr(rhs_subspace); assert(r_addr.size() == store_rhs_idx.size()); for (size_t i = 0; i < store_rhs_idx.size(); ++i) { size_t addr_idx = store_rhs_idx[i]; output_addr[addr_idx] = r_addr[i]; } - auto idx = result.my_index.map.add_mapping(ConstArrayRef(output_addr)); - if (__builtin_expect((idx == result.my_cells.size), true)) { - auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); - result.my_cells.push_back_fast(cell_value); - } + result.add_mapping(ConstArrayRef(output_addr)); + auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); + result.my_cells.push_back_fast(cell_value); } } return result; @@ -391,22 +416,22 @@ FastValueIndex::sparse_only_merge(const ValueType &res_type, const Fun &fun, ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash) { size_t guess_size = lhs.map.size() + rhs.map.size(); - auto &result = stash.create<FastValue<OCT>>(res_type, lhs.map.num_dims(), 1, guess_size); - result.my_index = lhs; - for (auto val : lhs_cells) { - result.my_cells.push_back_fast(val); - } + auto &result = stash.create<FastValue<OCT,true>>(res_type, lhs.map.addr_size(), 1, guess_size); + lhs.map.each_map_entry([&](auto lhs_subspace, auto hash) + { + result.add_mapping(lhs.map.get_addr(lhs_subspace), hash); + result.my_cells.push_back_fast(lhs_cells[lhs_subspace]); + }); rhs.map.each_map_entry([&](auto rhs_subspace, auto hash) { - auto lhs_subspace = lhs.map.lookup(hash); - if (lhs_subspace == FastSparseMap::npos()) { - auto idx = result.my_index.map.add_mapping(rhs.map.make_addr(rhs_subspace), hash); - if (__builtin_expect((idx == result.my_cells.size), true)) { - result.my_cells.push_back_fast(rhs_cells[rhs_subspace]); - } + auto rhs_addr = rhs.map.get_addr(rhs_subspace); + auto result_subspace = result.my_index.map.lookup(rhs_addr, hash); + if (result_subspace == FastAddrMap::npos()) { + result.add_mapping(rhs_addr, hash); + result.my_cells.push_back_fast(rhs_cells[rhs_subspace]); } else { - auto cell_value = fun(lhs_cells[lhs_subspace], rhs_cells[rhs_subspace]); - *result.my_cells.get(lhs_subspace) = cell_value; + OCT &out_cell = *result.my_cells.get(result_subspace); + out_cell = fun(out_cell, rhs_cells[rhs_subspace]); } }); return result; diff --git a/eval/src/vespa/eval/eval/label.h b/eval/src/vespa/eval/eval/label.h new file mode 100644 index 00000000000..931f96a4f1a --- /dev/null +++ b/eval/src/vespa/eval/eval/label.h @@ -0,0 +1,15 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <cstdint> + +namespace vespalib::eval { + +// We use string ids from SharedStringRepo as labels. Note that +// label_t represents the lightweight reference type. Other structures +// (Handle/StrongHandles) are needed to keep the id valid. + +using label_t = uint32_t; + +} diff --git a/eval/src/vespa/eval/eval/simple_value.cpp b/eval/src/vespa/eval/eval/simple_value.cpp index 113f89f77fb..0cbbb29ecf1 100644 --- a/eval/src/vespa/eval/eval/simple_value.cpp +++ b/eval/src/vespa/eval/eval/simple_value.cpp @@ -30,7 +30,8 @@ struct CreateSimpleValueBuilderBase { // look up a full address in the map directly struct SimpleLookupView : public Value::Index::View { - using Labels = std::vector<vespalib::string>; + using Handle = SharedStringRepo::Handle; + using Labels = std::vector<Handle>; using Map = std::map<Labels, size_t>; const Map ↦ @@ -38,17 +39,17 @@ struct SimpleLookupView : public Value::Index::View { Map::const_iterator pos; SimpleLookupView(const Map &map_in, size_t num_dims) - : map(map_in), my_addr(num_dims, ""), pos(map.end()) {} + : map(map_in), my_addr(num_dims), pos(map.end()) {} - void lookup(ConstArrayRef<const vespalib::stringref*> addr) override { + void lookup(ConstArrayRef<const label_t*> addr) override { assert(addr.size() == my_addr.size()); for (size_t i = 0; i < my_addr.size(); ++i) { - my_addr[i] = *addr[i]; + my_addr[i] = Handle::handle_from_id(*addr[i]); } pos = map.find(my_addr); } - bool next_result(ConstArrayRef<vespalib::stringref*>, size_t &idx_out) override { + bool next_result(ConstArrayRef<label_t*>, size_t &idx_out) override { if (pos == map.end()) { return false; } @@ -63,13 +64,14 @@ struct SimpleLookupView : public Value::Index::View { // find matching mappings for a partial address with brute force filtering struct SimpleFilterView : public Value::Index::View { - using Labels = std::vector<vespalib::string>; + using Handle = SharedStringRepo::Handle; + using Labels = std::vector<Handle>; using Map = std::map<Labels, size_t>; const Map ↦ std::vector<size_t> match_dims; std::vector<size_t> extract_dims; - std::vector<vespalib::string> query; + std::vector<Handle> query; Map::const_iterator pos; bool is_match() const { @@ -82,7 +84,7 @@ struct SimpleFilterView : public Value::Index::View { } SimpleFilterView(const Map &map_in, const std::vector<size_t> &match_dims_in, size_t num_dims) - : map(map_in), match_dims(match_dims_in), extract_dims(), query(match_dims.size(), ""), pos(map.end()) + : map(map_in), match_dims(match_dims_in), extract_dims(), query(match_dims.size()), pos(map.end()) { auto my_pos = match_dims.begin(); for (size_t i = 0; i < num_dims; ++i) { @@ -96,20 +98,20 @@ struct SimpleFilterView : public Value::Index::View { assert((match_dims.size() + extract_dims.size()) == num_dims); } - void lookup(ConstArrayRef<const vespalib::stringref*> addr) override { + void lookup(ConstArrayRef<const label_t*> addr) override { assert(addr.size() == query.size()); for (size_t i = 0; i < addr.size(); ++i) { - query[i] = *addr[i]; + query[i] = Handle::handle_from_id(*addr[i]); } pos = map.begin(); } - bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) override { + bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) override { while (pos != map.end()) { if (is_match()) { assert(addr_out.size() == extract_dims.size()); for (size_t i = 0; i < extract_dims.size(); ++i) { - *addr_out[i] = pos->first[extract_dims[i]]; + *addr_out[i] = pos->first[extract_dims[i]].id(); } idx_out = pos->second; ++pos; @@ -126,7 +128,8 @@ struct SimpleFilterView : public Value::Index::View { // iterate all mappings struct SimpleIterateView : public Value::Index::View { - using Labels = std::vector<vespalib::string>; + using Handle = SharedStringRepo::Handle; + using Labels = std::vector<Handle>; using Map = std::map<Labels, size_t>; const Map ↦ @@ -135,17 +138,17 @@ struct SimpleIterateView : public Value::Index::View { SimpleIterateView(const Map &map_in) : map(map_in), pos(map.end()) {} - void lookup(ConstArrayRef<const vespalib::stringref*>) override { + void lookup(ConstArrayRef<const label_t*>) override { pos = map.begin(); } - bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) override { + bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) override { if (pos == map.end()) { return false; } assert(addr_out.size() == pos->first.size()); for (size_t i = 0; i < addr_out.size(); ++i) { - *addr_out[i] = pos->first[i]; + *addr_out[i] = pos->first[i].id(); } idx_out = pos->second; ++pos; @@ -182,6 +185,17 @@ SimpleValue::add_mapping(ConstArrayRef<vespalib::stringref> addr) assert(was_inserted); } +void +SimpleValue::add_mapping(ConstArrayRef<label_t> addr) +{ + Labels my_addr; + for(label_t label: addr) { + my_addr.emplace_back(Handle::handle_from_id(label)); + } + auto [ignore, was_inserted] = _index.emplace(my_addr, _index.size()); + assert(was_inserted); +} + MemoryUsage SimpleValue::estimate_extra_memory_usage() const { @@ -246,15 +260,26 @@ SimpleValueT<T>::add_subspace(ConstArrayRef<vespalib::stringref> addr) return ArrayRef<T>(&_cells[old_size], subspace_size()); } +template <typename T> +ArrayRef<T> +SimpleValueT<T>::add_subspace(ConstArrayRef<label_t> addr) +{ + size_t old_size = _cells.size(); + add_mapping(addr); + _cells.resize(old_size + subspace_size(), std::numeric_limits<T>::quiet_NaN()); + return ArrayRef<T>(&_cells[old_size], subspace_size()); +} + //----------------------------------------------------------------------------- SimpleValueBuilderFactory::SimpleValueBuilderFactory() = default; SimpleValueBuilderFactory SimpleValueBuilderFactory::_factory; std::unique_ptr<ValueBuilderBase> -SimpleValueBuilderFactory::create_value_builder_base(const ValueType &type, size_t num_mapped_dims, size_t subspace_size, +SimpleValueBuilderFactory::create_value_builder_base(const ValueType &type, bool transient, size_t num_mapped_dims, size_t subspace_size, size_t expected_subspaces) const { + (void) transient; return typify_invoke<1,TypifyCellType,CreateSimpleValueBuilderBase>(type.cell_type(), type, num_mapped_dims, subspace_size, expected_subspaces); } diff --git a/eval/src/vespa/eval/eval/simple_value.h b/eval/src/vespa/eval/eval/simple_value.h index 590c0b4ef16..1fd645b704c 100644 --- a/eval/src/vespa/eval/eval/simple_value.h +++ b/eval/src/vespa/eval/eval/simple_value.h @@ -3,7 +3,7 @@ #pragma once #include "value.h" -#include <vespa/vespalib/stllike/string.h> +#include <vespa/vespalib/util/shared_string_repo.h> #include <vector> #include <map> @@ -26,7 +26,8 @@ class TensorSpec; class SimpleValue : public Value, public Value::Index { private: - using Labels = std::vector<vespalib::string>; + using Handle = SharedStringRepo::Handle; + using Labels = std::vector<Handle>; ValueType _type; size_t _num_mapped_dims; @@ -36,6 +37,7 @@ protected: size_t num_mapped_dims() const { return _num_mapped_dims; } size_t subspace_size() const { return _subspace_size; } void add_mapping(ConstArrayRef<vespalib::stringref> addr); + void add_mapping(ConstArrayRef<label_t> addr); MemoryUsage estimate_extra_memory_usage() const; public: SimpleValue(const ValueType &type, size_t num_mapped_dims_in, size_t subspace_size_in); @@ -62,6 +64,7 @@ public: ~SimpleValueT() override; TypedCells cells() const override { return TypedCells(ConstArrayRef<T>(_cells)); } ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref> addr) override; + ArrayRef<T> add_subspace(ConstArrayRef<label_t> addr) override; std::unique_ptr<Value> build(std::unique_ptr<ValueBuilder<T>> self) override { if (num_mapped_dims() == 0) { assert(size() == 1); @@ -87,7 +90,7 @@ class SimpleValueBuilderFactory : public ValueBuilderFactory { private: SimpleValueBuilderFactory(); static SimpleValueBuilderFactory _factory; - std::unique_ptr<ValueBuilderBase> create_value_builder_base(const ValueType &type, + std::unique_ptr<ValueBuilderBase> create_value_builder_base(const ValueType &type, bool transient, size_t num_mapped_dims, size_t subspace_size, size_t expected_subspaces) const override; public: static const SimpleValueBuilderFactory &get() { return _factory; } diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp index 7abc8d568cb..73c7c40636c 100644 --- a/eval/src/vespa/eval/eval/value.cpp +++ b/eval/src/vespa/eval/eval/value.cpp @@ -12,8 +12,8 @@ namespace { struct TrivialView : Value::Index::View { bool first = false; - void lookup(ConstArrayRef<const vespalib::stringref*> ) override { first = true; } - bool next_result(ConstArrayRef<vespalib::stringref*> , size_t &idx_out) override { + void lookup(ConstArrayRef<const label_t*> ) override { first = true; } + bool next_result(ConstArrayRef<label_t*> , size_t &idx_out) override { if (first) { idx_out = 0; first = false; diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h index 186c3698dcd..2efb7d7c1e4 100644 --- a/eval/src/vespa/eval/eval/value.h +++ b/eval/src/vespa/eval/eval/value.h @@ -2,6 +2,7 @@ #pragma once +#include "label.h" #include "memory_usage_stuff.h" #include "value_type.h" #include "typed_cells.h" @@ -36,13 +37,13 @@ struct Value { // partial address for the dimensions given to // create_view. Results from the lookup is extracted using // the next_result function. - virtual void lookup(ConstArrayRef<const vespalib::stringref*> addr) = 0; + virtual void lookup(ConstArrayRef<const label_t*> addr) = 0; // Extract the next result (if any) from the previous // lookup into the given partial address and index. Only // the labels for the dimensions NOT specified in // create_view will be extracted here. - virtual bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) = 0; + virtual bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) = 0; virtual ~View() {} }; @@ -163,6 +164,14 @@ struct ValueBuilder : ValueBuilderBase { // is not allowed. virtual ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref> addr) = 0; + // add a dense subspace for the given address where labels are + // specified by shared string repo ids. Note that the caller is + // responsible for making sure the ids are valid 'long enough'. + virtual ArrayRef<T> add_subspace(ConstArrayRef<label_t> addr) = 0; + + // convenience function to add a subspace with an empty address + ArrayRef<T> add_subspace() { return add_subspace(ConstArrayRef<label_t>()); } + // Given the ownership of the builder itself, produce the newly // created value. This means that builders can only be used once, // it also means values can build themselves. @@ -179,26 +188,40 @@ struct ValueBuilder : ValueBuilderBase { * builder. With interoperability between all values. **/ struct ValueBuilderFactory { +private: template <typename T> - std::unique_ptr<ValueBuilder<T>> create_value_builder(const ValueType &type, + std::unique_ptr<ValueBuilder<T>> create_value_builder(const ValueType &type, bool transient, size_t num_mapped_dims_in, size_t subspace_size_in, size_t expected_subspaces) const { assert(check_cell_type<T>(type.cell_type())); - auto base = create_value_builder_base(type, num_mapped_dims_in, subspace_size_in, expected_subspaces); + auto base = create_value_builder_base(type, transient, num_mapped_dims_in, subspace_size_in, expected_subspaces); ValueBuilder<T> *builder = dynamic_cast<ValueBuilder<T>*>(base.get()); assert(builder); base.release(); return std::unique_ptr<ValueBuilder<T>>(builder); } +public: + template <typename T> + std::unique_ptr<ValueBuilder<T>> create_value_builder(const ValueType &type, + size_t num_mapped_dims_in, size_t subspace_size_in, size_t expected_subspaces) const + { + return create_value_builder<T>(type, false, num_mapped_dims_in, subspace_size_in, expected_subspaces); + } + template <typename T> + std::unique_ptr<ValueBuilder<T>> create_transient_value_builder(const ValueType &type, + size_t num_mapped_dims_in, size_t subspace_size_in, size_t expected_subspaces) const + { + return create_value_builder<T>(type, true, num_mapped_dims_in, subspace_size_in, expected_subspaces); + } template <typename T> std::unique_ptr<ValueBuilder<T>> create_value_builder(const ValueType &type) const { - return create_value_builder<T>(type, type.count_mapped_dimensions(), type.dense_subspace_size(), 1); + return create_value_builder<T>(type, false, type.count_mapped_dimensions(), type.dense_subspace_size(), 1); } std::unique_ptr<Value> copy(const Value &value) const; virtual ~ValueBuilderFactory() {} protected: - virtual std::unique_ptr<ValueBuilderBase> create_value_builder_base(const ValueType &type, + virtual std::unique_ptr<ValueBuilderBase> create_value_builder_base(const ValueType &type, bool transient, size_t num_mapped_dims_in, size_t subspace_size_in, size_t expected_subspaces) const = 0; }; diff --git a/eval/src/vespa/eval/eval/value_codec.cpp b/eval/src/vespa/eval/eval/value_codec.cpp index 923d3f29cd3..53131da86d8 100644 --- a/eval/src/vespa/eval/eval/value_codec.cpp +++ b/eval/src/vespa/eval/eval/value_codec.cpp @@ -7,6 +7,7 @@ #include <vespa/vespalib/util/exceptions.h> #include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/util/stringfmt.h> +#include <vespa/vespalib/util/shared_string_repo.h> using vespalib::make_string_short::fmt; @@ -128,9 +129,10 @@ size_t maybe_decode_num_blocks(nbostream &input, bool has_mapped_dims, const For return 1; } -void encode_mapped_labels(nbostream &output, size_t num_mapped_dims, const std::vector<vespalib::stringref> &addr) { +void encode_mapped_labels(nbostream &output, size_t num_mapped_dims, const std::vector<label_t> &addr) { for (size_t i = 0; i < num_mapped_dims; ++i) { - output.writeSmallString(addr[i]); + vespalib::string str = SharedStringRepo::Handle::string_from_id(addr[i]); + output.writeSmallString(str); } } @@ -175,7 +177,7 @@ struct ContentDecoder { } // add implicit empty subspace if ((state.num_mapped_dims == 0) && (state.num_blocks == 0)) { - for (T &cell: builder->add_subspace({})) { + for (T &cell: builder->add_subspace()) { cell = T{}; } } @@ -229,8 +231,8 @@ struct CreateTensorSpecFromValue { TensorSpec spec(value.type().to_spec()); size_t subspace_id = 0; size_t subspace_size = value.type().dense_subspace_size(); - std::vector<vespalib::stringref> labels(value.type().count_mapped_dimensions()); - std::vector<vespalib::stringref*> label_refs; + std::vector<label_t> labels(value.type().count_mapped_dimensions()); + std::vector<label_t*> label_refs; for (auto &label: labels) { label_refs.push_back(&label); } @@ -241,7 +243,7 @@ struct CreateTensorSpecFromValue { TensorSpec::Address addr; for (const auto &dim: value.type().dimensions()) { if (dim.is_mapped()) { - addr.emplace(dim.name, labels[label_idx++]); + addr.emplace(dim.name, SharedStringRepo::Handle::string_from_id(labels[label_idx++])); } } for (size_t i = 0; i < subspace_size; ++i) { @@ -270,8 +272,8 @@ struct EncodeState { struct ContentEncoder { template<typename T> static void invoke(const Value &value, const EncodeState &state, nbostream &output) { - std::vector<vespalib::stringref> address(state.num_mapped_dims); - std::vector<vespalib::stringref*> a_refs(state.num_mapped_dims);; + std::vector<label_t> address(state.num_mapped_dims); + std::vector<label_t*> a_refs(state.num_mapped_dims);; for (size_t i = 0; i < state.num_mapped_dims; ++i) { a_refs[i] = &address[i]; } diff --git a/eval/src/vespa/eval/instruction/generic_concat.cpp b/eval/src/vespa/eval/instruction/generic_concat.cpp index fa9d2192b99..5d8ab7187c0 100644 --- a/eval/src/vespa/eval/instruction/generic_concat.cpp +++ b/eval/src/vespa/eval/instruction/generic_concat.cpp @@ -47,10 +47,10 @@ generic_concat(const Value &a, const Value &b, auto a_cells = a.cells().typify<LCT>(); auto b_cells = b.cells().typify<RCT>(); SparseJoinState sparse(sparse_plan, a.index(), b.index()); - auto builder = factory.create_value_builder<OCT>(res_type, - sparse_plan.sources.size(), - dense_plan.output_size, - sparse.first_index.size()); + auto builder = factory.create_transient_value_builder<OCT>(res_type, + sparse_plan.sources.size(), + dense_plan.output_size, + sparse.first_index.size()); auto outer = sparse.first_index.create_view({}); auto inner = sparse.second_index.create_view(sparse.second_view_dims); outer->lookup({}); diff --git a/eval/src/vespa/eval/instruction/generic_create.cpp b/eval/src/vespa/eval/instruction/generic_create.cpp index 02c89e0b43f..6e30da846e7 100644 --- a/eval/src/vespa/eval/instruction/generic_create.cpp +++ b/eval/src/vespa/eval/instruction/generic_create.cpp @@ -5,6 +5,7 @@ #include <vespa/eval/eval/array_array_map.h> #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/util/shared_string_repo.h> #include <cassert> using namespace vespalib::eval::tensor_function; @@ -13,6 +14,7 @@ namespace vespalib::eval::instruction { using State = InterpretedFunction::State; using Instruction = InterpretedFunction::Instruction; +using Handle = SharedStringRepo::Handle; namespace { @@ -21,12 +23,12 @@ struct CreateParam { size_t num_mapped_dims; size_t dense_subspace_size; size_t num_children; - ArrayArrayMap<vespalib::string,size_t> my_spec; + ArrayArrayMap<Handle,size_t> my_spec; const ValueBuilderFactory &factory; static constexpr size_t npos = -1; - ArrayRef<size_t> indexes(ConstArrayRef<vespalib::string> key) { + ArrayRef<size_t> indexes(ConstArrayRef<Handle> key) { auto [tag, first_time] = my_spec.lookup_or_add_entry(key); auto rv = my_spec.get_values(tag); if (first_time) { @@ -49,7 +51,7 @@ struct CreateParam { { size_t last_child = num_children - 1; for (const auto & entry : spec_in) { - std::vector<vespalib::string> sparse_key; + std::vector<Handle> sparse_key; size_t dense_key = 0; auto dim = res_type.dimensions().begin(); auto binding = entry.first.begin(); @@ -58,7 +60,7 @@ struct CreateParam { assert(dim->name == binding->first); assert(dim->is_mapped() == binding->second.is_mapped()); if (dim->is_mapped()) { - sparse_key.push_back(binding->second.name); + sparse_key.push_back(Handle(binding->second.name)); } else { assert(binding->second.index < dim->size); dense_key = (dense_key * dim->size) + binding->second.index; @@ -76,16 +78,16 @@ struct CreateParam { template <typename T> void my_generic_create_op(State &state, uint64_t param_in) { const auto ¶m = unwrap_param<CreateParam>(param_in); - auto builder = param.factory.create_value_builder<T>(param.res_type, - param.num_mapped_dims, - param.dense_subspace_size, - param.my_spec.size()); - std::vector<vespalib::stringref> sparse_addr; + auto builder = param.factory.create_transient_value_builder<T>(param.res_type, + param.num_mapped_dims, + param.dense_subspace_size, + param.my_spec.size()); + std::vector<label_t> sparse_addr; param.my_spec.each_entry([&](const auto &key, const auto &values) { sparse_addr.clear(); for (const auto & label : key) { - sparse_addr.push_back(label); + sparse_addr.push_back(label.id()); } T *dst = builder->add_subspace(sparse_addr).begin(); for (size_t stack_idx : values) { diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp index 026df5aa993..e0dc0feea28 100644 --- a/eval/src/vespa/eval/instruction/generic_join.cpp +++ b/eval/src/vespa/eval/instruction/generic_join.cpp @@ -41,7 +41,7 @@ generic_mixed_join(const Value &lhs, const Value &rhs, const JoinParam ¶m) if (param.sparse_plan.lhs_overlap.empty() && param.sparse_plan.rhs_overlap.empty()) { expected_subspaces = sparse.first_index.size() * sparse.second_index.size(); } - auto builder = param.factory.create_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), param.dense_plan.out_size, expected_subspaces); + auto builder = param.factory.create_transient_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), param.dense_plan.out_size, expected_subspaces); auto outer = sparse.first_index.create_view({}); auto inner = sparse.second_index.create_view(sparse.second_view_dims); outer->lookup({}); @@ -92,7 +92,7 @@ void my_sparse_no_overlap_join_op(State &state, uint64_t param_in) { SparseJoinState sparse(param.sparse_plan, lhs.index(), rhs.index()); auto guess = lhs.index().size() * rhs.index().size(); assert(param.dense_plan.out_size == 1); - auto builder = param.factory.create_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), 1, guess); + auto builder = param.factory.create_transient_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), 1, guess); auto outer = sparse.first_index.create_view({}); assert(sparse.second_view_dims.empty()); auto inner = sparse.second_index.create_view({}); @@ -131,7 +131,7 @@ void my_sparse_full_overlap_join_op(State &state, uint64_t param_in) { } Fun fun(param.function); SparseJoinState sparse(param.sparse_plan, lhs_index, rhs_index); - auto builder = param.factory.create_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), param.dense_plan.out_size, sparse.first_index.size()); + auto builder = param.factory.create_transient_value_builder<OCT>(param.res_type, param.sparse_plan.sources.size(), param.dense_plan.out_size, sparse.first_index.size()); auto outer = sparse.first_index.create_view({}); auto inner = sparse.second_index.create_view(sparse.second_view_dims); outer->lookup({}); diff --git a/eval/src/vespa/eval/instruction/generic_join.h b/eval/src/vespa/eval/instruction/generic_join.h index 988286be980..217f3195dec 100644 --- a/eval/src/vespa/eval/instruction/generic_join.h +++ b/eval/src/vespa/eval/instruction/generic_join.h @@ -68,10 +68,10 @@ struct SparseJoinState { const Value::Index &first_index; const Value::Index &second_index; const std::vector<size_t> &second_view_dims; - std::vector<vespalib::stringref> full_address; - std::vector<vespalib::stringref*> first_address; - std::vector<const vespalib::stringref*> address_overlap; - std::vector<vespalib::stringref*> second_only_address; + std::vector<label_t> full_address; + std::vector<label_t*> first_address; + std::vector<const label_t*> address_overlap; + std::vector<label_t*> second_only_address; size_t lhs_subspace; size_t rhs_subspace; size_t &first_subspace; diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp index 02749a04eb9..107cb805d74 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.cpp +++ b/eval/src/vespa/eval/instruction/generic_merge.cpp @@ -63,10 +63,10 @@ generic_mixed_merge(const Value &a, const Value &b, const size_t num_mapped = params.num_mapped_dimensions; const size_t subspace_size = params.dense_subspace_size; size_t guess_subspaces = std::max(a.index().size(), b.index().size()); - auto builder = params.factory.create_value_builder<OCT>(params.res_type, num_mapped, subspace_size, guess_subspaces); - std::vector<vespalib::stringref> address(num_mapped); - std::vector<const vespalib::stringref *> addr_cref; - std::vector<vespalib::stringref *> addr_ref; + auto builder = params.factory.create_transient_value_builder<OCT>(params.res_type, num_mapped, subspace_size, guess_subspaces); + std::vector<label_t> address(num_mapped); + std::vector<const label_t *> addr_cref; + std::vector<label_t *> addr_ref; for (auto & ref : address) { addr_cref.push_back(&ref); addr_ref.push_back(&ref); diff --git a/eval/src/vespa/eval/instruction/generic_peek.cpp b/eval/src/vespa/eval/instruction/generic_peek.cpp index 66538911890..d94742ae15c 100644 --- a/eval/src/vespa/eval/instruction/generic_peek.cpp +++ b/eval/src/vespa/eval/instruction/generic_peek.cpp @@ -7,6 +7,7 @@ #include <vespa/vespalib/util/stash.h> #include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/util/visit_ranges.h> +#include <vespa/vespalib/util/shared_string_repo.h> #include <cassert> using namespace vespalib::eval::tensor_function; @@ -16,6 +17,8 @@ namespace vespalib::eval::instruction { using State = InterpretedFunction::State; using Instruction = InterpretedFunction::Instruction; +using Handle = SharedStringRepo::Handle; + namespace { static constexpr size_t npos = -1; @@ -35,28 +38,43 @@ size_t count_children(const Spec &spec) } struct DimSpec { - vespalib::stringref name; - GenericPeek::SpecMap::mapped_type child_or_label; + enum class DimType { CHILD_IDX, LABEL_IDX, LABEL_STR }; + vespalib::string name; + DimType dim_type; + size_t idx; + Handle str; + static DimSpec from_child(const vespalib::string &name_in, size_t child_idx) { + return {name_in, DimType::CHILD_IDX, child_idx, Handle()}; + } + static DimSpec from_label(const vespalib::string &name_in, const TensorSpec::Label &label) { + if (label.is_mapped()) { + return {name_in, DimType::LABEL_STR, 0, Handle(label.name)}; + } else { + assert(label.is_indexed()); + return {name_in, DimType::LABEL_IDX, label.index, Handle()}; + } + } + ~DimSpec(); bool has_child() const { - return std::holds_alternative<size_t>(child_or_label); + return (dim_type == DimType::CHILD_IDX); } bool has_label() const { - return std::holds_alternative<TensorSpec::Label>(child_or_label); + return (dim_type != DimType::CHILD_IDX); } size_t get_child_idx() const { - return std::get<size_t>(child_or_label); + assert(dim_type == DimType::CHILD_IDX); + return idx; } - vespalib::stringref get_label_name() const { - auto & label = std::get<TensorSpec::Label>(child_or_label); - assert(label.is_mapped()); - return label.name; + label_t get_label_name() const { + assert(dim_type == DimType::LABEL_STR); + return str.id(); } size_t get_label_index() const { - auto & label = std::get<TensorSpec::Label>(child_or_label); - assert(label.is_indexed()); - return label.index; + assert(dim_type == DimType::LABEL_IDX); + return idx; } }; +DimSpec::~DimSpec() = default; struct ExtractedSpecs { using Dimension = ValueType::Dimension; @@ -85,7 +103,11 @@ struct ExtractedSpecs { dimensions.push_back(a); const auto & [spec_dim_name, child_or_label] = b; assert(a.name == spec_dim_name); - specs.emplace_back(DimSpec{a.name, child_or_label}); + if (std::holds_alternative<size_t>(child_or_label)) { + specs.push_back(DimSpec::from_child(a.name, std::get<size_t>(child_or_label))); + } else { + specs.push_back(DimSpec::from_label(a.name, std::get<TensorSpec::Label>(child_or_label))); + } } } }; @@ -181,22 +203,21 @@ struct DensePlan { }; struct SparseState { - std::vector<vespalib::string> view_addr; - std::vector<vespalib::stringref> view_refs; - std::vector<const vespalib::stringref *> lookup_refs; - std::vector<vespalib::stringref> output_addr; - std::vector<vespalib::stringref *> fetch_addr; - - SparseState(std::vector<vespalib::string> view_addr_in, size_t out_dims) - : view_addr(std::move(view_addr_in)), - view_refs(view_addr.size()), + std::vector<Handle> handles; + std::vector<label_t> view_addr; + std::vector<const label_t *> lookup_refs; + std::vector<label_t> output_addr; + std::vector<label_t *> fetch_addr; + + SparseState(std::vector<Handle> handles_in, std::vector<label_t> view_addr_in, size_t out_dims) + : handles(std::move(handles_in)), + view_addr(std::move(view_addr_in)), lookup_refs(view_addr.size()), output_addr(out_dims), fetch_addr(out_dims) { for (size_t i = 0; i < view_addr.size(); ++i) { - view_refs[i] = view_addr[i]; - lookup_refs[i] = &view_refs[i]; + lookup_refs[i] = &view_addr[i]; } for (size_t i = 0; i < out_dims; ++i) { fetch_addr[i] = &output_addr[i]; @@ -236,17 +257,19 @@ struct SparsePlan { template <typename Getter> SparseState make_state(const Getter &get_child_value) const { - std::vector<vespalib::string> view_addr; + std::vector<Handle> handles; + std::vector<label_t> view_addr; for (const auto & dim : lookup_specs) { if (dim.has_child()) { int64_t child_value = get_child_value(dim.get_child_idx()); - view_addr.push_back(vespalib::make_string("%" PRId64, child_value)); + handles.emplace_back(vespalib::make_string("%" PRId64, child_value)); + view_addr.push_back(handles.back().id()); } else { view_addr.push_back(dim.get_label_name()); } } assert(view_addr.size() == view_dims.size()); - return SparseState(std::move(view_addr), out_mapped_dims); + return SparseState(std::move(handles), std::move(view_addr), out_mapped_dims); } }; SparsePlan::~SparsePlan() = default; @@ -284,10 +307,10 @@ generic_mixed_peek(const ValueType &res_type, { auto input_cells = input_value.cells().typify<ICT>(); size_t bad_guess = 1; - auto builder = factory.create_value_builder<OCT>(res_type, - sparse_plan.out_mapped_dims, - dense_plan.out_dense_size, - bad_guess); + auto builder = factory.create_transient_value_builder<OCT>(res_type, + sparse_plan.out_mapped_dims, + dense_plan.out_dense_size, + bad_guess); size_t filled_subspaces = 0; size_t dense_offset = dense_plan.get_offset(get_child_value); if (dense_offset != npos) { @@ -304,7 +327,7 @@ generic_mixed_peek(const ValueType &res_type, } } if ((sparse_plan.out_mapped_dims == 0) && (filled_subspaces == 0)) { - for (auto & v : builder->add_subspace({})) { + for (auto & v : builder->add_subspace()) { v = OCT{}; } } diff --git a/eval/src/vespa/eval/instruction/generic_reduce.cpp b/eval/src/vespa/eval/instruction/generic_reduce.cpp index afc46e8ee7d..d30186d3dd8 100644 --- a/eval/src/vespa/eval/instruction/generic_reduce.cpp +++ b/eval/src/vespa/eval/instruction/generic_reduce.cpp @@ -45,10 +45,10 @@ ReduceParam::~ReduceParam() = default; //----------------------------------------------------------------------------- struct SparseReduceState { - std::vector<vespalib::stringref> full_address; - std::vector<vespalib::stringref*> fetch_address; - std::vector<vespalib::stringref*> keep_address; - size_t subspace; + std::vector<label_t> full_address; + std::vector<label_t*> fetch_address; + std::vector<label_t*> keep_address; + size_t subspace; SparseReduceState(const SparseReducePlan &plan) : full_address(plan.keep_dims.size() + plan.num_reduce_dims), @@ -71,20 +71,20 @@ template <typename ICT, typename OCT, typename AGGR> Value::UP generic_reduce(const Value &value, const ReduceParam ¶m) { auto cells = value.cells().typify<ICT>(); - ArrayArrayMap<vespalib::stringref,AGGR> map(param.sparse_plan.keep_dims.size(), - param.dense_plan.out_size, - value.index().size()); + ArrayArrayMap<label_t,AGGR> map(param.sparse_plan.keep_dims.size(), + param.dense_plan.out_size, + value.index().size()); SparseReduceState sparse(param.sparse_plan); auto full_view = value.index().create_view({}); full_view->lookup({}); - ConstArrayRef<vespalib::stringref*> keep_addr(sparse.keep_address); + ConstArrayRef<label_t*> keep_addr(sparse.keep_address); while (full_view->next_result(sparse.fetch_address, sparse.subspace)) { auto [tag, ignore] = map.lookup_or_add_entry(keep_addr); AGGR *dst = map.get_values(tag).begin(); auto sample = [&](size_t src_idx, size_t dst_idx) { dst[dst_idx].sample(cells[src_idx]); }; param.dense_plan.execute(sparse.subspace * param.dense_plan.in_size, sample); } - auto builder = param.factory.create_value_builder<OCT>(param.res_type, param.sparse_plan.keep_dims.size(), param.dense_plan.out_size, map.size()); + auto builder = param.factory.create_transient_value_builder<OCT>(param.res_type, param.sparse_plan.keep_dims.size(), param.dense_plan.out_size, map.size()); map.each_entry([&](const auto &keys, const auto &values) { OCT *dst = builder->add_subspace(keys).begin(); @@ -93,7 +93,7 @@ generic_reduce(const Value &value, const ReduceParam ¶m) { } }); if ((map.size() == 0) && param.sparse_plan.keep_dims.empty()) { - auto zero = builder->add_subspace({}); + auto zero = builder->add_subspace(); for (size_t i = 0; i < zero.size(); ++i) { zero[i] = OCT{}; } diff --git a/eval/src/vespa/eval/instruction/generic_rename.cpp b/eval/src/vespa/eval/instruction/generic_rename.cpp index 1ce18597ec2..894ef37b678 100644 --- a/eval/src/vespa/eval/instruction/generic_rename.cpp +++ b/eval/src/vespa/eval/instruction/generic_rename.cpp @@ -69,15 +69,15 @@ generic_rename(const Value &a, const ValueType &res_type, const ValueBuilderFactory &factory) { auto cells = a.cells().typify<CT>(); - std::vector<vespalib::stringref> output_address(sparse_plan.mapped_dims); - std::vector<vespalib::stringref*> input_address; + std::vector<label_t> output_address(sparse_plan.mapped_dims); + std::vector<label_t*> input_address; for (size_t maps_to : sparse_plan.output_dimensions) { input_address.push_back(&output_address[maps_to]); } - auto builder = factory.create_value_builder<CT>(res_type, - sparse_plan.mapped_dims, - dense_plan.subspace_size, - a.index().size()); + auto builder = factory.create_transient_value_builder<CT>(res_type, + sparse_plan.mapped_dims, + dense_plan.subspace_size, + a.index().size()); auto view = a.index().create_view({}); view->lookup({}); size_t subspace; diff --git a/eval/src/vespa/eval/streamed/streamed_value.cpp b/eval/src/vespa/eval/streamed/streamed_value.cpp index bdfe5fd4e27..06162b2200d 100644 --- a/eval/src/vespa/eval/streamed/streamed_value.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value.cpp @@ -16,8 +16,7 @@ StreamedValue<T>::get_memory_usage() const { MemoryUsage usage = self_memory_usage<StreamedValue<T>>(); usage.merge(vector_extra_memory_usage(_my_cells)); - usage.incUsedBytes(_label_buf.byteSize()); - usage.incAllocatedBytes(_label_buf.byteCapacity()); + usage.merge(vector_extra_memory_usage(_my_labels.view().handles())); return usage; } diff --git a/eval/src/vespa/eval/streamed/streamed_value.h b/eval/src/vespa/eval/streamed/streamed_value.h index 258802a53e8..94603d9d35e 100644 --- a/eval/src/vespa/eval/streamed/streamed_value.h +++ b/eval/src/vespa/eval/streamed/streamed_value.h @@ -4,6 +4,7 @@ #include <vespa/eval/eval/value_type.h> #include <vespa/eval/eval/value.h> +#include <vespa/vespalib/util/shared_string_repo.h> #include "streamed_value_index.h" #include <cassert> @@ -19,20 +20,22 @@ template <typename T> class StreamedValue : public Value { private: + using StrongHandles = SharedStringRepo::StrongHandles; + ValueType _type; std::vector<T> _my_cells; - Array<char> _label_buf; + StrongHandles _my_labels; StreamedValueIndex _my_index; public: StreamedValue(ValueType type, size_t num_mapped_dimensions, - std::vector<T> cells, size_t num_subspaces, Array<char> && label_buf) + std::vector<T> cells, size_t num_subspaces, StrongHandles && handles) : _type(std::move(type)), _my_cells(std::move(cells)), - _label_buf(std::move(label_buf)), + _my_labels(std::move(handles)), _my_index(num_mapped_dimensions, num_subspaces, - ConstArrayRef<char>(_label_buf.begin(), _label_buf.size())) + _my_labels.view().handles()) { assert(num_subspaces * _type.dense_subspace_size() == _my_cells.size()); } @@ -42,7 +45,6 @@ public: TypedCells cells() const final override { return TypedCells(_my_cells); } const Value::Index &index() const final override { return _my_index; } MemoryUsage get_memory_usage() const final override; - auto get_data_reference() const { return _my_index.get_data_reference(); } }; } // namespace diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder.h b/eval/src/vespa/eval/streamed/streamed_value_builder.h index 5698c805756..48a01f893de 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_builder.h +++ b/eval/src/vespa/eval/streamed/streamed_value_builder.h @@ -3,7 +3,7 @@ #pragma once #include "streamed_value.h" -#include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/util/shared_string_repo.h> namespace vespalib::eval { @@ -14,12 +14,14 @@ template <typename T> class StreamedValueBuilder : public ValueBuilder<T> { private: + using StrongHandles = SharedStringRepo::StrongHandles; + ValueType _type; size_t _num_mapped_dimensions; size_t _dense_subspace_size; std::vector<T> _cells; size_t _num_subspaces; - nbostream _labels; + StrongHandles _labels; public: StreamedValueBuilder(const ValueType &type, size_t num_mapped_in, @@ -30,18 +32,26 @@ public: _dense_subspace_size(subspace_size_in), _cells(), _num_subspaces(0), - _labels() + _labels(num_mapped_in * expected_subspaces) { _cells.reserve(subspace_size_in * expected_subspaces); - // assume small sized label strings: - _labels.reserve(num_mapped_in * expected_subspaces * 3); }; ~StreamedValueBuilder(); ArrayRef<T> add_subspace(ConstArrayRef<vespalib::stringref> addr) override { for (auto label : addr) { - _labels.writeSmallString(label); + _labels.add(label); + } + size_t old_sz = _cells.size(); + _cells.resize(old_sz + _dense_subspace_size); + _num_subspaces++; + return ArrayRef<T>(&_cells[old_sz], _dense_subspace_size); + } + + ArrayRef<T> add_subspace(ConstArrayRef<label_t> addr) override { + for (auto label : addr) { + _labels.add(label); } size_t old_sz = _cells.size(); _cells.resize(old_sz + _dense_subspace_size); @@ -58,7 +68,7 @@ public: _num_mapped_dimensions, std::move(_cells), _num_subspaces, - _labels.extract_buffer()); + std::move(_labels)); } }; diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder_factory.cpp b/eval/src/vespa/eval/streamed/streamed_value_builder_factory.cpp index aa6347a2c51..5111ba8a71e 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_builder_factory.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value_builder_factory.cpp @@ -19,10 +19,12 @@ struct SelectStreamedValueBuilder { std::unique_ptr<ValueBuilderBase> StreamedValueBuilderFactory::create_value_builder_base(const ValueType &type, + bool transient, size_t num_mapped, size_t subspace_size, size_t expected_subspaces) const { + (void) transient; return typify_invoke<1,TypifyCellType,SelectStreamedValueBuilder>( type.cell_type(), type, num_mapped, subspace_size, expected_subspaces); diff --git a/eval/src/vespa/eval/streamed/streamed_value_builder_factory.h b/eval/src/vespa/eval/streamed/streamed_value_builder_factory.h index 3f81981f429..58072aa31dc 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_builder_factory.h +++ b/eval/src/vespa/eval/streamed/streamed_value_builder_factory.h @@ -14,7 +14,7 @@ private: StreamedValueBuilderFactory() {} static StreamedValueBuilderFactory _factory; std::unique_ptr<ValueBuilderBase> create_value_builder_base( - const ValueType &type, size_t num_mapped_in, + const ValueType &type, bool transient, size_t num_mapped_in, size_t subspace_size_in, size_t expected_subspaces) const override; public: static const StreamedValueBuilderFactory &get() { return _factory; } diff --git a/eval/src/vespa/eval/streamed/streamed_value_index.cpp b/eval/src/vespa/eval/streamed/streamed_value_index.cpp index 17cf7316554..a014f2dcee9 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_index.cpp +++ b/eval/src/vespa/eval/streamed/streamed_value_index.cpp @@ -18,7 +18,7 @@ struct StreamedFilterView : Value::Index::View { LabelBlockStream label_blocks; std::vector<size_t> view_dims; - std::vector<vespalib::stringref> to_match; + std::vector<label_t> to_match; StreamedFilterView(LabelBlockStream labels, std::vector<size_t> view_dims_in) : label_blocks(std::move(labels)), @@ -28,7 +28,7 @@ struct StreamedFilterView : Value::Index::View to_match.reserve(view_dims.size()); } - void lookup(ConstArrayRef<const vespalib::stringref*> addr) override { + void lookup(ConstArrayRef<const label_t*> addr) override { label_blocks.reset(); to_match.clear(); for (auto ptr : addr) { @@ -37,7 +37,7 @@ struct StreamedFilterView : Value::Index::View assert(view_dims.size() == to_match.size()); } - bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) override { + bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) override { while (const auto block = label_blocks.next_block()) { idx_out = block.subspace_index; bool matches = true; @@ -66,12 +66,12 @@ struct StreamedIterationView : Value::Index::View : label_blocks(std::move(labels)) {} - void lookup(ConstArrayRef<const vespalib::stringref*> addr) override { + void lookup(ConstArrayRef<const label_t*> addr) override { label_blocks.reset(); assert(addr.size() == 0); } - bool next_result(ConstArrayRef<vespalib::stringref*> addr_out, size_t &idx_out) override { + bool next_result(ConstArrayRef<label_t*> addr_out, size_t &idx_out) override { if (auto block = label_blocks.next_block()) { idx_out = block.subspace_index; size_t i = 0; @@ -90,7 +90,7 @@ struct StreamedIterationView : Value::Index::View std::unique_ptr<Value::Index::View> StreamedValueIndex::create_view(const std::vector<size_t> &dims) const { - LabelBlockStream label_stream(_data.num_subspaces, _data.labels_buffer, _data.num_mapped_dims); + LabelBlockStream label_stream(_num_subspaces, _labels_ref, _num_mapped_dims); if (dims.empty()) { return std::make_unique<StreamedIterationView>(std::move(label_stream)); } diff --git a/eval/src/vespa/eval/streamed/streamed_value_index.h b/eval/src/vespa/eval/streamed/streamed_value_index.h index 8fd561200c3..aa1c9a0e201 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_index.h +++ b/eval/src/vespa/eval/streamed/streamed_value_index.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/eval/eval/value.h> +#include <vespa/vespalib/util/shared_string_repo.h> namespace vespalib::eval { @@ -12,25 +13,21 @@ namespace vespalib::eval { **/ class StreamedValueIndex : public Value::Index { +private: + uint32_t _num_mapped_dims; + uint32_t _num_subspaces; + const std::vector<label_t> &_labels_ref; + public: - struct SerializedDataRef { - uint32_t num_mapped_dims; - uint32_t num_subspaces; - ConstArrayRef<char> labels_buffer; - }; - StreamedValueIndex(uint32_t num_mapped_dims, uint32_t num_subspaces, ConstArrayRef<char> labels_buf) - : _data{num_mapped_dims, num_subspaces, labels_buf} + StreamedValueIndex(uint32_t num_mapped_dims, uint32_t num_subspaces, const std::vector<label_t> &labels_ref) + : _num_mapped_dims(num_mapped_dims), + _num_subspaces(num_subspaces), + _labels_ref(labels_ref) {} // index API: - size_t size() const override { return _data.num_subspaces; } + size_t size() const override { return _num_subspaces; } std::unique_ptr<View> create_view(const std::vector<size_t> &dims) const override; - - SerializedDataRef get_data_reference() const { return _data; } - -private: - SerializedDataRef _data; }; } // namespace - diff --git a/eval/src/vespa/eval/streamed/streamed_value_utils.h b/eval/src/vespa/eval/streamed/streamed_value_utils.h index b88d4df8581..6b44e052f0c 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_utils.h +++ b/eval/src/vespa/eval/streamed/streamed_value_utils.h @@ -4,24 +4,23 @@ #include <vespa/eval/eval/value.h> #include <vespa/vespalib/objects/nbostream.h> +#include <cassert> namespace vespalib::eval { /** * Reads a stream of serialized labels. - * Reading more labels than available will - * throw an exception. + * Reading more labels than available will trigger an assert. **/ struct LabelStream { - nbostream source; - LabelStream(ConstArrayRef<char> data) : source(data.begin(), data.size()) {} - vespalib::stringref next_label() { - size_t str_size = source.getInt1_4Bytes(); - vespalib::stringref label(source.peek(), str_size); - source.adjustReadPos(str_size); - return label; + const std::vector<label_t> &source; + size_t pos; + LabelStream(const std::vector<label_t> &data) : source(data), pos(0) {} + label_t next_label() { + assert(pos < source.size()); + return source[pos++]; } - void reset() { source.rp(0); } + void reset() { pos = 0; } }; /** @@ -30,7 +29,7 @@ struct LabelStream { struct LabelBlock { static constexpr size_t npos = -1; size_t subspace_index; - ConstArrayRef<vespalib::stringref> address; + ConstArrayRef<label_t> address; operator bool() const { return subspace_index != npos; } }; @@ -43,7 +42,7 @@ private: size_t _num_subspaces; LabelStream _labels; size_t _subspace_index; - std::vector<vespalib::stringref> _current_address; + std::vector<label_t> _current_address; public: LabelBlock next_block() { if (_subspace_index < _num_subspaces) { @@ -62,10 +61,10 @@ public: } LabelBlockStream(uint32_t num_subspaces, - ConstArrayRef<char> label_buf, + const std::vector<label_t> &labels, uint32_t num_mapped_dims) : _num_subspaces(num_subspaces), - _labels(label_buf), + _labels(labels), _subspace_index(num_subspaces), _current_address(num_mapped_dims) {} diff --git a/eval/src/vespa/eval/streamed/streamed_value_view.h b/eval/src/vespa/eval/streamed/streamed_value_view.h index e37f442dd9a..38eb8db786f 100644 --- a/eval/src/vespa/eval/streamed/streamed_value_view.h +++ b/eval/src/vespa/eval/streamed/streamed_value_view.h @@ -24,10 +24,10 @@ private: public: StreamedValueView(const ValueType &type, size_t num_mapped_dimensions, TypedCells cells, size_t num_subspaces, - ConstArrayRef<char> labels_buf) + const std::vector<label_t> &labels) : _type(type), _cells_ref(cells), - _my_index(num_mapped_dimensions, num_subspaces, labels_buf) + _my_index(num_mapped_dimensions, num_subspaces, labels) { assert(num_subspaces * _type.dense_subspace_size() == _cells_ref.size); } @@ -39,7 +39,6 @@ public: MemoryUsage get_memory_usage() const final override { return self_memory_usage<StreamedValueView>(); } - auto get_data_reference() const { return _my_index.get_data_reference(); } }; } // namespace diff --git a/fastos/src/tests/CMakeLists.txt b/fastos/src/tests/CMakeLists.txt index cc4928cc7d0..8c0d255df32 100644 --- a/fastos/src/tests/CMakeLists.txt +++ b/fastos/src/tests/CMakeLists.txt @@ -19,20 +19,6 @@ vespa_add_executable(fastos_thread_stats_test_app TEST fastos ) vespa_add_test(NAME fastos_thread_stats_test_app NO_VALGRIND COMMAND fastos_thread_stats_test_app) -vespa_add_executable(fastos_thread_sleep_test_app TEST - SOURCES - thread_sleep_test.cpp - DEPENDS - fastos -) -vespa_add_test(NAME fastos_thread_sleep_test_app NO_VALGRIND COMMAND fastos_thread_sleep_test_app) -vespa_add_executable(fastos_thread_mutex_test_app TEST - SOURCES - thread_mutex_test.cpp - DEPENDS - fastos -) -vespa_add_test(NAME fastos_thread_mutex_test_app NO_VALGRIND COMMAND fastos_thread_mutex_test_app) vespa_add_executable(fastos_thread_joinwait_test_app TEST SOURCES thread_joinwait_test.cpp @@ -40,13 +26,6 @@ vespa_add_executable(fastos_thread_joinwait_test_app TEST fastos ) vespa_add_test(NAME fastos_thread_joinwait_test_app NO_VALGRIND COMMAND fastos_thread_joinwait_test_app) -vespa_add_executable(fastos_thread_bounce_test_app TEST - SOURCES - thread_bounce_test.cpp - DEPENDS - fastos -) -vespa_add_test(NAME fastos_thread_bounce_test_app NO_VALGRIND COMMAND fastos_thread_bounce_test_app) vespa_add_executable(fastos_threadtest_app TEST SOURCES threadtest.cpp diff --git a/fastos/src/tests/job.h b/fastos/src/tests/job.h index 1d35ec95270..35e1d02a9d3 100644 --- a/fastos/src/tests/job.h +++ b/fastos/src/tests/job.h @@ -7,17 +7,13 @@ enum JobCode { - PRINT_MESSAGE_AND_WAIT3SEC, + PRINT_MESSAGE_AND_WAIT3MSEC, INCREASE_NUMBER, - PRIORITY_TEST, WAIT_FOR_BREAK_FLAG, WAIT_FOR_THREAD_TO_FINISH, WAIT_FOR_CONDITION, - BOUNCE_CONDITIONS, TEST_ID, WAIT2SEC_AND_SIGNALCOND, - HOLD_MUTEX_FOR2SEC, - WAIT_2_SEC, SILENTNOP, NOP }; @@ -34,7 +30,6 @@ public: std::mutex *mutex; std::condition_variable *condition; FastOS_ThreadInterface *otherThread, *ownThread; - double *timebuf; double average; int result; FastOS_ThreadId _threadId; @@ -49,7 +44,6 @@ public: condition(nullptr), otherThread(nullptr), ownThread(nullptr), - timebuf(nullptr), average(0.0), result(-1), _threadId(), diff --git a/fastos/src/tests/thread_bounce_test.cpp b/fastos/src/tests/thread_bounce_test.cpp deleted file mode 100644 index 488002341e9..00000000000 --- a/fastos/src/tests/thread_bounce_test.cpp +++ /dev/null @@ -1,98 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "tests.h" -#include "job.h" -#include "thread_test_base.hpp" - -using namespace std::chrono; - -class Thread_Bounce_Test : public ThreadTestBase -{ - int Main () override; - - void BounceTest(void) - { - TestHeader("Bounce Test"); - - FastOS_ThreadPool pool(128 * 1024); - std::mutex mutex1; - std::condition_variable cond1; - std::mutex mutex2; - std::condition_variable cond2; - Job job1; - Job job2; - int cnt1; - int cnt2; - int cntsum; - int lastcntsum; - - job1.code = BOUNCE_CONDITIONS; - job2.code = BOUNCE_CONDITIONS; - job1.otherjob = &job2; - job2.otherjob = &job1; - job1.mutex = &mutex1; - job1.condition = &cond1; - job2.mutex = &mutex2; - job2.condition = &cond2; - - job1.ownThread = pool.NewThread(this, static_cast<void *>(&job1)); - job2.ownThread = pool.NewThread(this, static_cast<void *>(&job2)); - - lastcntsum = -1; - for (int iter = 0; iter < 8; iter++) { - steady_clock::time_point start = steady_clock::now(); - - nanoseconds left = steady_clock::now() - start; - while (left < 1000ms) { - std::this_thread::sleep_for(1000ms - left); - left = steady_clock::now() - start; - } - - mutex1.lock(); - cnt1 = job1.bouncewakeupcnt; - mutex1.unlock(); - mutex2.lock(); - cnt2 = job2.bouncewakeupcnt; - mutex2.unlock(); - cntsum = cnt1 + cnt2; - Progress(lastcntsum != cntsum, "%d bounces", cntsum); - lastcntsum = cntsum; - } - - job1.ownThread->SetBreakFlag(); - mutex1.lock(); - job1.bouncewakeup = true; - cond1.notify_one(); - mutex1.unlock(); - - job2.ownThread->SetBreakFlag(); - mutex2.lock(); - job2.bouncewakeup = true; - cond2.notify_one(); - mutex2.unlock(); - - pool.Close(); - Progress(true, "Pool closed."); - PrintSeparator(); - } - -}; - -int Thread_Bounce_Test::Main () -{ - printf("grep for the string '%s' to detect failures.\n\n", failString); - time_t before = time(0); - - BounceTest(); - - { time_t now = time(0); printf("[%ld seconds]\n", now-before); before = now; } - printf("END OF TEST (%s)\n", _argv[0]); - return allWasOk() ? 0 : 1; -} - -int main (int argc, char **argv) -{ - Thread_Bounce_Test app; - setvbuf(stdout, nullptr, _IOLBF, 8192); - return app.Entry(argc, argv); -} diff --git a/fastos/src/tests/thread_joinwait_test.cpp b/fastos/src/tests/thread_joinwait_test.cpp index 7153a05f836..6330a52b5f0 100644 --- a/fastos/src/tests/thread_joinwait_test.cpp +++ b/fastos/src/tests/thread_joinwait_test.cpp @@ -45,9 +45,9 @@ class Thread_JoinWait_Test : public ThreadTestBase break; } - if(rc) + if (rc) { - jobs[lastThreadNum].code = (((variant & 2) != 0) ? NOP : PRINT_MESSAGE_AND_WAIT3SEC); + jobs[lastThreadNum].code = (((variant & 2) != 0) ? NOP : PRINT_MESSAGE_AND_WAIT3MSEC); jobs[lastThreadNum].message = strdup("This is the thread that others wait for."); FastOS_ThreadInterface *lastThread; @@ -59,10 +59,9 @@ class Thread_JoinWait_Test : public ThreadTestBase rc = (lastThread != nullptr); Progress(rc, "Creating last thread"); - if(rc) + if (rc) { - for(i=0; i<lastThreadNum; i++) - { + for(i=0; i<lastThreadNum; i++) { jobs[i].otherThread = lastThread; } } @@ -70,9 +69,9 @@ class Thread_JoinWait_Test : public ThreadTestBase jobMutex.unlock(); - if((variant & 1) != 0) + if ((variant & 1) != 0) { - for(i=0; i<lastThreadNum; i++) + for (i=0; i<lastThreadNum; i++) { Progress(true, "Waiting for thread %d to finish using Join()", i+1); jobs[i].ownThread->Join(); diff --git a/fastos/src/tests/thread_mutex_test.cpp b/fastos/src/tests/thread_mutex_test.cpp deleted file mode 100644 index 6d3f8c3c5f0..00000000000 --- a/fastos/src/tests/thread_mutex_test.cpp +++ /dev/null @@ -1,194 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "tests.h" -#include "job.h" -#include "thread_test_base.hpp" - -#define MUTEX_TEST_THREADS 6 -#define MAX_THREADS 7 - -class Thread_Mutex_Test : public ThreadTestBase -{ - int Main () override; - - void MutexTest (bool usingMutex) - { - if(usingMutex) - TestHeader("Mutex Test"); - else - TestHeader("Not Using Mutex Test"); - - - FastOS_ThreadPool *pool = new FastOS_ThreadPool(128*1024, MAX_THREADS); - - if(Progress(pool != nullptr, "Allocating ThreadPool")) - { - int i; - Job jobs[MUTEX_TEST_THREADS]; - std::mutex *myMutex=nullptr; - - if(usingMutex) { - myMutex = new std::mutex; - } - - for(i=0; i<MUTEX_TEST_THREADS; i++) - { - jobs[i].code = INCREASE_NUMBER; - jobs[i].mutex = myMutex; - } - - number = 0; - - for(i=0; i<MUTEX_TEST_THREADS; i++) - { - bool rc = (nullptr != pool->NewThread(this, - static_cast<void *>(&jobs[i]))); - Progress(rc, "Creating Thread with%s mutex", (usingMutex ? "" : "out")); - }; - - WaitForThreadsToFinish(jobs, MUTEX_TEST_THREADS); - - - for(i=0; i<MUTEX_TEST_THREADS; i++) - { - Progress(true, "Thread returned with resultcode %d", jobs[i].result); - } - - bool wasOk=true; - int concurrentHits=0; - - for(i=0; i<MUTEX_TEST_THREADS; i++) - { - int result = jobs[i].result; - - if(usingMutex) - { - if((result % INCREASE_NUMBER_AMOUNT) != 0) - { - wasOk = false; - Progress(false, "Mutex locking did not work (%d).", result); - break; - } - } - else - { - if((result != 0) && - (result != INCREASE_NUMBER_AMOUNT*MUTEX_TEST_THREADS) && - (result % INCREASE_NUMBER_AMOUNT) == 0) - { - if((++concurrentHits) == 2) - { - wasOk = false; - Progress(false, "Very unlikely that threads are running " - "concurrently (%d)", jobs[i].result); - break; - } - } - } - } - - if(wasOk) - { - if(usingMutex) - { - Progress(true, "Using the mutex, the returned numbers were alligned."); - } - else - { - Progress(true, "Returned numbers were not alligned. " - "This was the expected result."); - } - } - - if(myMutex != nullptr) - delete(myMutex); - - Progress(true, "Closing threadpool..."); - pool->Close(); - - Progress(true, "Deleting threadpool..."); - delete(pool); - } - PrintSeparator(); - } - - void TryLockTest () - { - TestHeader("Mutex TryLock Test"); - - FastOS_ThreadPool pool(128*1024); - Job job; - std::mutex mtx; - - job.code = HOLD_MUTEX_FOR2SEC; - job.result = -1; - job.mutex = &mtx; - job.ownThread = pool.NewThread(this, - static_cast<void *>(&job)); - - Progress(job.ownThread !=nullptr, "Creating thread"); - - if(job.ownThread != nullptr) - { - bool lockrc; - - std::this_thread::sleep_for(1s); - - for(int i=0; i<5; i++) - { - lockrc = mtx.try_lock(); - Progress(!lockrc, "We should not get the mutex lock just yet (%s)", - lockrc ? "got it" : "didn't get it"); - if(lockrc) { - mtx.unlock(); - break; - } - } - - std::this_thread::sleep_for(2s); - - lockrc = mtx.try_lock(); - Progress(lockrc, "We should get the mutex lock now (%s)", - lockrc ? "got it" : "didn't get it"); - - if(lockrc) - mtx.unlock(); - - Progress(true, "Attempting to do normal lock..."); - mtx.lock(); - Progress(true, "Got lock. Attempt to do normal unlock..."); - mtx.unlock(); - Progress(true, "Unlock OK."); - } - - Progress(true, "Waiting for threads to finish using pool.Close()..."); - pool.Close(); - Progress(true, "Pool closed."); - - PrintSeparator(); - } - -}; - -int Thread_Mutex_Test::Main () -{ - printf("grep for the string '%s' to detect failures.\n\n", failString); - time_t before = time(0); - - MutexTest(true); - { time_t now = time(0); printf("[%ld seconds]\n", now-before); before = now; } - MutexTest(false); - { time_t now = time(0); printf("[%ld seconds]\n", now-before); before = now; } - TryLockTest(); - { time_t now = time(0); printf("[%ld seconds]\n", now-before); before = now; } - - printf("END OF TEST (%s)\n", _argv[0]); - return allWasOk() ? 0 : 1; -} - -int main (int argc, char **argv) -{ - Thread_Mutex_Test app; - setvbuf(stdout, nullptr, _IOLBF, 8192); - return app.Entry(argc, argv); -} diff --git a/fastos/src/tests/thread_sleep_test.cpp b/fastos/src/tests/thread_sleep_test.cpp deleted file mode 100644 index 209b7d3f880..00000000000 --- a/fastos/src/tests/thread_sleep_test.cpp +++ /dev/null @@ -1,52 +0,0 @@ -// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -#include "tests.h" -#include "job.h" -#include "thread_test_base.hpp" - -class Thread_Sleep_Test : public ThreadTestBase -{ - int Main () override; - - void CreateSingleThread () - { - TestHeader("Create Single Thread Test"); - - FastOS_ThreadPool *pool = new FastOS_ThreadPool(128*1024); - - if(Progress(pool != nullptr, "Allocating ThreadPool")) - { - bool rc = (nullptr != pool->NewThread(this, nullptr)); - Progress(rc, "Creating Thread"); - - Progress(true, "Sleeping 3 seconds"); - std::this_thread::sleep_for(3s); - } - - Progress(true, "Closing threadpool..."); - pool->Close(); - - Progress(true, "Deleting threadpool..."); - delete(pool); - PrintSeparator(); - } -}; - -int Thread_Sleep_Test::Main () -{ - printf("grep for the string '%s' to detect failures.\n\n", failString); - time_t before = time(0); - - CreateSingleThread(); - { time_t now = time(0); printf("[%ld seconds]\n", now-before); before = now; } - - printf("END OF TEST (%s)\n", _argv[0]); - return allWasOk() ? 0 : 1; -} - -int main (int argc, char **argv) -{ - Thread_Sleep_Test app; - setvbuf(stdout, nullptr, _IOLBF, 8192); - return app.Entry(argc, argv); -} diff --git a/fastos/src/tests/thread_stats_test.cpp b/fastos/src/tests/thread_stats_test.cpp index a9d304d411f..9dadda20a14 100644 --- a/fastos/src/tests/thread_stats_test.cpp +++ b/fastos/src/tests/thread_stats_test.cpp @@ -18,20 +18,14 @@ class Thread_Stats_Test : public ThreadTestBase Job job[2]; inactiveThreads = pool.GetNumInactiveThreads(); - Progress(inactiveThreads == 0, "Initial inactive threads = %d", - inactiveThreads); + Progress(inactiveThreads == 0, "Initial inactive threads = %d", inactiveThreads); activeThreads = pool.GetNumActiveThreads(); - Progress(activeThreads == 0, "Initial active threads = %d", - activeThreads); + Progress(activeThreads == 0, "Initial active threads = %d", activeThreads); startedThreads = pool.GetNumStartedThreads(); - Progress(startedThreads == 0, "Initial started threads = %d", - startedThreads); + Progress(startedThreads == 0, "Initial started threads = %d", startedThreads); job[0].code = WAIT_FOR_BREAK_FLAG; - job[0].ownThread = pool.NewThread(this, - static_cast<void *>(&job[0])); - - std::this_thread::sleep_for(1s); + job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0])); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -41,10 +35,7 @@ class Thread_Stats_Test : public ThreadTestBase Progress(startedThreads == 1, "Started threads = %d", startedThreads); job[1].code = WAIT_FOR_BREAK_FLAG; - job[1].ownThread = pool.NewThread(this, - static_cast<void *>(&job[1])); - - std::this_thread::sleep_for(1s); + job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1])); inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); @@ -57,7 +48,11 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread->SetBreakFlag(); job[1].ownThread->SetBreakFlag(); - std::this_thread::sleep_for(3s); + job[0].ownThread->Join(); + job[1].ownThread->Join(); + while (pool.GetNumInactiveThreads() != 2) { + std::this_thread::sleep_for(1ms); + } inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads); @@ -66,14 +61,11 @@ class Thread_Stats_Test : public ThreadTestBase startedThreads = pool.GetNumStartedThreads(); Progress(startedThreads == 2, "Started threads = %d", startedThreads); - Progress(true, "Repeating process in the same pool..."); job[0].code = WAIT_FOR_BREAK_FLAG; job[0].ownThread = pool.NewThread(this, static_cast<void *>(&job[0])); - std::this_thread::sleep_for(1s); - inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 1, "Inactive threads = %d", inactiveThreads); activeThreads = pool.GetNumActiveThreads(); @@ -84,8 +76,6 @@ class Thread_Stats_Test : public ThreadTestBase job[1].code = WAIT_FOR_BREAK_FLAG; job[1].ownThread = pool.NewThread(this, static_cast<void *>(&job[1])); - std::this_thread::sleep_for(1s); - inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 0, "Inactive threads = %d", inactiveThreads); activeThreads = pool.GetNumActiveThreads(); @@ -97,7 +87,11 @@ class Thread_Stats_Test : public ThreadTestBase job[0].ownThread->SetBreakFlag(); job[1].ownThread->SetBreakFlag(); - std::this_thread::sleep_for(3s); + job[0].ownThread->Join(); + job[1].ownThread->Join(); + while (pool.GetNumInactiveThreads() != 2) { + std::this_thread::sleep_for(1ms); + } inactiveThreads = pool.GetNumInactiveThreads(); Progress(inactiveThreads == 2, "Inactive threads = %d", inactiveThreads); @@ -106,7 +100,6 @@ class Thread_Stats_Test : public ThreadTestBase startedThreads = pool.GetNumStartedThreads(); Progress(startedThreads == 4, "Started threads = %d", startedThreads); - pool.Close(); Progress(true, "Pool closed."); diff --git a/fastos/src/tests/thread_test_base.hpp b/fastos/src/tests/thread_test_base.hpp index c4f7ed76ea7..e77f61dddb3 100644 --- a/fastos/src/tests/thread_test_base.hpp +++ b/fastos/src/tests/thread_test_base.hpp @@ -20,7 +20,7 @@ public: : printMutex() { } - virtual ~ThreadTestBase() {}; + virtual ~ThreadTestBase() {} void PrintProgress (char *string) override { @@ -48,7 +48,7 @@ public: } } - std::this_thread::sleep_for(500ms); + std::this_thread::sleep_for(1us); if(threadsFinished) break; @@ -84,12 +84,12 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) break; } - case PRINT_MESSAGE_AND_WAIT3SEC: + case PRINT_MESSAGE_AND_WAIT3MSEC: { Progress(true, "Thread printing message: [%s]", job->message); job->result = strlen(job->message); - std::this_thread::sleep_for(3s); + std::this_thread::sleep_for(3ms); break; } @@ -110,7 +110,7 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) number = number + 2; if(i == sleepOn) - std::this_thread::sleep_for(1s); + std::this_thread::sleep_for(1ms); } guard = std::unique_lock<std::mutex>(); @@ -124,10 +124,9 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) { for(;;) { - std::this_thread::sleep_for(1s); + std::this_thread::sleep_for(1us); - if(thread->GetBreakFlag()) - { + if (thread->GetBreakFlag()) { Progress(true, "Thread %p got breakflag", thread); break; } @@ -159,24 +158,6 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) break; } - case BOUNCE_CONDITIONS: - { - while (!thread->GetBreakFlag()) { - { - std::lock_guard<std::mutex> guard(*job->otherjob->mutex); - job->otherjob->bouncewakeupcnt++; - job->otherjob->bouncewakeup = true; - job->otherjob->condition->notify_one(); - } - std::unique_lock<std::mutex> guard(*job->mutex); - while (!job->bouncewakeup) { - job->condition->wait_for(guard, 1ms); - } - job->bouncewakeup = false; - } - break; - } - case TEST_ID: { job->mutex->lock(); // Initially the parent threads owns the lock @@ -199,23 +180,6 @@ void ThreadTestBase::Run (FastOS_ThreadInterface *thread, void *arg) break; } - case HOLD_MUTEX_FOR2SEC: - { - { - std::lock_guard<std::mutex> guard(*job->mutex); - std::this_thread::sleep_for(2s); - } - job->result = 1; - break; - } - - case WAIT_2_SEC: - { - std::this_thread::sleep_for(2s); - job->result = 1; - break; - } - default: Progress(false, "Unknown jobcode"); break; diff --git a/fastos/src/tests/threadtest.cpp b/fastos/src/tests/threadtest.cpp index 1fa9820c8d7..129e067f229 100644 --- a/fastos/src/tests/threadtest.cpp +++ b/fastos/src/tests/threadtest.cpp @@ -6,7 +6,6 @@ #include <cstdlib> #include <chrono> -#define MUTEX_TEST_THREADS 6 #define MAX_THREADS 7 using namespace std::chrono; @@ -62,7 +61,7 @@ class ThreadTest : public ThreadTestBase for(i=0; i<MAX_THREADS; i++) { - jobs[i].code = PRINT_MESSAGE_AND_WAIT3SEC; + jobs[i].code = PRINT_MESSAGE_AND_WAIT3MSEC; jobs[i].message = static_cast<char *>(malloc(100)); sprintf(jobs[i].message, "Thread %d invocation", i+1); } @@ -103,62 +102,6 @@ class ThreadTest : public ThreadTestBase PrintSeparator(); } - - void HowManyThreadsTest () - { - #define HOW_MAX_THREADS (1024) - TestHeader("How Many Threads Test"); - - FastOS_ThreadPool *pool = new FastOS_ThreadPool(128*1024, HOW_MAX_THREADS); - - if(Progress(pool != nullptr, "Allocating ThreadPool")) - { - int i; - Job jobs[HOW_MAX_THREADS]; - - for(i=0; i<HOW_MAX_THREADS; i++) - { - jobs[i].code = PRINT_MESSAGE_AND_WAIT3SEC; - jobs[i].message = static_cast<char *>(malloc(100)); - sprintf(jobs[i].message, "Thread %d invocation", i+1); - } - - for(i=0; i<HOW_MAX_THREADS; i++) - { - if(i==HOW_MAX_THREADS) - { - bool rc = (nullptr == pool->NewThread(this, - static_cast<void *>(&jobs[0]))); - Progress(rc, "Creating too many threads should fail."); - } - else - { - bool rc = (nullptr != pool->NewThread(this, - static_cast<void *>(&jobs[i]))); - Progress(rc, "Creating Thread"); - } - }; - - WaitForThreadsToFinish(jobs, HOW_MAX_THREADS); - - Progress(true, "Verifying result codes..."); - for(i=0; i<HOW_MAX_THREADS; i++) - { - Progress(jobs[i].result == - static_cast<int>(strlen(jobs[i].message)), - "Checking result code from thread (%d==%d)", - jobs[i].result, strlen(jobs[i].message)); - } - - Progress(true, "Closing threadpool..."); - pool->Close(); - - Progress(true, "Deleting threadpool..."); - delete(pool); - } - PrintSeparator(); - } - void CreateSingleThreadAndJoin () { TestHeader("Create Single Thread And Join Test"); diff --git a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java index f1eaf522308..8f6d55060e4 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/Flags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/Flags.java @@ -106,13 +106,6 @@ public class Flags { "Takes effect at redeployment", ZONE_ID, APPLICATION_ID); - public static final UnboundBooleanFlag USE_DIRECT_STORAGE_API_RPC = defineFeatureFlag( - "use-direct-storage-api-rpc", false, - List.of("geirst"), "2020-12-02", "2021-02-01", - "Whether to use direct RPC for Storage API communication between content cluster nodes.", - "Takes effect at restart of distributor and content node process", - ZONE_ID, APPLICATION_ID); - public static final UnboundBooleanFlag USE_FAST_VALUE_TENSOR_IMPLEMENTATION = defineFeatureFlag( "use-fast-value-tensor-implementation", false, List.of("geirst"), "2020-12-02", "2021-02-01", diff --git a/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java b/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java index a3e2a11a79c..823662a74f2 100644 --- a/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java +++ b/flags/src/main/java/com/yahoo/vespa/flags/PermanentFlags.java @@ -131,6 +131,12 @@ public class PermanentFlags { "takes effect on browser reload of api/user/v1/user", CONSOLE_USER_EMAIL); + public static final UnboundLongFlag INVALIDATE_CONSOLE_SESSIONS = defineLongFlag( + "invalidate-console-sessions", 0, + "Invalidate console sessions (cookies) issued before this unix timestamp", + "Takes effect on next api request" + ); + private PermanentFlags() {} private static UnboundBooleanFlag defineFeatureFlag( diff --git a/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java b/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java index 3a848b33c76..f17816f224d 100644 --- a/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java +++ b/hosted-api/src/main/java/ai/vespa/hosted/api/ControllerHttpClient.java @@ -121,6 +121,13 @@ public abstract class ControllerHttpClient { DELETE))); } + /** Sets suspension status of the given application in the given zone. */ + public String suspend(ApplicationId id, ZoneId zone, boolean suspend) { + return toMessage(send(request(HttpRequest.newBuilder(suspendPath(id, zone)) + .timeout(Duration.ofSeconds(10)), + suspend ? POST : DELETE))); + } + /** Returns the default {@link ZoneId} for the given environment, if any. */ public ZoneId defaultZone(Environment environment) { Inspector rootObject = toInspector(send(request(HttpRequest.newBuilder(defaultRegionPath(environment)) @@ -225,6 +232,10 @@ public abstract class ControllerHttpClient { "region", zone.region().value()); } + private URI suspendPath(ApplicationId id, ZoneId zone) { + return concatenated(deploymentPath(id, zone), "suspend"); + } + private URI deploymentJobPath(ApplicationId id, ZoneId zone) { return concatenated(instancePath(id), "deploy", jobNameOf(zone)); diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java index adf3e4ecaaa..2c56f0e356b 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/NGramExpression.java @@ -52,9 +52,12 @@ public final class NGramExpression extends Expression { // This expression is already executed for this input instance return; } - SpanList spanList = input.setSpanTree(new SpanTree(SpanTrees.LINGUISTICS)).spanList(); + StringFieldValue output = input.clone(); + ctx.setValue(output); + + SpanList spanList = output.setSpanTree(new SpanTree(SpanTrees.LINGUISTICS)).spanList(); int lastPosition = 0; - for (Iterator<GramSplitter.Gram> it = linguistics.getGramSplitter().split(input.getString(), gramSize); it.hasNext();) { + for (Iterator<GramSplitter.Gram> it = linguistics.getGramSplitter().split(output.getString(), gramSize); it.hasNext();) { GramSplitter.Gram gram = it.next(); // if there is a gap before this gram, then annotate the gram as punctuation // (technically it may be of various types, but it does not matter - we just @@ -64,15 +67,15 @@ public final class NGramExpression extends Expression { } // annotate gram as a word term - String gramString = gram.extractFrom(input.getString()); + String gramString = gram.extractFrom(output.getString()); typedSpan(gram.getStart(), gram.getCodePointCount(), TokenType.ALPHABETIC, spanList). annotate(LinguisticsAnnotator.lowerCaseTermAnnotation(gramString, gramString)); lastPosition = gram.getStart() + gram.getCodePointCount(); } // handle punctuation at the end - if (lastPosition < input.toString().length()) { - typedSpan(lastPosition, input.toString().length() - lastPosition, TokenType.PUNCTUATION, spanList); + if (lastPosition < output.toString().length()) { + typedSpan(lastPosition, output.toString().length() - lastPosition, TokenType.PUNCTUATION, spanList); } } diff --git a/messagebus/src/main/java/com/yahoo/messagebus/routing/RoutingTable.java b/messagebus/src/main/java/com/yahoo/messagebus/routing/RoutingTable.java index 82f409dbb44..e3410cdba7d 100644 --- a/messagebus/src/main/java/com/yahoo/messagebus/routing/RoutingTable.java +++ b/messagebus/src/main/java/com/yahoo/messagebus/routing/RoutingTable.java @@ -13,8 +13,8 @@ import java.util.Map; */ public class RoutingTable { - private final Map<String, HopBlueprint> hops = new LinkedHashMap<String, HopBlueprint>(); - private final Map<String, Route> routes = new LinkedHashMap<String, Route>(); + private final Map<String, HopBlueprint> hops = new LinkedHashMap<>(); + private final Map<String, Route> routes = new LinkedHashMap<>(); /** * Creates a new routing table based on a given specification. This also verifies the integrity of the table. diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java index 5c635551692..b0b61e8a6b2 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/NodeList.java @@ -12,8 +12,10 @@ import com.yahoo.config.provision.NodeType; import java.util.Comparator; import java.util.EnumSet; import java.util.List; +import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; import java.util.stream.Stream; @@ -213,6 +215,28 @@ public class NodeList extends AbstractFilteringList<Node, NodeList> { first().get().resources()); } + /** Returns the nodes that are allocated on an exclusive network switch within its cluster */ + public NodeList onExclusiveSwitch(NodeList clusterHosts) { + ensureSingleCluster(); + Map<String, Long> switchCount = clusterHosts.stream() + .flatMap(host -> host.switchHostname().stream()) + .collect(Collectors.groupingBy(Function.identity(), + Collectors.counting())); + return matching(node -> { + Optional<Node> nodeOnSwitch = clusterHosts.parentOf(node); + if (node.parentHostname().isPresent()) { + if (nodeOnSwitch.isEmpty()) { + throw new IllegalArgumentException("Parent of " + node + ", " + node.parentHostname().get() + + ", not found in given cluster hosts"); + } + } else { + nodeOnSwitch = Optional.of(node); + } + Optional<String> allocatedSwitch = nodeOnSwitch.flatMap(Node::switchHostname); + return allocatedSwitch.isEmpty() || switchCount.get(allocatedSwitch.get()) == 1; + }); + } + private void ensureSingleCluster() { if (isEmpty()) return; diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java index 4a5c28fe0c8..778a3656dca 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporter.java @@ -79,6 +79,7 @@ public class MetricsReporter extends NodeRepositoryMaintainer { updateTenantUsageMetrics(nodes); updateRepairTicketMetrics(nodes); updateAllocationMetrics(nodes); + updateExclusiveSwitchMetrics(nodes); return true; } @@ -102,11 +103,24 @@ public class MetricsReporter extends NodeRepositoryMaintainer { } else { nonActiveFraction = (double) nonActiveNodes / (double) activeNodes; } - Map<String, String> dimensions = new HashMap<>(dimensions(clusterKey.application)); - dimensions.put("clusterId", clusterKey.cluster.value()); - metric.set("nodes.active", activeNodes, getContext(dimensions)); - metric.set("nodes.nonActive", nonActiveNodes, getContext(dimensions)); - metric.set("nodes.nonActiveFraction", nonActiveFraction, getContext(dimensions)); + Metric.Context context = getContext(dimensions(clusterKey.application, clusterKey.cluster)); + metric.set("nodes.active", activeNodes, context); + metric.set("nodes.nonActive", nonActiveNodes, context); + metric.set("nodes.nonActiveFraction", nonActiveFraction, context); + }); + } + + private void updateExclusiveSwitchMetrics(NodeList nodes) { + Map<ClusterKey, List<Node>> byCluster = nodes.stream() + .filter(node -> node.type() == NodeType.tenant) + .filter(node -> node.state() == State.active) + .filter(node -> node.allocation().isPresent()) + .collect(Collectors.groupingBy(node -> new ClusterKey(node.allocation().get().owner(), node.allocation().get().membership().cluster().id()))); + byCluster.forEach((clusterKey, clusterNodes) -> { + NodeList clusterHosts = nodes.parentsOf(NodeList.copyOf(clusterNodes)); + long nodesOnExclusiveSwitch = NodeList.copyOf(clusterNodes).onExclusiveSwitch(clusterHosts).size(); + double exclusiveSwitchRatio = nodesOnExclusiveSwitch / (double) clusterNodes.size(); + metric.set("nodes.exclusiveSwitchFraction", exclusiveSwitchRatio, getContext(dimensions(clusterKey.application, clusterKey.cluster))); }); } @@ -340,6 +354,12 @@ public class MetricsReporter extends NodeRepositoryMaintainer { .forEach((status, number) -> metric.set("hostedVespa.breakfixedHosts", number, getContext(Map.of("status", status)))); } + static Map<String, String> dimensions(ApplicationId application, ClusterSpec.Id cluster) { + Map<String, String> dimensions = new HashMap<>(dimensions(application)); + dimensions.put("clusterId", cluster.value()); + return dimensions; + } + private static Map<String, String> dimensions(ApplicationId application) { return Map.of("tenantName", application.tenant().value(), "applicationId", application.serializedForm().replace(':', '.'), diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintainer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintainer.java index 84dd2c6a8c3..fdbf199898a 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintainer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/NodeRepositoryMaintainer.java @@ -25,8 +25,8 @@ public abstract class NodeRepositoryMaintainer extends Maintainer { private final NodeRepository nodeRepository; public NodeRepositoryMaintainer(NodeRepository nodeRepository, Duration interval, Metric metric) { - super(null, interval, nodeRepository.clock().instant(), nodeRepository.jobControl(), jobMetrics(metric), - nodeRepository.database().cluster()); + super(null, interval, nodeRepository.clock().instant(), nodeRepository.jobControl(), + jobMetrics(metric), nodeRepository.database().cluster()); this.nodeRepository = nodeRepository; } diff --git a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/SwitchRebalancer.java b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/SwitchRebalancer.java index ee02beb168f..e545b3d97ee 100644 --- a/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/SwitchRebalancer.java +++ b/node-repository/src/main/java/com/yahoo/vespa/hosted/provision/maintenance/SwitchRebalancer.java @@ -13,7 +13,7 @@ import com.yahoo.vespa.hosted.provision.node.Agent; import java.time.Duration; import java.util.HashSet; -import java.util.Optional; +import java.util.List; import java.util.Set; /** @@ -67,17 +67,12 @@ public class SwitchRebalancer extends NodeMover<Move> { } /** Returns whether allocatedNode is on an exclusive switch */ - private boolean onExclusiveSwitch(Node allocatedNode, NodeList clusterHosts) { - Optional<String> allocatedSwitch = clusterHosts.parentOf(allocatedNode).flatMap(Node::switchHostname); - if (allocatedSwitch.isEmpty()) return true; - return clusterHosts.stream() - .flatMap(host -> host.switchHostname().stream()) - .filter(switchHostname -> switchHostname.equals(allocatedSwitch.get())) - .count() == 1; + private static boolean onExclusiveSwitch(Node allocatedNode, NodeList clusterHosts) { + return !NodeList.copyOf(List.of(allocatedNode)).onExclusiveSwitch(clusterHosts).isEmpty(); } /** Returns whether allocating a node on toHost would increase the number of exclusive switches */ - private boolean increasesExclusiveSwitches(NodeList clusterNodes, NodeList clusterHosts, Node toHost) { + private static boolean increasesExclusiveSwitches(NodeList clusterNodes, NodeList clusterHosts, Node toHost) { if (toHost.switchHostname().isEmpty()) return false; Set<String> activeSwitches = new HashSet<>(); int unknownSwitches = 0; diff --git a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java index 3e4887b6998..0a4ba497558 100644 --- a/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java +++ b/node-repository/src/test/java/com/yahoo/vespa/hosted/provision/maintenance/MetricsReporterTest.java @@ -6,20 +6,16 @@ import com.yahoo.config.provision.ApplicationId; import com.yahoo.config.provision.Capacity; import com.yahoo.config.provision.ClusterMembership; import com.yahoo.config.provision.ClusterResources; -import com.yahoo.config.provision.DockerImage; +import com.yahoo.config.provision.ClusterSpec; import com.yahoo.config.provision.NodeFlavors; import com.yahoo.config.provision.NodeResources; import com.yahoo.config.provision.NodeType; -import com.yahoo.config.provision.Zone; import com.yahoo.jdisc.Metric; import com.yahoo.transaction.Mutex; import com.yahoo.transaction.NestedTransaction; import com.yahoo.vespa.applicationmodel.ApplicationInstance; import com.yahoo.vespa.applicationmodel.ApplicationInstanceReference; -import com.yahoo.vespa.curator.Curator; -import com.yahoo.vespa.curator.mock.MockCurator; import com.yahoo.vespa.curator.stats.LockStats; -import com.yahoo.vespa.flags.InMemoryFlagSource; import com.yahoo.vespa.hosted.provision.LockedNodeList; import com.yahoo.vespa.hosted.provision.Node; import com.yahoo.vespa.hosted.provision.NodeRepository; @@ -27,10 +23,8 @@ import com.yahoo.vespa.hosted.provision.node.Agent; import com.yahoo.vespa.hosted.provision.node.Allocation; import com.yahoo.vespa.hosted.provision.node.Generation; import com.yahoo.vespa.hosted.provision.node.IP; -import com.yahoo.vespa.hosted.provision.provisioning.EmptyProvisionServiceProvider; import com.yahoo.vespa.hosted.provision.provisioning.FlavorConfigBuilder; import com.yahoo.vespa.hosted.provision.provisioning.ProvisioningTester; -import com.yahoo.vespa.hosted.provision.testutils.MockNameResolver; import com.yahoo.vespa.orchestrator.Orchestrator; import com.yahoo.vespa.orchestrator.status.HostInfo; import com.yahoo.vespa.orchestrator.status.HostStatus; @@ -39,7 +33,6 @@ import com.yahoo.vespa.service.monitor.ServiceMonitor; import org.junit.Before; import org.junit.Test; -import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.util.List; @@ -85,7 +78,10 @@ public class MetricsReporterTest { @Test public void test_registered_metric() { NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("default"); - ProvisioningTester tester = new ProvisioningTester.Builder().flavors(nodeFlavors.getFlavors()).build(); + Orchestrator orchestrator = mock(Orchestrator.class); + when(orchestrator.getHostInfo(eq(reference), any())).thenReturn( + HostInfo.createSuspended(HostStatus.ALLOWED_TO_BE_DOWN, Instant.ofEpochSecond(1))); + ProvisioningTester tester = new ProvisioningTester.Builder().flavors(nodeFlavors.getFlavors()).orchestrator(orchestrator).build(); NodeRepository nodeRepository = tester.nodeRepository(); tester.makeProvisionedNodes(1, "default", NodeType.tenant, 0); tester.makeProvisionedNodes(1, "default", NodeType.proxy, 0); @@ -132,17 +128,8 @@ public class MetricsReporterTest { tester.clock().setInstant(Instant.ofEpochSecond(124)); - Orchestrator orchestrator = mock(Orchestrator.class); - when(orchestrator.getHostInfo(eq(reference), any())).thenReturn( - HostInfo.createSuspended(HostStatus.ALLOWED_TO_BE_DOWN, Instant.ofEpochSecond(1))); - TestMetric metric = new TestMetric(); - MetricsReporter metricsReporter = new MetricsReporter(nodeRepository, - metric, - orchestrator, - serviceMonitor, - () -> 42, - LONG_INTERVAL); + MetricsReporter metricsReporter = metricsReporter(metric, tester); metricsReporter.maintain(); // Verify sum of values across dimensions, and remove these metrics to avoid checking against @@ -167,17 +154,8 @@ public class MetricsReporterTest { @Test public void docker_metrics() { NodeFlavors nodeFlavors = FlavorConfigBuilder.createDummies("host", "docker", "docker2"); - Curator curator = new MockCurator(); - NodeRepository nodeRepository = new NodeRepository(nodeFlavors, - new EmptyProvisionServiceProvider(), - curator, - Clock.systemUTC(), - Zone.defaultZone(), - new MockNameResolver().mockAnyLookup(), - DockerImage.fromString("docker-registry.domain.tld:8080/dist/vespa"), - new InMemoryFlagSource(), - true, - 0, 1000); + ProvisioningTester tester = new ProvisioningTester.Builder().flavors(nodeFlavors.getFlavors()).build(); + NodeRepository nodeRepository = tester.nodeRepository(); // Allow 4 containers Set<String> ipAddressPool = Set.of("::2", "::3", "::4", "::5"); @@ -210,12 +188,7 @@ public class MetricsReporterTest { when(orchestrator.getHostInfo(eq(reference), any())).thenReturn(HostInfo.createNoRemarks()); TestMetric metric = new TestMetric(); - MetricsReporter metricsReporter = new MetricsReporter(nodeRepository, - metric, - orchestrator, - serviceMonitor, - () -> 42, - LONG_INTERVAL); + MetricsReporter metricsReporter = metricsReporter(metric, tester); metricsReporter.maintain(); assertEquals(0, metric.values.get("hostedVespa.readyHosts")); // Only tenants counts @@ -246,13 +219,7 @@ public class MetricsReporterTest { tester.makeReadyHosts(5, new NodeResources(64, 256, 2000, 10)); tester.activateTenantHosts(); TestMetric metric = new TestMetric(); - MetricsReporter metricsReporter = new MetricsReporter(tester.nodeRepository(), - metric, - tester.orchestrator(), - serviceMonitor, - () -> 42, - LONG_INTERVAL); - + MetricsReporter metricsReporter = metricsReporter(metric, tester); // Application is deployed ApplicationId application = ApplicationId.from("t1", "a1", "default"); @@ -279,6 +246,46 @@ public class MetricsReporterTest { assertEquals(3, getMetric("nodes.nonActive", metric, dimensions)); } + @Test + public void exclusive_switch_ratio() { + ProvisioningTester tester = new ProvisioningTester.Builder().build(); + ClusterSpec spec = ClusterSpec.request(ClusterSpec.Type.container, ClusterSpec.Id.from("c1")).vespaVersion("1").build(); + Capacity capacity = Capacity.from(new ClusterResources(4, 1, new NodeResources(4, 8, 50, 1))); + ApplicationId app = ApplicationId.from("t1", "a1", "default"); + TestMetric metric = new TestMetric(); + MetricsReporter metricsReporter = metricsReporter(metric, tester); + + // Provision initial hosts on two switches + NodeResources hostResources = new NodeResources(8, 16, 500, 10); + List<Node> hosts0 = tester.makeReadyNodes(4, hostResources, NodeType.host, 5); + tester.activateTenantHosts(); + String switch0 = "switch0"; + String switch1 = "switch1"; + tester.patchNode(hosts0.get(0), (host) -> host.withSwitchHostname(switch0)); + tester.patchNodes(hosts0.subList(1, hosts0.size()), (host) -> host.withSwitchHostname(switch1)); + + // Deploy application + tester.deploy(app, spec, capacity); + tester.assertSwitches(Set.of(switch0, switch1), app, spec.id()); + metricsReporter.maintain(); + assertEquals(0.25D, getMetric("nodes.exclusiveSwitchFraction", metric, MetricsReporter.dimensions(app, spec.id())).doubleValue(), Double.MIN_VALUE); + + // More exclusive switches become available + List<Node> hosts1 = tester.makeReadyNodes(2, hostResources, NodeType.host, 5); + tester.activateTenantHosts(); + String switch2 = "switch2"; + String switch3 = "switch3"; + tester.patchNode(hosts1.get(0), (host) -> host.withSwitchHostname(switch2)); + tester.patchNode(hosts1.get(1), (host) -> host.withSwitchHostname(switch3)); + + // Another cluster is added + ClusterSpec spec2 = ClusterSpec.request(ClusterSpec.Type.content, ClusterSpec.Id.from("c2")).vespaVersion("1").build(); + tester.deploy(app, spec2, capacity); + tester.assertSwitches(Set.of(switch0, switch1, switch2, switch3), app, spec2.id()); + metricsReporter.maintain(); + assertEquals(1D, getMetric("nodes.exclusiveSwitchFraction", metric, MetricsReporter.dimensions(app, spec2.id())).doubleValue(), Double.MIN_VALUE); + } + private Number getMetric(String name, TestMetric metric, Map<String, String> dimensions) { List<TestMetric.TestContext> metrics = metric.context.get(name).stream() .filter(ctx -> ctx.properties.entrySet().containsAll(dimensions.entrySet())) @@ -306,4 +313,13 @@ public class MetricsReporterTest { return Optional.empty(); } + private MetricsReporter metricsReporter(TestMetric metric, ProvisioningTester tester) { + return new MetricsReporter(tester.nodeRepository(), + metric, + tester.orchestrator(), + serviceMonitor, + () -> 42, + LONG_INTERVAL); + } + } diff --git a/orchestrator-restapi/src/main/java/com/yahoo/vespa/orchestrator/restapi/ApplicationSuspensionApi.java b/orchestrator-restapi/src/main/java/com/yahoo/vespa/orchestrator/restapi/ApplicationSuspensionApi.java index 1c597a73d01..e44f6fa0df7 100644 --- a/orchestrator-restapi/src/main/java/com/yahoo/vespa/orchestrator/restapi/ApplicationSuspensionApi.java +++ b/orchestrator-restapi/src/main/java/com/yahoo/vespa/orchestrator/restapi/ApplicationSuspensionApi.java @@ -18,6 +18,7 @@ import java.util.Set; * * @author smorgrav */ +@Path("/orchestrator" + ApplicationSuspensionApi.PATH_PREFIX) public interface ApplicationSuspensionApi { /** * Path prefix for this api. Resources implementing this API should use this with a @Path annotation. diff --git a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/status/ZkStatusServiceTest.java b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/status/ZkStatusServiceTest.java index 230290a632a..9a79828bc7b 100644 --- a/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/status/ZkStatusServiceTest.java +++ b/orchestrator/src/test/java/com/yahoo/vespa/orchestrator/status/ZkStatusServiceTest.java @@ -225,6 +225,7 @@ public class ZkStatusServiceTest { }; } + @SuppressWarnings("deprecation") private static void killSession(CuratorFramework curatorFramework, TestingServer testingServer) { try { KillSession.kill(curatorFramework.getZookeeperClient().getZooKeeper(), testingServer.getConnectString()); diff --git a/parent/pom.xml b/parent/pom.xml index 114968c02ef..3ba84254bf9 100644 --- a/parent/pom.xml +++ b/parent/pom.xml @@ -756,6 +756,11 @@ <artifactId>xercesImpl</artifactId> <version>2.12.0</version> </dependency> + <dependency> + <groupId>io.airlift</groupId> + <artifactId>aircompressor</artifactId> + <version>0.17</version> + </dependency> </dependencies> </dependencyManagement> diff --git a/processing/src/main/java/com/yahoo/processing/request/CompoundName.java b/processing/src/main/java/com/yahoo/processing/request/CompoundName.java index 09c0879fdbf..432c7473c2b 100644 --- a/processing/src/main/java/com/yahoo/processing/request/CompoundName.java +++ b/processing/src/main/java/com/yahoo/processing/request/CompoundName.java @@ -140,7 +140,7 @@ public final class CompoundName { if (nameParts.length == 0) return this; if (isEmpty()) return fromComponents(nameParts); - List<String> newCompounds = new ArrayList<>(nameParts.length+compounds.size()); + List<String> newCompounds = new ArrayList<>(nameParts.length + compounds.size()); newCompounds.addAll(Arrays.asList(nameParts)); newCompounds.addAll(this.compounds); return new CompoundName(newCompounds); @@ -192,7 +192,7 @@ public final class CompoundName { this + "' only have " + compounds.size() + " components."); if (n == 1) return rest(); if (compounds.size() == n) return empty; - return rest.rest(n-1); + return rest.rest(n - 1); } /** diff --git a/searchcore/src/tests/proton/attribute/attributeflush_test.cpp b/searchcore/src/tests/proton/attribute/attributeflush_test.cpp index d98de21ec5e..8a814e25bd5 100644 --- a/searchcore/src/tests/proton/attribute/attributeflush_test.cpp +++ b/searchcore/src/tests/proton/attribute/attributeflush_test.cpp @@ -532,12 +532,12 @@ Test::requireThatLastFlushTimeIsReported() IFlushTarget::SP ft = am.getFlushable("a9"); EXPECT_EQUAL(seconds(stat._modifiedTime), duration_cast<seconds>(ft->getLastFlushTime().time_since_epoch())); { // updated flush time after nothing to flush - std::this_thread::sleep_for(8000ms); + std::this_thread::sleep_for(1100ms); std::chrono::seconds now = duration_cast<seconds>(vespalib::system_clock::now().time_since_epoch()); Executor::Task::UP task = ft->initFlush(200, std::make_shared<search::FlushToken>()); EXPECT_FALSE(task); EXPECT_LESS(seconds(stat._modifiedTime), ft->getLastFlushTime().time_since_epoch()); - EXPECT_APPROX(now.count(), duration_cast<seconds>(ft->getLastFlushTime().time_since_epoch()).count(), 8); + EXPECT_APPROX(now.count(), duration_cast<seconds>(ft->getLastFlushTime().time_since_epoch()).count(), 3); } } } diff --git a/searchcore/src/tests/proton/index/CMakeLists.txt b/searchcore/src/tests/proton/index/CMakeLists.txt index 4a5baabe0ee..130ed2e97f4 100644 --- a/searchcore/src/tests/proton/index/CMakeLists.txt +++ b/searchcore/src/tests/proton/index/CMakeLists.txt @@ -9,6 +9,10 @@ vespa_add_executable(searchcore_indexmanager_test_app TEST searchcore_pcommon GTest::GTest ) + +vespa_add_test(NAME searchcore_indexmanager_test_app + COMMAND searchcore_indexmanager_test_app) + vespa_add_executable(searchcore_fusionrunner_test_app TEST SOURCES fusionrunner_test.cpp @@ -17,12 +21,20 @@ vespa_add_executable(searchcore_fusionrunner_test_app TEST searchcore_index searchcore_pcommon ) + +vespa_add_test(NAME searchcore_fusionrunner_test_app + COMMAND searchcore_fusionrunner_test_app) + vespa_add_executable(searchcore_diskindexcleaner_test_app TEST SOURCES diskindexcleaner_test.cpp DEPENDS searchcore_index ) + +vespa_add_test(NAME searchcore_diskindexcleaner_test_app + COMMAND searchcore_diskindexcleaner_test_app) + vespa_add_executable(searchcore_indexcollection_test_app TEST SOURCES indexcollection_test.cpp @@ -30,5 +42,6 @@ vespa_add_executable(searchcore_indexcollection_test_app TEST searchcore_index GTest::GTest ) -vespa_add_test(NAME searchcore_index_test COMMAND ${CMAKE_CURRENT_SOURCE_DIR}/index_test.sh - DEPENDS searchcore_indexmanager_test_app searchcore_fusionrunner_test_app searchcore_diskindexcleaner_test_app searchcore_indexcollection_test_app) + +vespa_add_test(NAME searchcore_indexcollection_test_app + COMMAND searchcore_indexcollection_test_app) diff --git a/searchcore/src/tests/proton/index/diskindexcleaner_test.cpp b/searchcore/src/tests/proton/index/diskindexcleaner_test.cpp index 86f120aa403..f2f133a10cd 100644 --- a/searchcore/src/tests/proton/index/diskindexcleaner_test.cpp +++ b/searchcore/src/tests/proton/index/diskindexcleaner_test.cpp @@ -27,7 +27,7 @@ public: int Main() override; }; -const string index_dir = "test_data"; +const string index_dir = "diskindexcleaner_test_data"; void removeTestData() { FastOS_FileInterface::EmptyAndRemoveDirectory(index_dir.c_str()); diff --git a/searchcore/src/tests/proton/index/fusionrunner_test.cpp b/searchcore/src/tests/proton/index/fusionrunner_test.cpp index acd5c86fd5d..80e6e8b3db8 100644 --- a/searchcore/src/tests/proton/index/fusionrunner_test.cpp +++ b/searchcore/src/tests/proton/index/fusionrunner_test.cpp @@ -80,6 +80,7 @@ class Test : public vespalib::TestApp { void requireThatFusionCanRunOnMultipleDiskIndexes(); void requireThatOldFusionIndexCanBePartOfNewFusion(); void requireThatSelectorsCanBeRebased(); + void requireThatFusionCanBeStopped(); public: Test() @@ -111,6 +112,7 @@ Test::Main() TEST_CALL(requireThatFusionCanRunOnMultipleDiskIndexes()); TEST_CALL(requireThatOldFusionIndexCanBePartOfNewFusion()); TEST_CALL(requireThatSelectorsCanBeRebased()); + TEST_CALL(requireThatFusionCanBeStopped()); TEST_DONE(); } @@ -324,6 +326,17 @@ void Test::requireThatSelectorsCanBeRebased() { checkResults(fusion_id, disk_id, 3); } +void +Test::requireThatFusionCanBeStopped() +{ + createIndex(base_dir, disk_id[0]); + createIndex(base_dir, disk_id[1]); + auto flush_token = std::make_shared<search::FlushToken>(); + flush_token->request_stop(); + uint32_t fusion_id = _fusion_runner->fuse(_fusion_spec, 0u, _ops, flush_token); + EXPECT_EQUAL(0u, fusion_id); +} + } // namespace TEST_APPHOOK(Test); diff --git a/searchcore/src/tests/proton/index/index_test.sh b/searchcore/src/tests/proton/index/index_test.sh deleted file mode 100755 index 5cffb9838da..00000000000 --- a/searchcore/src/tests/proton/index/index_test.sh +++ /dev/null @@ -1,8 +0,0 @@ -#!/bin/bash -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -set -e - -$VALGRIND ./searchcore_diskindexcleaner_test_app -$VALGRIND ./searchcore_fusionrunner_test_app -$VALGRIND ./searchcore_indexcollection_test_app -$VALGRIND ./searchcore_indexmanager_test_app diff --git a/searchcore/src/tests/proton/index/indexmanager_test.cpp b/searchcore/src/tests/proton/index/indexmanager_test.cpp index 67eb11cee3e..5ce4a1e949d 100644 --- a/searchcore/src/tests/proton/index/indexmanager_test.cpp +++ b/searchcore/src/tests/proton/index/indexmanager_test.cpp @@ -303,14 +303,14 @@ TEST_F(IndexManagerTest, require_that_memory_index_is_flushed) EXPECT_EQ(seconds(stat._modifiedTime), duration_cast<seconds>(target.getLastFlushTime().time_since_epoch())); // updated serial number & flush time when nothing to flush - std::this_thread::sleep_for(8s); + std::this_thread::sleep_for(2s); std::chrono::seconds now = duration_cast<seconds>(vespalib::system_clock::now().time_since_epoch()); vespalib::Executor::Task::UP task; runAsMaster([&]() { task = target.initFlush(2, std::make_shared<search::FlushToken>()); }); EXPECT_FALSE(task); EXPECT_EQ(2u, target.getFlushedSerialNum()); EXPECT_LT(seconds(stat._modifiedTime), duration_cast<seconds>(target.getLastFlushTime().time_since_epoch())); - EXPECT_NEAR(now.count(), duration_cast<seconds>(target.getLastFlushTime().time_since_epoch()).count(), 8); + EXPECT_NEAR(now.count(), duration_cast<seconds>(target.getLastFlushTime().time_since_epoch()).count(), 2); } } @@ -830,6 +830,30 @@ TEST_F(IndexManagerTest, field_length_info_is_loaded_from_disk_index_during_star expect_field_length_info(1, 2, *as_memory_index(*sources, 1)); } +TEST_F(IndexManagerTest, fusion_can_be_stopped) +{ + resetIndexManager(); + + addDocument(docid); + flushIndexManager(); + addDocument(docid); + flushIndexManager(); + + IndexFusionTarget target(_index_manager->getMaintainer()); + auto flush_token = std::make_shared<search::FlushToken>(); + flush_token->request_stop(); + vespalib::Executor::Task::UP fusionTask = target.initFlush(1, flush_token); + fusionTask->run(); + + FusionSpec spec = _index_manager->getMaintainer().getFusionSpec(); + set<uint32_t> fusion_ids = readDiskIds(index_dir, "fusion"); + EXPECT_TRUE(fusion_ids.empty()); + EXPECT_EQ(0u, spec.last_fusion_id); + EXPECT_EQ(2u, spec.flush_ids.size()); + EXPECT_EQ(1u, spec.flush_ids[0]); + EXPECT_EQ(2u, spec.flush_ids[1]); +} + } // namespace int diff --git a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp index 5f35ecc916d..ab5b1ac5937 100644 --- a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp +++ b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.cpp @@ -94,7 +94,9 @@ FlushEngine::FlushEngine(std::shared_ptr<flushengine::ITlsStatsFactory> tlsStats _strategyLock(), _strategyCond(), _tlsStatsFactory(std::move(tlsStatsFactory)), - _pendingPrune() + _pendingPrune(), + _normal_flush_token(std::make_shared<search::FlushToken>()), + _gc_flush_token(std::make_shared<search::FlushToken>()) { } FlushEngine::~FlushEngine() @@ -117,6 +119,7 @@ FlushEngine::close() { std::lock_guard<std::mutex> strategyGuard(_strategyLock); std::lock_guard<std::mutex> guard(_lock); + _gc_flush_token->request_stop(); _closed = true; _cond.notify_all(); } @@ -269,6 +272,16 @@ FlushEngine::getSortedTargetList() return ret; } +std::shared_ptr<search::IFlushToken> +FlushEngine::get_flush_token(const FlushContext& ctx) +{ + if (ctx.getTarget()->getType() == IFlushTarget::Type::GC) { + return _gc_flush_token; + } else { + return _normal_flush_token; + } +} + FlushContext::SP FlushEngine::initNextFlush(const FlushContext::List &lst) { @@ -277,7 +290,7 @@ FlushEngine::initNextFlush(const FlushContext::List &lst) if (LOG_WOULD_LOG(event)) { EventLogger::flushInit(it->getName()); } - if (it->initFlush(std::make_shared<search::FlushToken>())) { + if (it->initFlush(get_flush_token(*it))) { ctx = it; break; } @@ -294,7 +307,7 @@ FlushEngine::flushAll(const FlushContext::List &lst) LOG(debug, "%ld targets to flush.", lst.size()); for (const FlushContext::SP & ctx : lst) { if (wait(0)) { - if (ctx->initFlush(std::make_shared<search::FlushToken>())) { + if (ctx->initFlush(get_flush_token(*ctx))) { logTarget("initiated", *ctx); _executor.execute(std::make_unique<FlushTask>(initFlush(*ctx), *this, ctx)); } else { diff --git a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.h b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.h index 160423c7c68..f51e93f0fbd 100644 --- a/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.h +++ b/searchcore/src/vespa/searchcore/proton/flushengine/flushengine.h @@ -12,6 +12,8 @@ #include <mutex> #include <condition_variable> +namespace search { class FlushToken; } + namespace proton { namespace flushengine { class ITlsStatsFactory; } @@ -63,9 +65,12 @@ private: std::condition_variable _strategyCond; std::shared_ptr<flushengine::ITlsStatsFactory> _tlsStatsFactory; std::set<IFlushHandler::SP> _pendingPrune; + std::shared_ptr<search::FlushToken> _normal_flush_token; + std::shared_ptr<search::FlushToken> _gc_flush_token; FlushContext::List getTargetList(bool includeFlushingTargets) const; std::pair<FlushContext::List,bool> getSortedTargetList(); + std::shared_ptr<search::IFlushToken> get_flush_token(const FlushContext& ctx); FlushContext::SP initNextFlush(const FlushContext::List &lst); vespalib::string flushNextTarget(const vespalib::string & name); void flushAll(const FlushContext::List &lst); diff --git a/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp b/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp index e2bcb8b7629..38990f61b43 100644 --- a/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp +++ b/searchcorespi/src/vespa/searchcorespi/index/indexmaintainer.cpp @@ -10,6 +10,7 @@ #include "indexwriteutilities.h" #include <vespa/fastos/file.h> #include <vespa/searchcorespi/flush/closureflushtask.h> +#include <vespa/searchlib/common/i_flush_token.h> #include <vespa/searchlib/index/schemautil.h> #include <vespa/searchlib/util/dirtraverse.h> #include <vespa/searchlib/util/filekit.h> @@ -984,11 +985,16 @@ IndexMaintainer::doFusion(SerialNum serialNum, std::shared_ptr<search::IFlushTok _fusion_spec.flush_ids.clear(); } - uint32_t new_fusion_id = runFusion(spec, std::move(flush_token)); + uint32_t new_fusion_id = runFusion(spec, flush_token); LockGuard lock(_fusion_lock); if (new_fusion_id == spec.last_fusion_id) { // Error running fusion. - LOG(warning, "Fusion failed for id %u.", spec.flush_ids.back()); + string fail_dir = getFusionDir(spec.flush_ids.back()); + if (flush_token->stop_requested()) { + LOG(info, "Fusion stopped for id %u, fusion dir \"%s\".", spec.flush_ids.back(), fail_dir.c_str()); + } else { + LOG(warning, "Fusion failed for id %u, fusion dir \"%s\".", spec.flush_ids.back(), fail_dir.c_str()); + } // Restore fusion spec. copy(_fusion_spec.flush_ids.begin(), _fusion_spec.flush_ids.end(), back_inserter(spec.flush_ids)); _fusion_spec.flush_ids.swap(spec.flush_ids); @@ -1020,15 +1026,19 @@ IndexMaintainer::runFusion(const FusionSpec &fusion_spec, std::shared_ptr<search serialNum = IndexReadUtilities::readSerialNum(lastFlushDir); } FusionRunner fusion_runner(_base_dir, args._schema, tuneFileAttributes, _ctx.getFileHeaderContext()); - uint32_t new_fusion_id = fusion_runner.fuse(fusion_spec, serialNum, _operations, std::move(flush_token)); + uint32_t new_fusion_id = fusion_runner.fuse(fusion_spec, serialNum, _operations, flush_token); bool ok = (new_fusion_id != 0); if (ok) { ok = IndexWriteUtilities::copySerialNumFile(getFlushDir(fusion_spec.flush_ids.back()), getFusionDir(new_fusion_id)); } if (!ok) { - LOG(error, "Fusion failed."); string fail_dir = getFusionDir(fusion_spec.flush_ids.back()); + if (flush_token->stop_requested()) { + LOG(info, "Fusion stopped, fusion dir \"%s\".", fail_dir.c_str()); + } else { + LOG(error, "Fusion failed, fusion dir \"%s\".", fail_dir.c_str()); + } FastOS_FileInterface::EmptyAndRemoveDirectory(fail_dir.c_str()); { LockGuard slock(_state_lock); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index d412f408350..88eccb4559f 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -875,6 +875,7 @@ "public final java.lang.String outs()", "public final java.lang.String out()", "public final java.util.List args()", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arg()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode function()", "public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index b97c8316c9b..9f900ffed36 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -3,10 +3,7 @@ package com.yahoo.searchlib.rankingexpression; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; import com.yahoo.tensor.TensorType; import com.yahoo.text.Utf8; @@ -134,16 +131,7 @@ public class ExpressionFunction { public Instance expand(SerializationContext context, List<ExpressionNode> argumentValues, Deque<String> path) { Map<String, String> argumentBindings = new HashMap<>(); for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) { - String key = arguments.get(i); - ExpressionNode expr = argumentValues.get(i); - String binding = expr.toString(new StringBuilder(), context, path, null).toString(); - - if ( ! (expr instanceof ReferenceNode) && ! (expr instanceof ConstantNode) && ! (expr instanceof FunctionNode) ) { - String funcName = "autogenerated_ranking_feature@" + Long.toHexString(symbolCode(key + "=" + binding)); - context.addFunctionSerialization(RankingExpression.propertyName(funcName), binding); - binding = funcName; - } - argumentBindings.put(key, binding); + argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString()); } context = argumentBindings.isEmpty() ? context.withoutBindings() : context.withBindings(argumentBindings); String symbol = toSymbol(argumentBindings); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 36b1f9627bb..09880b8dfc3 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -328,10 +328,30 @@ List<ExpressionNode> args() : ExpressionNode argument; } { - ( ( argument = expression() { arguments.add(argument); } ( <COMMA> argument = expression() { arguments.add(argument); } )* )? ) + ( ( argument = arg() { arguments.add(argument); } ( <COMMA> argument = arg() { arguments.add(argument); } )* )? ) { return arguments; } } +// TODO: Replace use of this for function arguments with value() +// For that to work with the current search execution framework +// we need to generate another function for the argument such that we can replace +// instances of the argument with the reference to that function in the same way +// as we replace by constants/names today (this can make for some fun combinatorial explosion). +// We should also stop doing function expansion in the toString of a function. +// - Jon 2014-05-02 +ExpressionNode arg() : +{ + ExpressionNode ret; + String name; + Function fnc; +} +{ + ( ret = constantPrimitive() | + LOOKAHEAD(2) ret = feature() | + name = identifier() { ret = new NameNode(name); } ) + { return ret; } +} + ExpressionNode function() : { ExpressionNode function; diff --git a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp index 51b4f1d760d..855510d0457 100644 --- a/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attributeblueprint_test.cpp @@ -341,7 +341,7 @@ public: request_ctx.set_query_tensor("query_tensor", tensor_spec); } Blueprint::UP create_blueprint() { - query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33); + query::NearestNeighborTerm term("query_tensor", attr_name, 0, Weight(0), 7, true, 33, 100100.25); return BlueprintFactoryFixture::create_blueprint(term); } }; diff --git a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp index e1bd47af358..cbdb2c9bd22 100644 --- a/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp +++ b/searchlib/src/tests/attribute/tensorattribute/tensorattribute_test.cpp @@ -13,6 +13,7 @@ #include <vespa/searchlib/tensor/dense_tensor_attribute.h> #include <vespa/searchlib/tensor/direct_tensor_attribute.h> #include <vespa/searchlib/tensor/doc_vector_access.h> +#include <vespa/searchlib/tensor/distance_functions.h> #include <vespa/searchlib/tensor/hnsw_index.h> #include <vespa/searchlib/tensor/nearest_neighbor_index.h> #include <vespa/searchlib/tensor/nearest_neighbor_index_factory.h> @@ -206,24 +207,32 @@ public: _index_value = (reinterpret_cast<const int*>(buf.buffer()))[0]; return true; } - std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k) const override { + std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, uint32_t explore_k, + double distance_threshold) const override + { (void) k; (void) vector; (void) explore_k; + (void) distance_threshold; return std::vector<Neighbor>(); } std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, - const search::BitVector& filter, uint32_t explore_k) const override + const search::BitVector& filter, uint32_t explore_k, + double distance_threshold) const override { (void) k; (void) vector; (void) explore_k; (void) filter; + (void) distance_threshold; return std::vector<Neighbor>(); } - const search::tensor::DistanceFunction *distance_function() const override { return nullptr; } + const search::tensor::DistanceFunction *distance_function() const override { + static search::tensor::SquaredEuclideanDistance<double> my_dist_fun; + return &my_dist_fun; + } }; class MockNearestNeighborIndexFactory : public NearestNeighborIndexFactory { @@ -563,7 +572,7 @@ void Fixture::testCompaction() { if ((_traits.use_dense_tensor_attribute && _denseTensors) || - _traits.use_direct_tensor_attribute) + ! _traits.use_dense_tensor_attribute) { LOG(info, "Skipping compaction test for tensor '%s' which is using free-lists", _cfg.tensorType().to_spec().c_str()); return; @@ -914,9 +923,12 @@ public: field, as_dense_tensor(), createDenseTensor(vec_2d(17, 42)), - 3, true, 5, brute_force_limit); + 3, true, 5, + 100100.25, + brute_force_limit); EXPECT_EQUAL(11u, bp->getState().estimate().estHits); EXPECT_TRUE(bp->may_approximate()); + EXPECT_EQUAL(100100.25 * 100100.25, bp->get_distance_threshold()); return bp; } }; diff --git a/searchlib/src/tests/diskindex/fusion/fusion_test.cpp b/searchlib/src/tests/diskindex/fusion/fusion_test.cpp index efc9e99bf88..4c62140b731 100644 --- a/searchlib/src/tests/diskindex/fusion/fusion_test.cpp +++ b/searchlib/src/tests/diskindex/fusion/fusion_test.cpp @@ -602,15 +602,15 @@ TEST_F(FusionTest, require_that_fusion_can_be_stopped) auto flush_token = std::make_shared<MyFlushToken>(10000); make_simple_index("stopdump2", MockFieldLengthInspector()); ASSERT_TRUE(try_merge_simple_indexes("stopdump3", {"stopdump2"}, flush_token)); - EXPECT_EQ(40, flush_token->get_checks()); + EXPECT_EQ(48, flush_token->get_checks()); vespalib::rmdir("stopdump3", true); flush_token = std::make_shared<MyFlushToken>(1); ASSERT_FALSE(try_merge_simple_indexes("stopdump3", {"stopdump2"}, flush_token)); EXPECT_EQ(12, flush_token->get_checks()); vespalib::rmdir("stopdump3", true); - flush_token = std::make_shared<MyFlushToken>(39); + flush_token = std::make_shared<MyFlushToken>(47); ASSERT_FALSE(try_merge_simple_indexes("stopdump3", {"stopdump2"}, flush_token)); - EXPECT_EQ(41, flush_token->get_checks()); + EXPECT_EQ(49, flush_token->get_checks()); clean_stopped_fusion_testdirs(); } diff --git a/searchlib/src/tests/query/query_visitor_test.cpp b/searchlib/src/tests/query/query_visitor_test.cpp index 8441dc2227f..946ad17352d 100644 --- a/searchlib/src/tests/query/query_visitor_test.cpp +++ b/searchlib/src/tests/query/query_visitor_test.cpp @@ -99,7 +99,7 @@ void Test::requireThatAllNodesCanBeVisited() { checkVisit<SuffixTerm>(new SimpleSuffixTerm("t", "field", 0, Weight(0))); checkVisit<PredicateQuery>(new SimplePredicateQuery(PredicateQueryTerm::UP(), "field", 0, Weight(0))); checkVisit<RegExpTerm>(new SimpleRegExpTerm("t", "field", 0, Weight(0))); - checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321)); + checkVisit<NearestNeighborTerm>(new SimpleNearestNeighborTerm("query_tensor", "doc_tensor", 0, Weight(0), 123, true, 321, 100100.25)); } } // namespace diff --git a/searchlib/src/tests/query/querybuilder_test.cpp b/searchlib/src/tests/query/querybuilder_test.cpp index 5a5a5eafb2c..30b4d2ae264 100644 --- a/searchlib/src/tests/query/querybuilder_test.cpp +++ b/searchlib/src/tests/query/querybuilder_test.cpp @@ -110,7 +110,7 @@ Node::UP createQueryTree() { builder.addStringTerm(str[5], view[5], id[5], weight[6]); builder.addStringTerm(str[6], view[6], id[6], weight[7]); } - builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33); + builder.add_nearest_neighbor_term("query_tensor", "doc_tensor", id[3], weight[5], 7, true, 33, 100100.25); } Node::UP node = builder.build(); ASSERT_TRUE(node.get()); @@ -395,8 +395,9 @@ struct MyRegExpTerm : RegExpTerm { struct MyNearestNeighborTerm : NearestNeighborTerm { MyNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t i, Weight w, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) - : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) + : NearestNeighborTerm(query_tensor_name, field_name, i, w, target_num_hits, allow_approximate, explore_additional_hits, distance_threshold) {} }; diff --git a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp index ad450a91f33..09790e7e360 100644 --- a/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp +++ b/searchlib/src/tests/queryeval/nearest_neighbor/nearest_neighbor_test.cpp @@ -121,11 +121,12 @@ struct Fixture }; template <bool strict> -SimpleResult find_matches(Fixture &env, const Value &qtv) { +SimpleResult find_matches(Fixture &env, const Value &qtv, double threshold = std::numeric_limits<double>::max()) { auto md = MatchData::makeTestInstance(2, 2); auto &tfmd = *(md->resolveTermField(0)); auto &attr = *(env._tensorAttr); NearestNeighborDistanceHeap dh(2); + dh.set_distance_threshold(env.dist_fun()->convert_threshold(threshold)); const BitVector *filter = env._global_filter.get(); auto search = NearestNeighborIterator::create(strict, tfmd, qtv, attr, dh, filter, env.dist_fun()); if (strict) { @@ -159,6 +160,19 @@ verify_iterator_returns_expected_results(const vespalib::string& attribute_tenso EXPECT_EQUAL(result, farExpect); result = find_matches<false>(fixture, *farTensor); EXPECT_EQUAL(result, farExpect); + + SimpleResult null_thr5_exp({1,4,6}); + result = find_matches<true>(fixture, *nullTensor, 5.0); + EXPECT_EQUAL(result, null_thr5_exp); + result = find_matches<false>(fixture, *nullTensor, 5.0); + EXPECT_EQUAL(result, null_thr5_exp); + + SimpleResult far_thr4_exp({2,5}); + result = find_matches<true>(fixture, *farTensor, 4.0); + EXPECT_EQUAL(result, far_thr4_exp); + result = find_matches<false>(fixture, *farTensor, 4.0); + EXPECT_EQUAL(result, far_thr4_exp); + } TEST("require that NearestNeighborIterator returns expected results") { diff --git a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp index 06fb95089fd..ee0a2aff80e 100644 --- a/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp +++ b/searchlib/src/tests/tensor/distance_functions/distance_functions_test.cpp @@ -24,10 +24,12 @@ void verify_geo_miles(const DistanceFunction *dist_fun, TypedCells t2(p2); double abstract_distance = dist_fun->calc(t1, t2); double raw_score = dist_fun->to_rawscore(abstract_distance); - double m = ((1.0/raw_score)-1.0); - double d_miles = m / 1.609344; + double km = ((1.0/raw_score)-1.0); + double d_miles = km / 1.609344; EXPECT_GE(d_miles, exp_miles*0.99); EXPECT_LE(d_miles, exp_miles*1.01); + double threshold = dist_fun->convert_threshold(km); + EXPECT_DOUBLE_EQ(threshold, abstract_distance); } @@ -50,6 +52,10 @@ TEST(DistanceFunctionsTest, euclidean_gives_expected_score) double d12 = euclid->calc(t(p1), t(p2)); EXPECT_EQ(d12, 2.0); EXPECT_DOUBLE_EQ(euclid->to_rawscore(d12), 1.0/(1.0 + sqrt(2.0))); + double threshold = euclid->convert_threshold(8.0); + EXPECT_EQ(threshold, 64.0); + threshold = euclid->convert_threshold(0.5); + EXPECT_EQ(threshold, 0.25); } TEST(DistanceFunctionsTest, angular_gives_expected_score) @@ -75,19 +81,28 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_DOUBLE_EQ(a23, 1.0); EXPECT_FLOAT_EQ(angular->to_rawscore(a12), 1.0/(1.0 + pi/2)); + double threshold = angular->convert_threshold(pi/2); + EXPECT_DOUBLE_EQ(threshold, 1.0); + double a14 = angular->calc(t(p1), t(p4)); double a24 = angular->calc(t(p2), t(p4)); EXPECT_FLOAT_EQ(a14, 0.5); EXPECT_FLOAT_EQ(a24, 0.5); EXPECT_FLOAT_EQ(angular->to_rawscore(a14), 1.0/(1.0 + pi/3)); + threshold = angular->convert_threshold(pi/3); + EXPECT_DOUBLE_EQ(threshold, 0.5); double a34 = angular->calc(t(p3), t(p4)); EXPECT_FLOAT_EQ(a34, (1.0 - 0.707107)); EXPECT_FLOAT_EQ(angular->to_rawscore(a34), 1.0/(1.0 + pi/4)); + threshold = angular->convert_threshold(pi/4); + EXPECT_FLOAT_EQ(threshold, a34); double a25 = angular->calc(t(p2), t(p5)); EXPECT_DOUBLE_EQ(a25, 2.0); EXPECT_FLOAT_EQ(angular->to_rawscore(a25), 1.0/(1.0 + pi)); + threshold = angular->convert_threshold(pi); + EXPECT_FLOAT_EQ(threshold, 2.0); double a44 = angular->calc(t(p4), t(p4)); EXPECT_GE(a44, 0.0); @@ -98,6 +113,8 @@ TEST(DistanceFunctionsTest, angular_gives_expected_score) EXPECT_GE(a66, 0.0); EXPECT_LT(a66, 0.000001); EXPECT_FLOAT_EQ(angular->to_rawscore(a66), 1.0); + threshold = angular->convert_threshold(0.0); + EXPECT_FLOAT_EQ(threshold, 0.0); double a16 = angular->calc(t(p1), t(p6)); double a26 = angular->calc(t(p2), t(p6)); @@ -127,6 +144,7 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) EXPECT_DOUBLE_EQ(i12, 1.0); EXPECT_DOUBLE_EQ(i13, 1.0); EXPECT_DOUBLE_EQ(i23, 1.0); + double i14 = innerproduct->calc(t(p1), t(p4)); double i24 = innerproduct->calc(t(p2), t(p4)); EXPECT_DOUBLE_EQ(i14, 0.5); @@ -140,6 +158,13 @@ TEST(DistanceFunctionsTest, innerproduct_gives_expected_score) double i44 = innerproduct->calc(t(p4), t(p4)); EXPECT_GE(i44, 0.0); EXPECT_LT(i44, 0.000001); + + double threshold = innerproduct->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = innerproduct->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = innerproduct->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); } TEST(DistanceFunctionsTest, hamming_gives_expected_score) @@ -180,6 +205,13 @@ TEST(DistanceFunctionsTest, hamming_gives_expected_score) double d25 = hamming->calc(t(points[2]), t(points[5])); EXPECT_EQ(d25, 1.0); EXPECT_DOUBLE_EQ(hamming->to_rawscore(d25), 1.0/(1.0 + 1.0)); + + double threshold = hamming->convert_threshold(0.25); + EXPECT_DOUBLE_EQ(threshold, 0.25); + threshold = hamming->convert_threshold(0.5); + EXPECT_DOUBLE_EQ(threshold, 0.5); + threshold = hamming->convert_threshold(1.0); + EXPECT_DOUBLE_EQ(threshold, 1.0); } TEST(GeoDegreesTest, gives_expected_score) diff --git a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp index acc157709c0..20dc55df329 100644 --- a/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp +++ b/searchlib/src/tests/tensor/hnsw_index/hnsw_index_test.cpp @@ -144,11 +144,28 @@ public: if (exp_hits.size() == k) { std::vector<uint32_t> expected_by_docid = exp_hits; std::sort(expected_by_docid.begin(), expected_by_docid.end()); - auto got_by_docid = index->find_top_k(k, qv, k); + auto got_by_docid = index->find_top_k(k, qv, k, 100100.25); for (idx = 0; idx < k; ++idx) { EXPECT_EQ(expected_by_docid[idx], got_by_docid[idx].docid); } } + check_with_distance_threshold(docid); + } + void check_with_distance_threshold(uint32_t docid) { + auto qv = vectors.get_vector(docid); + uint32_t k = 3; + auto rv = index->top_k_candidates(qv, k, global_filter.get()).peek(); + std::sort(rv.begin(), rv.end(), LesserDistance()); + EXPECT_EQ(rv.size(), 3); + EXPECT_LE(rv[0].distance, rv[1].distance); + double thr = (rv[0].distance + rv[1].distance) * 0.5; + auto got_by_docid = index->find_top_k_with_filter(k, qv, *global_filter, k, thr); + EXPECT_EQ(got_by_docid.size(), 1); + EXPECT_EQ(got_by_docid[0].docid, rv[0].docid); + for (const auto & hit : got_by_docid) { + LOG(debug, "from docid=%u found docid=%u dist=%g (threshold %g)\n", + docid, hit.docid, hit.distance, thr); + } } }; diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index dfcbfbbbe2b..70a59f1575a 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -736,6 +736,7 @@ public: n.get_target_num_hits(), n.get_allow_approximate(), n.get_explore_additional_hits(), + n.get_distance_threshold(), getRequestContext().get_attribute_blueprint_params().nearest_neighbor_brute_force_limit)); } }; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp index 6039a86580c..ce24de02281 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.cpp @@ -273,6 +273,13 @@ SimpleQueryStackDumpIterator::next() _extraIntArg1 = readCompressedPositiveInt(p); // targetNumHits _extraIntArg2 = readCompressedPositiveInt(p); // allow_approximate _extraIntArg3 = readCompressedPositiveInt(p); // explore_additional_hits + if ((_extraIntArg2 & 0x40) != 0) { + _extraIntArg2 &= ~0x40; + // in some later release, do this always: + _extraDoubleArg4 = read_double(p); // distance threshold + } else { + _extraDoubleArg4 = std::numeric_limits<double>::max(); + } _currArity = 0; } catch (...) { return false; diff --git a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h index d60765f3fe1..301929c8919 100644 --- a/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h +++ b/searchlib/src/vespa/searchlib/parsequery/stackdumpiterator.h @@ -119,6 +119,7 @@ public: uint32_t getNearDistance() const { return _extraIntArg1; } uint32_t getTargetNumHits() const { return _extraIntArg1; } + double getDistanceThreshold() const { return _extraDoubleArg4; } double getScoreThreshold() const { return _extraDoubleArg4; } double getThresholdBoostFactor() const { return _extraDoubleArg5; } bool getAllowApproximate() const { return (_extraIntArg2 != 0); } diff --git a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h index 8e6f2944ec9..8392730cd29 100644 --- a/searchlib/src/vespa/searchlib/query/tree/querybuilder.h +++ b/searchlib/src/vespa/searchlib/query/tree/querybuilder.h @@ -206,10 +206,12 @@ template <class NodeTypes> typename NodeTypes::NearestNeighborTerm * create_nearest_neighbor_term(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) { return new typename NodeTypes::NearestNeighborTerm(query_tensor_name, field_name, id, weight, - target_num_hits, allow_approximate, explore_additional_hits); + target_num_hits, allow_approximate, explore_additional_hits, + distance_threshold); } template <class NodeTypes> @@ -321,9 +323,11 @@ public: } typename NodeTypes::NearestNeighborTerm &add_nearest_neighbor_term(stringref query_tensor_name, stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) { + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) + { adjustWeight(weight); - return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits)); + return addTerm(create_nearest_neighbor_term<NodeTypes>(query_tensor_name, field_name, id, weight, target_num_hits, allow_approximate, explore_additional_hits, distance_threshold)); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h index 600249c3e1e..4b9226f6112 100644 --- a/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h +++ b/searchlib/src/vespa/searchlib/query/tree/queryreplicator.h @@ -166,7 +166,8 @@ private: void visit(NearestNeighborTerm &node) override { replicate(node, _builder.add_nearest_neighbor_term(node.get_query_tensor_name(), node.getView(), node.getId(), node.getWeight(), node.get_target_num_hits(), - node.get_allow_approximate(), node.get_explore_additional_hits())); + node.get_allow_approximate(), node.get_explore_additional_hits(), + node.get_distance_threshold())); } }; diff --git a/searchlib/src/vespa/searchlib/query/tree/simplequery.h b/searchlib/src/vespa/searchlib/query/tree/simplequery.h index 4953f1a5b7c..db517edc348 100644 --- a/searchlib/src/vespa/searchlib/query/tree/simplequery.h +++ b/searchlib/src/vespa/searchlib/query/tree/simplequery.h @@ -106,9 +106,11 @@ struct SimpleRegExpTerm : RegExpTerm { struct SimpleNearestNeighborTerm : NearestNeighborTerm { SimpleNearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) : NearestNeighborTerm(query_tensor_name, field_name, id, weight, - target_num_hits, allow_approximate, explore_additional_hits) + target_num_hits, allow_approximate, explore_additional_hits, + distance_threshold) {} }; diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp index 9af1ecee224..d2ac46157b1 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpcreator.cpp @@ -263,8 +263,10 @@ class QueryNodeConverter : public QueryVisitor { createTermNode(node, ParseItem::ITEM_NEAREST_NEIGHBOR); appendString(node.get_query_tensor_name()); appendCompressedPositiveNumber(node.get_target_num_hits()); - appendCompressedPositiveNumber(node.get_allow_approximate() ? 1 : 0); + // XXX subtract 0x40 later: + appendCompressedPositiveNumber(node.get_allow_approximate() ? 0x41 : 0x40); appendCompressedPositiveNumber(node.get_explore_additional_hits()); + appendDouble(node.get_distance_threshold()); } public: diff --git a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h index 66702fcd85c..040ac751d25 100644 --- a/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h +++ b/searchlib/src/vespa/searchlib/query/tree/stackdumpquerycreator.h @@ -120,8 +120,10 @@ private: Weight weight = queryStack.GetWeight(); bool allow_approximate = queryStack.getAllowApproximate(); uint32_t explore_additional_hits = queryStack.getExploreAdditionalHits(); + double distance_threshold = queryStack.getDistanceThreshold(); builder.add_nearest_neighbor_term(query_tensor_name, field_name, id, weight, - target_num_hits, allow_approximate, explore_additional_hits); + target_num_hits, allow_approximate, explore_additional_hits, + distance_threshold); } else { vespalib::stringref term = queryStack.getTerm(); vespalib::stringref view = queryStack.getIndexName(); diff --git a/searchlib/src/vespa/searchlib/query/tree/termnodes.h b/searchlib/src/vespa/searchlib/query/tree/termnodes.h index 9af424716fb..e112fd6e295 100644 --- a/searchlib/src/vespa/searchlib/query/tree/termnodes.h +++ b/searchlib/src/vespa/searchlib/query/tree/termnodes.h @@ -130,22 +130,26 @@ private: uint32_t _target_num_hits; bool _allow_approximate; uint32_t _explore_additional_hits; + double _distance_threshold; public: NearestNeighborTerm(vespalib::stringref query_tensor_name, vespalib::stringref field_name, int32_t id, Weight weight, uint32_t target_num_hits, - bool allow_approximate, uint32_t explore_additional_hits) + bool allow_approximate, uint32_t explore_additional_hits, + double distance_threshold) : QueryNodeMixinType(field_name, id, weight), _query_tensor_name(query_tensor_name), _target_num_hits(target_num_hits), _allow_approximate(allow_approximate), - _explore_additional_hits(explore_additional_hits) + _explore_additional_hits(explore_additional_hits), + _distance_threshold(distance_threshold) {} virtual ~NearestNeighborTerm() {} const vespalib::string& get_query_tensor_name() const { return _query_tensor_name; } uint32_t get_target_num_hits() const { return _target_num_hits; } bool get_allow_approximate() const { return _allow_approximate; } uint32_t get_explore_additional_hits() const { return _explore_additional_hits; } + double get_distance_threshold() const { return _distance_threshold; } }; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp index d3ecffd1605..01f02748664 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.cpp @@ -52,13 +52,15 @@ struct ConvertCellsSelector NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<Value> query_tensor, - uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, double brute_force_limit) + uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, + double distance_threshold, double brute_force_limit) : ComplexLeafBlueprint(field), _attr_tensor(attr_tensor), _query_tensor(std::move(query_tensor)), _target_num_hits(target_num_hits), _approximate(approximate), _explore_additional_hits(explore_additional_hits), + _distance_threshold(std::numeric_limits<double>::max()), _brute_force_limit(brute_force_limit), _fallback_dist_fun(), _distance_heap(target_num_hits), @@ -72,9 +74,15 @@ NearestNeighborBlueprint::NearestNeighborBlueprint(const queryeval::FieldSpec& f fixup_fun(_query_tensor, _attr_tensor.getTensorType()); _fallback_dist_fun = search::tensor::make_distance_function(_attr_tensor.getConfig().distance_metric(), rct); _dist_fun = _fallback_dist_fun.get(); + assert(_dist_fun); auto nns_index = _attr_tensor.nearest_neighbor_index(); if (nns_index) { _dist_fun = nns_index->distance_function(); + assert(_dist_fun); + } + if (distance_threshold < std::numeric_limits<double>::max()) { + _distance_threshold = _dist_fun->convert_threshold(distance_threshold); + _distance_heap.set_distance_threshold(_distance_threshold); } uint32_t est_hits = _attr_tensor.getNumDocs(); setEstimate(HitEstimate(est_hits, false)); @@ -127,9 +135,9 @@ NearestNeighborBlueprint::perform_top_k() uint32_t k = _target_num_hits; if (_global_filter->has_filter()) { auto filter = _global_filter->filter(); - _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits); + _found_hits = nns_index->find_top_k_with_filter(k, lhs, *filter, k + _explore_additional_hits, _distance_threshold); } else { - _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits); + _found_hits = nns_index->find_top_k(k, lhs, k + _explore_additional_hits, _distance_threshold); } } } diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h index a8a0ff19246..aad43c923a2 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_blueprint.h @@ -24,6 +24,7 @@ private: uint32_t _target_num_hits; bool _approximate; uint32_t _explore_additional_hits; + double _distance_threshold; double _brute_force_limit; search::tensor::DistanceFunction::UP _fallback_dist_fun; const search::tensor::DistanceFunction *_dist_fun; @@ -36,7 +37,9 @@ public: NearestNeighborBlueprint(const queryeval::FieldSpec& field, const tensor::DenseTensorAttribute& attr_tensor, std::unique_ptr<vespalib::eval::Value> query_tensor, - uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, double brute_force_limit); + uint32_t target_num_hits, bool approximate, uint32_t explore_additional_hits, + double distance_threshold, + double brute_force_limit); NearestNeighborBlueprint(const NearestNeighborBlueprint&) = delete; NearestNeighborBlueprint& operator=(const NearestNeighborBlueprint&) = delete; ~NearestNeighborBlueprint(); @@ -45,6 +48,7 @@ public: uint32_t get_target_num_hits() const { return _target_num_hits; } void set_global_filter(const GlobalFilter &global_filter) override; bool may_approximate() const { return _approximate; } + double get_distance_threshold() const { return _distance_threshold; } std::unique_ptr<SearchIterator> createLeafSearch(const search::fef::TermFieldMatchDataArray& tfmda, bool strict) const override; diff --git a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h index 3937dfba2ca..b7bdffd31c1 100644 --- a/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h +++ b/searchlib/src/vespa/searchlib/queryeval/nearest_neighbor_distance_heap.h @@ -15,15 +15,22 @@ class NearestNeighborDistanceHeap { private: std::mutex _lock; size_t _size; + double _distance_threshold; vespalib::PriorityQueue<double, std::greater<double>> _priQ; public: - explicit NearestNeighborDistanceHeap(size_t maxSize) : _size(maxSize), _priQ() { + explicit NearestNeighborDistanceHeap(size_t maxSize) + : _size(maxSize), _distance_threshold(std::numeric_limits<double>::max()), + _priQ() + { _priQ.reserve(maxSize); } + void set_distance_threshold(double distance_threshold) { + _distance_threshold = distance_threshold; + } double distanceLimit() { std::lock_guard<std::mutex> guard(_lock); if (_priQ.size() < _size) { - return std::numeric_limits<double>::max(); + return _distance_threshold; } return _priQ.front(); } diff --git a/searchlib/src/vespa/searchlib/tensor/distance_function.h b/searchlib/src/vespa/searchlib/tensor/distance_function.h index 30ad1876317..724f83b6129 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_function.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_function.h @@ -11,15 +11,24 @@ namespace search::tensor { /** * Interface used to calculate the distance between two n-dimensional vectors. * - * The vectors must be of same size and same type (float or double). + * The vectors must be of same size and same cell type (float or double). * The actual implementation must know which type the vectors are. */ class DistanceFunction { public: using UP = std::unique_ptr<DistanceFunction>; virtual ~DistanceFunction() {} + + // calculate internal distance (comparable) virtual double calc(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs) const = 0; + + // convert threshold (external distance units) to internal units + virtual double convert_threshold(double threshold) const = 0; + + // convert internal distance to rawscore (1.0 / (1.0 + d)) virtual double to_rawscore(double distance) const = 0; + + // calculate internal distance, early return allowed if > limit virtual double calc_with_limit(const vespalib::eval::TypedCells& lhs, const vespalib::eval::TypedCells& rhs, double limit) const = 0; diff --git a/searchlib/src/vespa/searchlib/tensor/distance_functions.h b/searchlib/src/vespa/searchlib/tensor/distance_functions.h index 8db7b1f48f1..2557a51e0d7 100644 --- a/searchlib/src/vespa/searchlib/tensor/distance_functions.h +++ b/searchlib/src/vespa/searchlib/tensor/distance_functions.h @@ -26,6 +26,9 @@ public: assert(sz == rhs_vector.size()); return _computer.squaredEuclideanDistance(&lhs_vector[0], &rhs_vector[0], sz); } + double convert_threshold(double threshold) const override { + return threshold*threshold; + } double to_rawscore(double distance) const override { double d = sqrt(distance); double score = 1.0 / (1.0 + d); @@ -75,6 +78,10 @@ public: double distance = 1.0 - cosine_similarity; // in range [0,2] return distance; } + double convert_threshold(double threshold) const override { + double cosine_similarity = cos(threshold); + return 1.0 - cosine_similarity; + } double to_rawscore(double distance) const override { double cosine_similarity = 1.0 - distance; // should be in in range [-1,1] but roundoff may cause problems: @@ -112,6 +119,9 @@ public: double score = 1.0 - _computer.dotProduct(&lhs_vector[0], &rhs_vector[0], sz); return std::max(0.0, score); } + double convert_threshold(double threshold) const override { + return threshold; + } double to_rawscore(double distance) const override { double score = 1.0 / (1.0 + distance); return score; @@ -135,6 +145,11 @@ public: template <typename FloatType> class GeoDegreesDistance : public DistanceFunction { public: + // in km, as defined by IUGG, see: + // https://en.wikipedia.org/wiki/Earth_radius#Mean_radius + static constexpr double earth_mean_radius = 6371.0088; + static constexpr double degrees_to_radians = M_PI / 180.0; + GeoDegreesDistance() {} // haversine function: static double hav(double angle) { @@ -147,10 +162,10 @@ public: assert(2 == lhs_vector.size()); assert(2 == rhs_vector.size()); // convert to radians: - double lat_A = lhs_vector[0] * M_PI / 180.0; - double lat_B = rhs_vector[0] * M_PI / 180.0; - double lon_A = lhs_vector[1] * M_PI / 180.0; - double lon_B = rhs_vector[1] * M_PI / 180.0; + double lat_A = lhs_vector[0] * degrees_to_radians; + double lat_B = rhs_vector[0] * degrees_to_radians; + double lon_A = lhs_vector[1] * degrees_to_radians; + double lon_B = rhs_vector[1] * degrees_to_radians; double lat_diff = lat_A - lat_B; double lon_diff = lon_A - lon_B; @@ -163,10 +178,16 @@ public: double hav_central_angle = hav_lat + cos(lat_A)*cos(lat_B)*hav_lon; return hav_central_angle; } + double convert_threshold(double threshold) const override { + double half_angle = threshold / (2 * earth_mean_radius); + double rt_hav = sin(half_angle); + return rt_hav * rt_hav; + } double to_rawscore(double distance) const override { double hav_diff = sqrt(distance); // distance in kilometers: - double d = 2 * asin(hav_diff) * 6371.0088; // Earth mean radius + double d = 2 * asin(hav_diff) * earth_mean_radius; + // km to rawscore: return 1.0 / (1.0 + d); } double calc_with_limit(const vespalib::eval::TypedCells& lhs, @@ -197,6 +218,9 @@ public: } return (double)sum; } + double convert_threshold(double threshold) const override { + return threshold; + } double to_rawscore(double distance) const override { double score = 1.0 / (1.0 + distance); return score; diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp index 6488b525b7c..44b2ff2b7f1 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.cpp @@ -281,6 +281,7 @@ HnswIndex::HnswIndex(const DocVectorAccess& vectors, DistanceFunction::UP distan _level_generator(std::move(level_generator)), _cfg(cfg) { + assert(_distance_func); } HnswIndex::~HnswIndex() = default; @@ -534,7 +535,8 @@ struct NeighborsByDocId { std::vector<NearestNeighborIndex::Neighbor> HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, - const BitVector *filter, uint32_t explore_k) const + const BitVector *filter, uint32_t explore_k, + double distance_threshold) const { std::vector<Neighbor> result; FurthestPriQ candidates = top_k_candidates(vector, std::max(k, explore_k), filter); @@ -543,6 +545,7 @@ HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, } result.reserve(candidates.size()); for (const HnswCandidate & hit : candidates.peek()) { + if (hit.distance > distance_threshold) continue; result.emplace_back(hit.docid, hit.distance); } std::sort(result.begin(), result.end(), NeighborsByDocId()); @@ -550,16 +553,18 @@ HnswIndex::top_k_by_docid(uint32_t k, TypedCells vector, } std::vector<NearestNeighborIndex::Neighbor> -HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const +HnswIndex::find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, + double distance_threshold) const { - return top_k_by_docid(k, vector, nullptr, explore_k); + return top_k_by_docid(k, vector, nullptr, explore_k, distance_threshold); } std::vector<NearestNeighborIndex::Neighbor> HnswIndex::find_top_k_with_filter(uint32_t k, TypedCells vector, - const BitVector &filter, uint32_t explore_k) const + const BitVector &filter, uint32_t explore_k, + double distance_threshold) const { - return top_k_by_docid(k, vector, &filter, explore_k); + return top_k_by_docid(k, vector, &filter, explore_k, distance_threshold); } FurthestPriQ diff --git a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h index c07a0642b2e..5bd9d17adc3 100644 --- a/searchlib/src/vespa/searchlib/tensor/hnsw_index.h +++ b/searchlib/src/vespa/searchlib/tensor/hnsw_index.h @@ -123,7 +123,8 @@ protected: void search_layer(const TypedCells& input, uint32_t neighbors_to_find, FurthestPriQ& found_neighbors, uint32_t level, const search::BitVector *filter = nullptr) const; std::vector<Neighbor> top_k_by_docid(uint32_t k, TypedCells vector, - const BitVector *filter, uint32_t explore_k) const; + const BitVector *filter, uint32_t explore_k, + double distance_threshold) const; struct PreparedAddDoc : public PrepareResult { using ReadGuard = vespalib::GenerationHandler::Guard; @@ -166,9 +167,11 @@ public: std::unique_ptr<NearestNeighborIndexSaver> make_saver() const override; bool load(const fileutil::LoadedBuffer& buf) override; - std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k) const override; + std::vector<Neighbor> find_top_k(uint32_t k, TypedCells vector, uint32_t explore_k, + double distance_threshold) const override; std::vector<Neighbor> find_top_k_with_filter(uint32_t k, TypedCells vector, - const BitVector &filter, uint32_t explore_k) const override; + const BitVector &filter, uint32_t explore_k, + double distance_threshold) const override; const DistanceFunction *distance_function() const override { return _distance_func.get(); } FurthestPriQ top_k_candidates(const TypedCells &vector, uint32_t k, const BitVector *filter) const; diff --git a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h index c14da0d058f..fd37cf80720 100644 --- a/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h +++ b/searchlib/src/vespa/searchlib/tensor/nearest_neighbor_index.h @@ -73,13 +73,15 @@ public: virtual std::vector<Neighbor> find_top_k(uint32_t k, vespalib::eval::TypedCells vector, - uint32_t explore_k) const = 0; + uint32_t explore_k, + double distance_threshold) const = 0; // only return neighbors where the corresponding filter bit is set virtual std::vector<Neighbor> find_top_k_with_filter(uint32_t k, vespalib::eval::TypedCells vector, const BitVector &filter, - uint32_t explore_k) const = 0; + uint32_t explore_k, + double distance_threshold) const = 0; virtual const DistanceFunction *distance_function() const = 0; }; diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp index 6e1fb1a0a2f..260ffa1a388 100644 --- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp +++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.cpp @@ -3,8 +3,7 @@ #include "serialized_fast_value_attribute.h" #include "streamed_value_saver.h" #include <vespa/eval/eval/value.h> -#include <vespa/eval/eval/fast_value.hpp> -#include <vespa/eval/streamed/streamed_value_utils.h> +#include <vespa/eval/eval/fast_value.h> #include <vespa/fastlib/io/bufferedfile.h> #include <vespa/searchlib/attribute/readerbase.h> #include <vespa/searchlib/util/fileutil.h> @@ -21,127 +20,10 @@ using namespace vespalib::eval; namespace search::tensor { -namespace { - -struct ValueBlock : LabelBlock { - TypedCells cells; -}; - -class ValueBlockStream { -private: - const StreamedValueStore::DataFromType &_from_type; - LabelBlockStream _label_block_stream; - const char *_cells_ptr; - - size_t dsss() const { return _from_type.dense_subspace_size; } - auto cell_type() const { return _from_type.cell_type; } -public: - ValueBlock next_block() { - auto labels = _label_block_stream.next_block(); - if (labels) { - TypedCells subspace_cells(_cells_ptr, cell_type(), dsss()); - _cells_ptr += CellTypeUtils::mem_size(cell_type(), dsss()); - return ValueBlock{labels, subspace_cells}; - } else { - TypedCells none(nullptr, cell_type(), 0); - return ValueBlock{labels, none}; - } - } - - ValueBlockStream(const StreamedValueStore::DataFromType &from_type, - const StreamedValueStore::StreamedValueData &from_store) - : _from_type(from_type), - _label_block_stream(from_store.num_subspaces, - from_store.labels_buffer, - from_type.num_mapped_dimensions), - _cells_ptr((const char *)from_store.cells_ref.data) - { - _label_block_stream.reset(); - } - - ~ValueBlockStream(); -}; - -ValueBlockStream::~ValueBlockStream() = default; - -void report_problematic_subspace(size_t idx, - const StreamedValueStore::DataFromType &from_type, - const StreamedValueStore::StreamedValueData &from_store) -{ - LOG(error, "PROBLEM: add_mapping returned same index=%zu twice", idx); - FastValueIndex temp_index(from_type.num_mapped_dimensions, - from_store.num_subspaces); - auto from_start = ValueBlockStream(from_type, from_store); - while (auto redo_block = from_start.next_block()) { - if (idx == temp_index.map.add_mapping(redo_block.address)) { - vespalib::string msg = "Block with address[ "; - for (vespalib::stringref ref : redo_block.address) { - msg.append("'").append(ref).append("' "); - } - msg.append("]"); - LOG(error, "%s maps to subspace %zu", msg.c_str(), idx); - } - } -} - -/** - * This Value implementation is almost exactly like FastValue, but - * instead of owning its type and cells it just has a reference to - * data stored elsewhere. - * XXX: we should find a better name for this, and move it - * (together with the helper classes above) to its own file, - * and add associated unit tests. - **/ -class OnlyFastValueIndex : public Value { -private: - const ValueType &_type; - TypedCells _cells; - FastValueIndex my_index; -public: - OnlyFastValueIndex(const ValueType &type, - const StreamedValueStore::DataFromType &from_type, - const StreamedValueStore::StreamedValueData &from_store) - : _type(type), - _cells(from_store.cells_ref), - my_index(from_type.num_mapped_dimensions, - from_store.num_subspaces) - { - assert(_type.cell_type() == _cells.type); - std::vector<vespalib::stringref> address(from_type.num_mapped_dimensions); - auto block_stream = ValueBlockStream(from_type, from_store); - size_t ss = 0; - while (auto block = block_stream.next_block()) { - size_t idx = my_index.map.add_mapping(block.address); - if (idx != ss) { - report_problematic_subspace(idx, from_type, from_store); - } - ++ss; - } - assert(ss == from_store.num_subspaces); - } - - - ~OnlyFastValueIndex(); - - const ValueType &type() const final override { return _type; } - TypedCells cells() const final override { return _cells; } - const Index &index() const final override { return my_index; } - vespalib::MemoryUsage get_memory_usage() const final override { - auto usage = self_memory_usage<OnlyFastValueIndex>(); - usage.merge(my_index.map.estimate_extra_memory_usage()); - return usage; - } -}; - -OnlyFastValueIndex::~OnlyFastValueIndex() = default; - -} - SerializedFastValueAttribute::SerializedFastValueAttribute(stringref name, const Config &cfg) : TensorAttribute(name, cfg, _streamedValueStore), _tensor_type(cfg.tensorType()), - _streamedValueStore(_tensor_type), - _data_from_type(_tensor_type) + _streamedValueStore(_tensor_type) { } @@ -171,10 +53,8 @@ SerializedFastValueAttribute::getTensor(DocId docId) const if (!ref.valid()) { return {}; } - if (auto data_from_store = _streamedValueStore.get_tensor_data(ref)) { - return std::make_unique<OnlyFastValueIndex>(_tensor_type, - _data_from_type, - data_from_store); + if (const auto * ptr = _streamedValueStore.get_tensor_entry(ref)) { + return ptr->create_fast_value_view(_tensor_type); } return {}; } diff --git a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h index a8c1df4913a..cc559d9b758 100644 --- a/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h +++ b/searchlib/src/vespa/searchlib/tensor/serialized_fast_value_attribute.h @@ -19,7 +19,6 @@ namespace search::tensor { class SerializedFastValueAttribute : public TensorAttribute { vespalib::eval::ValueType _tensor_type; StreamedValueStore _streamedValueStore; // data store for serialized tensors - const StreamedValueStore::DataFromType _data_from_type; public: SerializedFastValueAttribute(vespalib::stringref baseFileName, const Config &cfg); virtual ~SerializedFastValueAttribute(); diff --git a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp index c4579880409..ef4b711b86f 100644 --- a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp +++ b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.cpp @@ -1,99 +1,204 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "streamed_value_store.h" -#include "tensor_deserialize.h" #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/value_codec.h> +#include <vespa/eval/eval/fast_value.hpp> #include <vespa/eval/streamed/streamed_value_builder_factory.h> #include <vespa/eval/streamed/streamed_value_view.h> #include <vespa/vespalib/datastore/datastore.hpp> #include <vespa/vespalib/objects/nbostream.h> +#include <vespa/vespalib/util/typify.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/log/log.h> LOG_SETUP(".searchlib.tensor.streamed_value_store"); using vespalib::datastore::Handle; +using vespalib::datastore::EntryRef; using namespace vespalib::eval; +using vespalib::ConstArrayRef; +using vespalib::MemoryUsage; namespace search::tensor { +//----------------------------------------------------------------------------- + namespace { -constexpr size_t MIN_BUFFER_ARRAYS = 1024; - -struct CellsMemBlock { - uint32_t num; - uint32_t total_sz; - const char *ptr; - CellsMemBlock(TypedCells cells) - : num(cells.size), - total_sz(CellTypeUtils::mem_size(cells.type, num)), - ptr((const char *)cells.data) - {} +template <typename CT, typename F> +void each_subspace(const Value &value, size_t num_mapped, size_t dense_size, F f) { + size_t subspace; + std::vector<label_t> addr(num_mapped); + std::vector<label_t*> refs; + refs.reserve(addr.size()); + for (label_t &label: addr) { + refs.push_back(&label); + } + auto cells = value.cells().typify<CT>(); + auto view = value.index().create_view({}); + view->lookup({}); + while (view->next_result(refs, subspace)) { + size_t offset = subspace * dense_size; + f(ConstArrayRef<label_t>(addr), ConstArrayRef<CT>(cells.begin() + offset, dense_size)); + } +} + +using TensorEntry = StreamedValueStore::TensorEntry; + +struct CreateTensorEntry { + template <typename CT> + static TensorEntry::SP invoke(const Value &value, size_t num_mapped, size_t dense_size) { + using EntryImpl = StreamedValueStore::TensorEntryImpl<CT>; + return std::make_shared<EntryImpl>(value, num_mapped, dense_size); + } }; -template<typename T> -void check_alignment(T *ptr, size_t align) +using HandleView = vespalib::SharedStringRepo::HandleView; + +struct MyFastValueView final : Value { + const ValueType &my_type; + FastValueIndex my_index; + TypedCells my_cells; + MyFastValueView(const ValueType &type_ref, HandleView handle_view, TypedCells cells, size_t num_mapped, size_t num_spaces) + : my_type(type_ref), + my_index(num_mapped, handle_view, num_spaces), + my_cells(cells) + { + const std::vector<label_t> &labels = handle_view.handles(); + for (size_t i = 0; i < num_spaces; ++i) { + ConstArrayRef<label_t> addr(&labels[i * num_mapped], num_mapped); + my_index.map.add_mapping(FastAddrMap::hash_labels(addr)); + } + assert(my_index.map.size() == num_spaces); + } + const ValueType &type() const override { return my_type; } + const Value::Index &index() const override { return my_index; } + TypedCells cells() const override { return my_cells; } + MemoryUsage get_memory_usage() const override { + MemoryUsage usage = self_memory_usage<MyFastValueView>(); + usage.merge(my_index.map.estimate_extra_memory_usage()); + return usage; + } +}; + +} // <unnamed> + +//----------------------------------------------------------------------------- + +StreamedValueStore::TensorEntry::~TensorEntry() = default; + +StreamedValueStore::TensorEntry::SP +StreamedValueStore::TensorEntry::create_shared_entry(const Value &value) { - static_assert(sizeof(T) == 1); - size_t ptr_val = (size_t)ptr; - size_t unalign = ptr_val & (align - 1); - assert(unalign == 0); + size_t num_mapped = value.type().count_mapped_dimensions(); + size_t dense_size = value.type().dense_subspace_size(); + return vespalib::typify_invoke<1,TypifyCellType,CreateTensorEntry>(value.type().cell_type(), value, num_mapped, dense_size); } -} // namespace <unnamed> +template <typename CT> +StreamedValueStore::TensorEntryImpl<CT>::TensorEntryImpl(const Value &value, size_t num_mapped, size_t dense_size) + : handles(num_mapped * value.index().size()), + cells() +{ + cells.reserve(dense_size * value.index().size()); + auto store_subspace = [&](auto addr, auto data) { + for (label_t label: addr) { + handles.add(label); + } + for (CT entry: data) { + cells.push_back(entry); + } + }; + each_subspace<CT>(value, num_mapped, dense_size, store_subspace); +} -StreamedValueStore::StreamedValueStore(const ValueType &tensor_type) - : TensorStore(_concreteStore), - _concreteStore(), - _bufferType(RefType::align(1), - MIN_BUFFER_ARRAYS, - RefType::offsetSize() / RefType::align(1)), - _tensor_type(tensor_type), - _data_from_type(_tensor_type) +template <typename CT> +Value::UP +StreamedValueStore::TensorEntryImpl<CT>::create_fast_value_view(const ValueType &type_ref) const { - _store.addType(&_bufferType); - _store.initActiveBuffers(); - size_t align = CellTypeUtils::alignment(_data_from_type.cell_type); - // max alignment we can handle is 8: - assert(align <= 8); - // alignment must be a power of two: - assert((align & (align-1)) == 0); + size_t num_mapped = type_ref.count_mapped_dimensions(); + size_t dense_size = type_ref.dense_subspace_size(); + size_t num_spaces = cells.size() / dense_size; + assert(dense_size * num_spaces == cells.size()); + assert(num_mapped * num_spaces == handles.view().handles().size()); + return std::make_unique<MyFastValueView>(type_ref, handles.view(), TypedCells(cells), num_mapped, num_spaces); } -StreamedValueStore::~StreamedValueStore() +template <typename CT> +void +StreamedValueStore::TensorEntryImpl<CT>::encode_value(const ValueType &type, vespalib::nbostream &target) const { - _store.dropBuffers(); + size_t num_mapped = type.count_mapped_dimensions(); + size_t dense_size = type.dense_subspace_size(); + size_t num_spaces = cells.size() / dense_size; + assert(dense_size * num_spaces == cells.size()); + assert(num_mapped * num_spaces == handles.view().handles().size()); + StreamedValueView my_value(type, num_mapped, TypedCells(cells), num_spaces, handles.view().handles()); + ::vespalib::eval::encode_value(my_value, target); } -std::pair<const char *, uint32_t> -StreamedValueStore::getRawBuffer(RefType ref) const +template <typename CT> +MemoryUsage +StreamedValueStore::TensorEntryImpl<CT>::get_memory_usage() const +{ + MemoryUsage usage = self_memory_usage<TensorEntryImpl<CT>>(); + usage.merge(vector_extra_memory_usage(handles.view().handles())); + usage.merge(vector_extra_memory_usage(cells)); + return usage; +} + +template <typename CT> +StreamedValueStore::TensorEntryImpl<CT>::~TensorEntryImpl() = default; + +//----------------------------------------------------------------------------- + +constexpr size_t MIN_BUFFER_ARRAYS = 8192; + +StreamedValueStore::TensorBufferType::TensorBufferType() + : ParentType(1, MIN_BUFFER_ARRAYS, TensorStoreType::RefType::offsetSize()) { - if (!ref.valid()) { - return std::make_pair(nullptr, 0u); - } - const char *buf = _store.getEntry<char>(ref); - uint32_t len = *reinterpret_cast<const uint32_t *>(buf); - return std::make_pair(buf + sizeof(uint32_t), len); } -Handle<char> -StreamedValueStore::allocRawBuffer(uint32_t size) +void +StreamedValueStore::TensorBufferType::cleanHold(void* buffer, size_t offset, size_t num_elems, CleanContext clean_ctx) { - if (size == 0) { - return Handle<char>(); + TensorEntry::SP* elem = static_cast<TensorEntry::SP*>(buffer) + offset; + for (size_t i = 0; i < num_elems; ++i) { + clean_ctx.extraBytesCleaned((*elem)->get_memory_usage().allocatedBytes()); + *elem = _emptyEntry; + ++elem; } - size_t extSize = size + sizeof(uint32_t); - size_t bufSize = RefType::align(extSize); - auto result = _concreteStore.rawAllocator<char>(_typeId).alloc(bufSize); - *reinterpret_cast<uint32_t *>(result.data) = size; - char *padWritePtr = result.data + extSize; - for (size_t i = extSize; i < bufSize; ++i) { - *padWritePtr++ = 0; +} + +StreamedValueStore::StreamedValueStore(const ValueType &tensor_type) + : TensorStore(_concrete_store), + _concrete_store(), + _tensor_type(tensor_type) +{ + _concrete_store.enableFreeLists(); +} + +StreamedValueStore::~StreamedValueStore() = default; + +EntryRef +StreamedValueStore::add_entry(TensorEntry::SP tensor) +{ + auto ref = _concrete_store.addEntry(tensor); + auto& state = _concrete_store.getBufferState(RefType(ref).bufferId()); + state.incExtraUsedBytes(tensor->get_memory_usage().allocatedBytes()); + return ref; +} + +const StreamedValueStore::TensorEntry * +StreamedValueStore::get_tensor_entry(EntryRef ref) const +{ + if (!ref.valid()) { + return nullptr; } - // Hide length of buffer (first 4 bytes) from users of the buffer. - return Handle<char>(result.ref, result.data + sizeof(uint32_t)); + const auto& entry = _concrete_store.getEntry(ref); + assert(entry); + return entry.get(); } void @@ -102,111 +207,40 @@ StreamedValueStore::holdTensor(EntryRef ref) if (!ref.valid()) { return; } - RefType iRef(ref); - const char *buf = _store.getEntry<char>(iRef); - uint32_t len = *reinterpret_cast<const uint32_t *>(buf); - _concreteStore.holdElem(ref, len + sizeof(uint32_t)); + const auto& tensor = _concrete_store.getEntry(ref); + assert(tensor); + _concrete_store.holdElem(ref, 1, tensor->get_memory_usage().allocatedBytes()); } TensorStore::EntryRef StreamedValueStore::move(EntryRef ref) { if (!ref.valid()) { - return RefType(); + return EntryRef(); } - auto oldraw = getRawBuffer(ref); - auto newraw = allocRawBuffer(oldraw.second); - memcpy(newraw.data, oldraw.first, oldraw.second); - _concreteStore.holdElem(ref, oldraw.second + sizeof(uint32_t)); - return newraw.ref; -} - -StreamedValueStore::StreamedValueData -StreamedValueStore::get_tensor_data(EntryRef ref) const -{ - StreamedValueData retval; - retval.valid = false; - auto raw = getRawBuffer(ref); - if (raw.second == 0u) { - return retval; - } - vespalib::nbostream source(raw.first, raw.second); - uint32_t num_cells = source.readValue<uint32_t>(); - check_alignment(source.peek(), CellTypeUtils::alignment(_data_from_type.cell_type)); - retval.cells_ref = TypedCells(source.peek(), _data_from_type.cell_type, num_cells); - source.adjustReadPos(CellTypeUtils::mem_size(_data_from_type.cell_type, num_cells)); - assert((num_cells % _data_from_type.dense_subspace_size) == 0); - retval.num_subspaces = num_cells / _data_from_type.dense_subspace_size; - retval.labels_buffer = vespalib::ConstArrayRef<char>(source.peek(), source.size()); - retval.valid = true; - return retval; + const auto& old_tensor = _concrete_store.getEntry(ref); + assert(old_tensor); + auto new_ref = add_entry(old_tensor); + _concrete_store.holdElem(ref, 1, old_tensor->get_memory_usage().allocatedBytes()); + return new_ref; } bool StreamedValueStore::encode_tensor(EntryRef ref, vespalib::nbostream &target) const { - if (auto data = get_tensor_data(ref)) { - StreamedValueView value( - _tensor_type, _data_from_type.num_mapped_dimensions, - data.cells_ref, data.num_subspaces, data.labels_buffer); - vespalib::eval::encode_value(value, target); + if (const auto * entry = get_tensor_entry(ref)) { + entry->encode_value(_tensor_type, target); return true; } else { return false; } } -void -StreamedValueStore::serialize_labels(const Value::Index &index, - vespalib::nbostream &target) const -{ - uint32_t num_subspaces = index.size(); - uint32_t num_mapped_dims = _data_from_type.num_mapped_dimensions; - std::vector<vespalib::stringref> labels(num_mapped_dims * num_subspaces); - auto view = index.create_view({}); - view->lookup({}); - std::vector<vespalib::stringref> addr(num_mapped_dims); - std::vector<vespalib::stringref *> addr_refs; - for (auto & label : addr) { - addr_refs.push_back(&label); - } - size_t subspace; - for (size_t ss = 0; ss < num_subspaces; ++ss) { - bool ok = view->next_result(addr_refs, subspace); - assert(ok); - size_t idx = subspace * num_mapped_dims; - for (auto label : addr) { - labels[idx++] = label; - } - } - bool ok = view->next_result(addr_refs, subspace); - assert(!ok); - for (auto label : labels) { - target.writeSmallString(label); - } -} - TensorStore::EntryRef StreamedValueStore::store_tensor(const Value &tensor) { assert(tensor.type() == _tensor_type); - CellsMemBlock cells_mem(tensor.cells()); - vespalib::nbostream stream; - stream << uint32_t(cells_mem.num); - serialize_labels(tensor.index(), stream); - size_t mem_size = stream.size() + cells_mem.total_sz; - auto raw = allocRawBuffer(mem_size); - char *target = raw.data; - memcpy(target, stream.peek(), sizeof(uint32_t)); - stream.adjustReadPos(sizeof(uint32_t)); - target += sizeof(uint32_t); - check_alignment(target, CellTypeUtils::alignment(_data_from_type.cell_type)); - memcpy(target, cells_mem.ptr, cells_mem.total_sz); - target += cells_mem.total_sz; - memcpy(target, stream.peek(), stream.size()); - target += stream.size(); - assert(target <= raw.data + mem_size); - return raw.ref; + return add_entry(TensorEntry::create_shared_entry(tensor)); } TensorStore::EntryRef diff --git a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.h b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.h index de94dc043d3..3a9d9a0b7b4 100644 --- a/searchlib/src/vespa/searchlib/tensor/streamed_value_store.h +++ b/searchlib/src/vespa/searchlib/tensor/streamed_value_store.h @@ -5,87 +5,71 @@ #include "tensor_store.h" #include <vespa/eval/eval/value_type.h> #include <vespa/eval/eval/value.h> +#include <vespa/eval/streamed/streamed_value.h> #include <vespa/vespalib/objects/nbostream.h> -#include <vespa/vespalib/util/typify.h> +#include <vespa/vespalib/util/shared_string_repo.h> namespace search::tensor { /** - * Class for storing tensors in memory, with a special serialization - * format that can be used directly to make a StreamedValueView. - * - * The tensor type is owned by the store itself and will not be - * serialized at all. - * - * The parameters for serialization (see DataFromType) are: - * - number of mapped dimensions [MD] - * - dense subspace size [DS] - * - size of each cell [CS] - currently 4 (float) or 8 (double) - * - alignment for cells - currently 4 (float) or 8 (double) - * While the tensor value to be serialized has: - * - number of dense subspaces [ND] - * - labels for dense subspaces, ND * MD strings - * - cell values, ND * DS cells (each either float or double) - * The serialization format looks like: - * - * [bytes] : [format] : [description] - * 4 : n.b.o. uint32_ t : num cells = ND * DS - * CS * ND * DS : native float or double : cells - * (depends) : n.b.o. strings : ND * MD label strings - * - * Here, n.b.o. means network byte order, or more precisely - * it's the format vespalib::nbostream uses for the given data type, - * including strings (where exact format depends on the string length). - * Note that the only unpredictably-sized data (the labels) are kept - * last. - * If we ever make a "hbostream" which uses host byte order, we - * could switch to that instead since these data are only kept in - * memory. + * Class for StreamedValue tensors in memory. */ class StreamedValueStore : public TensorStore { public: - using RefType = vespalib::datastore::AlignedEntryRefT<22, 3>; - using DataStoreType = vespalib::datastore::DataStoreT<RefType>; + using Value = vespalib::eval::Value; + using ValueType = vespalib::eval::ValueType; + using Handles = vespalib::SharedStringRepo::StrongHandles; + using MemoryUsage = vespalib::MemoryUsage; - struct StreamedValueData { - bool valid; - vespalib::eval::TypedCells cells_ref; - size_t num_subspaces; - vespalib::ConstArrayRef<char> labels_buffer; - operator bool() const { return valid; } + // interface for tensor entries + struct TensorEntry { + using SP = std::shared_ptr<TensorEntry>; + virtual Value::UP create_fast_value_view(const ValueType &type_ref) const = 0; + virtual void encode_value(const ValueType &type, vespalib::nbostream &target) const = 0; + virtual MemoryUsage get_memory_usage() const = 0; + virtual ~TensorEntry(); + static TensorEntry::SP create_shared_entry(const Value &value); }; - struct DataFromType { - uint32_t num_mapped_dimensions; - uint32_t dense_subspace_size; - vespalib::eval::CellType cell_type; - - DataFromType(const vespalib::eval::ValueType& type) - : num_mapped_dimensions(type.count_mapped_dimensions()), - dense_subspace_size(type.dense_subspace_size()), - cell_type(type.cell_type()) - {} + // implementation of tensor entries + template <typename CT> + struct TensorEntryImpl : public TensorEntry { + Handles handles; + std::vector<CT> cells; + TensorEntryImpl(const Value &value, size_t num_mapped, size_t dense_size); + Value::UP create_fast_value_view(const ValueType &type_ref) const override; + void encode_value(const ValueType &type, vespalib::nbostream &target) const override; + MemoryUsage get_memory_usage() const override; + ~TensorEntryImpl() override; }; private: - DataStoreType _concreteStore; - vespalib::datastore::BufferType<char> _bufferType; - vespalib::eval::ValueType _tensor_type; - DataFromType _data_from_type; - - void serialize_labels(const vespalib::eval::Value::Index &index, - vespalib::nbostream &target) const; + // Note: Must use SP (instead of UP) because of fallbackCopy() and initializeReservedElements() in BufferType, + // and implementation of move(). + using TensorStoreType = vespalib::datastore::DataStore<TensorEntry::SP>; - std::pair<const char *, uint32_t> getRawBuffer(RefType ref) const; - vespalib::datastore::Handle<char> allocRawBuffer(uint32_t size); + class TensorBufferType : public vespalib::datastore::BufferType<TensorEntry::SP> { + private: + using ParentType = BufferType<TensorEntry::SP>; + using ParentType::_emptyEntry; + using CleanContext = typename ParentType::CleanContext; + public: + TensorBufferType(); + virtual void cleanHold(void* buffer, size_t offset, size_t num_elems, CleanContext clean_ctx) override; + }; + TensorStoreType _concrete_store; + const vespalib::eval::ValueType _tensor_type; + EntryRef add_entry(TensorEntry::SP tensor); public: StreamedValueStore(const vespalib::eval::ValueType &tensor_type); - virtual ~StreamedValueStore(); + ~StreamedValueStore() override; + + using RefType = TensorStoreType::RefType; - virtual void holdTensor(EntryRef ref) override; - virtual EntryRef move(EntryRef ref) override; + void holdTensor(EntryRef ref) override; + EntryRef move(EntryRef ref) override; - StreamedValueData get_tensor_data(EntryRef ref) const; + const TensorEntry * get_tensor_entry(EntryRef ref) const; bool encode_tensor(EntryRef ref, vespalib::nbostream &target) const; EntryRef store_tensor(const vespalib::eval::Value &tensor); diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp index 9f1e6bde06b..8dcca2c7b89 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.cpp @@ -45,6 +45,7 @@ Domain::Domain(const string &domainName, const string & baseDir, Executor & exec _sessionId(1), _syncMonitor(), _pendingSync(false), + _done_sync_tasks(), _name(domainName), _parts(), _lock(), @@ -206,13 +207,16 @@ Domain::getSynced() const void -Domain::triggerSyncNow() +Domain::triggerSyncNow(std::unique_ptr<vespalib::Executor::Task> done_sync_task) { { std::unique_lock guard(_currentChunkMonitor); commitAndTransferResponses(guard); } std::unique_lock guard(_syncMonitor); + if (done_sync_task) { + _done_sync_tasks.push_back(std::move(done_sync_task)); + } if (!_pendingSync) { _pendingSync = true; _executor.execute(makeLambdaTask([this, domainPart= getActivePart()]() { @@ -220,6 +224,11 @@ Domain::triggerSyncNow() std::lock_guard monitorGuard(_syncMonitor); _pendingSync = false; _syncCond.notify_all(); + for (auto &task : _done_sync_tasks) { + auto failed_task = _executor.execute(std::move(task)); + assert(!failed_task); + } + _done_sync_tasks.clear(); })); } } @@ -316,7 +325,7 @@ Domain::optionallyRotateFile(SerialNum serialNum) { DomainPart::SP dp = getActivePart(); if (dp->byteSize() > _config.getPartSizeLimit()) { waitPendingSync(_syncMonitor, _syncCond, _pendingSync); - triggerSyncNow(); + triggerSyncNow({}); waitPendingSync(_syncMonitor, _syncCond, _pendingSync); dp->close(); dp = std::make_shared<DomainPart>(_name, dir(), serialNum, _config.getEncoding(), diff --git a/searchlib/src/vespa/searchlib/transactionlog/domain.h b/searchlib/src/vespa/searchlib/transactionlog/domain.h index c9eb6385b15..5a80758dd0b 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/domain.h +++ b/searchlib/src/vespa/searchlib/transactionlog/domain.h @@ -36,7 +36,7 @@ public: SerialNum begin() const; SerialNum end() const; SerialNum getSynced() const; - void triggerSyncNow(); + void triggerSyncNow(std::unique_ptr<vespalib::Executor::Task> done_sync_task); bool getMarkedDeleted() const { return _markedDeleted; } void markDeleted() { _markedDeleted = true; } @@ -92,6 +92,7 @@ private: std::mutex _syncMonitor; std::condition_variable _syncCond; bool _pendingSync; + std::vector<std::unique_ptr<vespalib::Executor::Task>> _done_sync_tasks; vespalib::string _name; DomainPartList _parts; mutable std::mutex _lock; diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp index 0c0c9186e12..0cebc056569 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.cpp @@ -6,9 +6,9 @@ #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/io/fileutil.h> #include <vespa/vespalib/util/exceptions.h> +#include <vespa/vespalib/util/lambdatask.h> #include <vespa/fnet/frt/supervisor.h> #include <vespa/fnet/frt/rpcrequest.h> -#include <vespa/fnet/task.h> #include <vespa/fnet/transport.h> #include <fstream> #include <thread> @@ -28,24 +28,25 @@ namespace search::transactionlog { namespace { -class SyncHandler : public FNET_Task +class SyncHandler : public std::enable_shared_from_this<SyncHandler> { + std::atomic<bool> & _closed; FRT_RPCRequest & _req; Domain::SP _domain; TransLogServer::Session::SP _session; SerialNum _syncTo; public: - SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const Domain::SP &domain, - const TransLogServer::Session::SP &session, SerialNum syncTo); + SyncHandler(std::atomic<bool>& closed, FRT_RPCRequest *req, const Domain::SP &domain, + const TransLogServer::Session::SP &session, SerialNum syncTo) noexcept; - ~SyncHandler() override; - void PerformTask() override; + ~SyncHandler(); + void poll(); }; -SyncHandler::SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const Domain::SP &domain, - const TransLogServer::Session::SP &session, SerialNum syncTo) - : FNET_Task(supervisor->GetScheduler()), +SyncHandler::SyncHandler(std::atomic<bool>& closed, FRT_RPCRequest *req, const Domain::SP &domain, + const TransLogServer::Session::SP &session, SerialNum syncTo) noexcept + : _closed(closed), _req(*req), _domain(domain), _session(session), @@ -56,20 +57,20 @@ SyncHandler::SyncHandler(FRT_Supervisor *supervisor, FRT_RPCRequest *req, const SyncHandler::~SyncHandler() = default; void -SyncHandler::PerformTask() +SyncHandler::poll() { SerialNum synced(_domain->getSynced()); if (_session->getDown() || _domain->getMarkedDeleted() || - synced >= _syncTo) { + _closed.load(std::memory_order_acquire) || + synced >= _syncTo) + { FRT_Values &rvals = *_req.GetReturn(); rvals.AddInt32(0); rvals.AddInt64(synced); _req.Return(); - delete this; } else { - _domain->triggerSyncNow(); - Schedule(0.05); // Retry in 0.05 seconds + _domain->triggerSyncNow(vespalib::makeLambdaTask([self = shared_from_this()]() { self->poll(); })); } } @@ -101,7 +102,8 @@ TransLogServer::TransLogServer(const vespalib::string &name, int listenPort, con _supervisor(std::make_unique<FRT_Supervisor>(_transport.get())), _domains(), _reqQ(), - _fileHeaderContext(fileHeaderContext) + _fileHeaderContext(fileHeaderContext), + _closed(false) { int retval(0); if ((retval = makeDirectory(_baseDir.c_str())) == 0) { @@ -146,8 +148,10 @@ TransLogServer::TransLogServer(const vespalib::string &name, int listenPort, con TransLogServer::~TransLogServer() { + _closed = true; stop(); join(); + _executor.sync(); _executor.shutdown(); _executor.sync(); _transport->ShutDown(true); @@ -719,10 +723,9 @@ TransLogServer::domainSync(FRT_RPCRequest *req) req->Return(); return; } - - SyncHandler *syncHandler = new SyncHandler(_supervisor.get(), req, domain, session, syncTo); - - syncHandler->ScheduleNow(); + + auto syncHandler = std::make_shared<SyncHandler>(_closed, req, domain, session, syncTo); + syncHandler->poll(); } } diff --git a/searchlib/src/vespa/searchlib/transactionlog/translogserver.h b/searchlib/src/vespa/searchlib/transactionlog/translogserver.h index 37133615c1e..3c6efa20550 100644 --- a/searchlib/src/vespa/searchlib/transactionlog/translogserver.h +++ b/searchlib/src/vespa/searchlib/transactionlog/translogserver.h @@ -7,6 +7,7 @@ #include <vespa/document/util/queue.h> #include <vespa/fnet/frt/invokable.h> #include <shared_mutex> +#include <atomic> class FRT_Supervisor; class FNET_Transport; @@ -92,6 +93,7 @@ private: std::mutex _fileLock; // Protects the creating and deleting domains including file system operations. document::Queue<FRT_RPCRequest *> _reqQ; const common::FileHeaderContext &_fileHeaderContext; + std::atomic<bool> _closed; }; } diff --git a/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h b/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h index baf38035210..008e9055e57 100644 --- a/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h +++ b/searchlib/src/vespa/searchlib/util/postingpriorityqueue.h @@ -221,7 +221,7 @@ PostingPriorityQueue<IN>::merge(OUT &out, uint32_t heapLimit, const IFlushToken& (this->*mergeHeapFunc)(out, flush_token); return; } - for (;;) { + while (!flush_token.stop_requested()) { if (_vec.size() == 1) { void (*mergeOneFunc)(OUT &out, IN &in, const IFlushToken& flush_token) = &PostingPriorityQueue<IN>::mergeOne; diff --git a/staging_vespalib/src/tests/clock/clock_test.cpp b/staging_vespalib/src/tests/clock/clock_test.cpp index 4a06787fce5..06eee21d2b9 100644 --- a/staging_vespalib/src/tests/clock/clock_test.cpp +++ b/staging_vespalib/src/tests/clock/clock_test.cpp @@ -7,7 +7,14 @@ using vespalib::Clock; using vespalib::duration; using vespalib::steady_time; +using vespalib::steady_clock; +void waitForMovement(steady_time start, Clock & clock, vespalib::duration timeout) { + steady_time startOsClock = steady_clock::now(); + while ((clock.getTimeNS() <= start) && ((steady_clock::now() - startOsClock) < timeout)) { + std::this_thread::sleep_for(1ms); + } +} TEST("Test that clock is ticking forward") { @@ -15,10 +22,12 @@ TEST("Test that clock is ticking forward") { FastOS_ThreadPool pool(0x10000); ASSERT_TRUE(pool.NewThread(clock.getRunnable(), nullptr) != nullptr); steady_time start = clock.getTimeNS(); - std::this_thread::sleep_for(5s); + waitForMovement(start, clock, 10s); steady_time stop = clock.getTimeNS(); EXPECT_TRUE(stop > start); - std::this_thread::sleep_for(6s); + std::this_thread::sleep_for(1s); + start = clock.getTimeNS(); + waitForMovement(start, clock, 10s); clock.stop(); steady_time stop2 = clock.getTimeNS(); EXPECT_TRUE(stop2 > stop); diff --git a/storage/src/vespa/storage/config/stor-communicationmanager.def b/storage/src/vespa/storage/config/stor-communicationmanager.def index d674ee96aa3..799490ba114 100644 --- a/storage/src/vespa/storage/config/stor-communicationmanager.def +++ b/storage/src/vespa/storage/config/stor-communicationmanager.def @@ -58,7 +58,8 @@ skip_thread bool default=false ## Whether to use direct P2P RPC protocol for all StorageAPI communication ## instead of going via MessageBus. -use_direct_storageapi_rpc bool default=false +## Deprecated, and will soon be gone as it is default on. +use_direct_storageapi_rpc bool default=true ## The number of network (FNET) threads used by the shared rpc resource. rpc.num_network_threads int default=2 restart diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.cpp b/storage/src/vespa/storage/storageserver/communicationmanager.cpp index 6021673d472..25942926155 100644 --- a/storage/src/vespa/storage/storageserver/communicationmanager.cpp +++ b/storage/src/vespa/storage/storageserver/communicationmanager.cpp @@ -270,8 +270,7 @@ CommunicationManager::CommunicationManager(StorageComponentRegister& compReg, co _closed(false), _docApiConverter(configUri, std::make_shared<PlaceHolderBucketResolver>()), _thread(), - _skip_thread(false), - _use_direct_storageapi_rpc(false) + _skip_thread(false) { _component.registerMetricUpdateHook(*this, framework::SecondTime(5)); _component.registerMetric(_metrics); @@ -375,7 +374,6 @@ void CommunicationManager::configure(std::unique_ptr<CommunicationManagerConfig> { // Only allow dynamic (live) reconfiguration of message bus limits. _skip_thread = config->skipThread; - _use_direct_storageapi_rpc = config->useDirectStorageapiRpc; if (_mbus) { configureMessageBusLimits(*config); if (_mbus->getRPCNetwork().getPort() != config->mbusport) { @@ -574,9 +572,7 @@ CommunicationManager::sendCommand( case api::StorageMessageAddress::Protocol::STORAGE: { LOG(debug, "Send to %s: %s", address.toString().c_str(), msg->toString().c_str()); - if (_use_direct_storageapi_rpc.load(std::memory_order_relaxed) && - _storage_api_rpc_service->target_supports_direct_rpc(address)) - { + if (_storage_api_rpc_service->target_supports_direct_rpc(address)) { _storage_api_rpc_service->send_rpc_v1_request(msg); } else { auto cmd = std::make_unique<mbusprot::StorageCommand>(msg); diff --git a/storage/src/vespa/storage/storageserver/communicationmanager.h b/storage/src/vespa/storage/storageserver/communicationmanager.h index db88f95af6d..5a828e49678 100644 --- a/storage/src/vespa/storage/storageserver/communicationmanager.h +++ b/storage/src/vespa/storage/storageserver/communicationmanager.h @@ -119,7 +119,6 @@ private: DocumentApiConverter _docApiConverter; framework::Thread::UP _thread; std::atomic<bool> _skip_thread; - std::atomic<bool> _use_direct_storageapi_rpc; void updateMetrics(const MetricLockGuard &) override; void enqueue_or_process(std::shared_ptr<api::StorageMessage> msg); diff --git a/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/SuspendMojo.java b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/SuspendMojo.java new file mode 100644 index 00000000000..52057e237d7 --- /dev/null +++ b/vespa-maven-plugin/src/main/java/ai/vespa/hosted/plugin/SuspendMojo.java @@ -0,0 +1,23 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.hosted.plugin; + +import org.apache.maven.plugins.annotations.Mojo; +import org.apache.maven.plugins.annotations.Parameter; + +/** + * Sets suspension status for a Vespa application deployment. + * + * @author jonmv + */ +@Mojo(name = "suspend") +public class SuspendMojo extends AbstractVespaDeploymentMojo { + + @Parameter(property = "suspend", required = true) + private boolean suspend; + + @Override + protected void doExecute() { + getLog().info(controller.suspend(id, zoneOf(environment, region), suspend)); + } + +} diff --git a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java index 47530e2c6e2..1fefe2e0c7e 100644 --- a/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java +++ b/vespaclient-container-plugin/src/main/java/com/yahoo/document/restapi/resource/DocumentV1ApiHandler.java @@ -652,7 +652,7 @@ public class DocumentV1ApiHandler extends AbstractRequestHandler { return true; if ( ! lock.tryLock()) - throw new IllegalStateException("Concurrent attempts at dispatch — this is a bug"); + throw new IllegalStateException("Concurrent attempts at dispatch — this is a bug"); try { if (operation == null) diff --git a/vespajlib/pom.xml b/vespajlib/pom.xml index 68639d30ab2..f8be4c6f8f7 100644 --- a/vespajlib/pom.xml +++ b/vespajlib/pom.xml @@ -31,6 +31,11 @@ <groupId>net.java.dev.jna</groupId> <artifactId>jna</artifactId> </dependency> + <dependency> + <groupId>io.airlift</groupId> + <artifactId>aircompressor</artifactId> + <scope>compile</scope> + </dependency> <!-- provided scope --> <dependency> diff --git a/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobControl.java b/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobControl.java index 583337203ab..2a682bcb4db 100644 --- a/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobControl.java +++ b/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobControl.java @@ -51,7 +51,7 @@ public class JobControl { public void run(String jobSimpleClassName) { var job = startedJobs.get(jobSimpleClassName); if (job == null) throw new IllegalArgumentException("No such job '" + jobSimpleClassName + "'"); - job.lockAndMaintain(); + job.lockAndMaintain(true); } /** Acquire lock for running given job */ diff --git a/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobMetrics.java b/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobMetrics.java index 483057a828d..d4d60723cbe 100644 --- a/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobMetrics.java +++ b/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/JobMetrics.java @@ -1,7 +1,6 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.concurrent.maintenance; -import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.function.BiConsumer; @@ -14,7 +13,7 @@ public class JobMetrics { private final BiConsumer<String, Long> metricConsumer; - private final Map<String, Long> incompleteRuns = new ConcurrentHashMap<>(); + private final ConcurrentHashMap<String, Long> incompleteRuns = new ConcurrentHashMap<>(); public JobMetrics(BiConsumer<String, Long> metricConsumer) { this.metricConsumer = metricConsumer; diff --git a/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/Maintainer.java b/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/Maintainer.java index 9fb5172ab0a..daad1f8fb4b 100644 --- a/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/Maintainer.java +++ b/vespajlib/src/main/java/com/yahoo/concurrent/maintenance/Maintainer.java @@ -1,7 +1,6 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.concurrent.maintenance; -import com.google.common.util.concurrent.UncheckedTimeoutException; import com.yahoo.net.HostName; import java.time.Duration; @@ -33,15 +32,15 @@ public abstract class Maintainer implements Runnable { private final ScheduledExecutorService service; private final AtomicBoolean shutDown = new AtomicBoolean(); - public Maintainer(String name, Duration interval, Instant startedAt, JobControl jobControl, JobMetrics jobMetrics, List<String> clusterHostnames) { - this(name, interval, staggeredDelay(interval, startedAt, HostName.getLocalhost(), clusterHostnames), jobControl, jobMetrics); - } - - public Maintainer(String name, Duration interval, Duration initialDelay, JobControl jobControl, JobMetrics jobMetrics) { + public Maintainer(String name, Duration interval, Instant startedAt, JobControl jobControl, + JobMetrics jobMetrics, List<String> clusterHostnames) { this.name = name; this.interval = requireInterval(interval); this.jobControl = Objects.requireNonNull(jobControl); this.jobMetrics = Objects.requireNonNull(jobMetrics); + Objects.requireNonNull(startedAt); + Objects.requireNonNull(clusterHostnames); + Duration initialDelay = staggeredDelay(interval, startedAt, HostName.getLocalhost(), clusterHostnames); service = new ScheduledThreadPoolExecutor(1, r -> new Thread(r, name() + "-worker")); service.scheduleAtFixedRate(this, initialDelay.toMillis(), interval.toMillis(), TimeUnit.MILLISECONDS); jobControl.started(name(), this); @@ -49,17 +48,7 @@ public abstract class Maintainer implements Runnable { @Override public void run() { - log.log(Level.FINE, () -> "Running " + this.getClass().getSimpleName()); - try { - if (jobControl.isActive(name())) { - lockAndMaintain(); - } - } catch (UncheckedTimeoutException ignored) { - // Another actor is running this job - } catch (Throwable e) { - log.log(Level.WARNING, this + " failed. Will retry in " + interval.toMinutes() + " minutes", e); - } - log.log(Level.FINE, () -> "Finished " + this.getClass().getSimpleName()); + lockAndMaintain(false); } /** Starts shutdown of this, typically by shutting down executors. {@link #awaitShutdown()} waits for shutdown to complete. */ @@ -92,17 +81,18 @@ public abstract class Maintainer implements Runnable { protected Duration interval() { return interval; } /** Run this while holding the job lock */ - @SuppressWarnings("unused") - public final void lockAndMaintain() { + public final void lockAndMaintain(boolean force) { + if (!force && !jobControl.isActive(name())) return; + log.log(Level.FINE, () -> "Running " + this.getClass().getSimpleName()); + jobMetrics.recordRunOf(name()); try (var lock = jobControl.lockJob(name())) { - try { - jobMetrics.recordRunOf(name()); - if (maintain()) jobMetrics.recordSuccessOf(name()); - } finally { - // Always forward metrics - jobMetrics.forward(name()); - } + if (maintain()) jobMetrics.recordSuccessOf(name()); + } catch (Throwable e) { + log.log(Level.WARNING, this + " failed. Will retry in " + interval.toMinutes() + " minutes", e); + } finally { + jobMetrics.forward(name()); } + log.log(Level.FINE, () -> "Finished " + this.getClass().getSimpleName()); } /** Returns the simple name of this job */ diff --git a/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/JobControlTest.java b/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/JobControlTest.java index a0ca9b529c5..139a2901cd3 100644 --- a/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/JobControlTest.java +++ b/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/JobControlTest.java @@ -19,8 +19,9 @@ public class JobControlTest { String job1 = "Job1"; String job2 = "Job2"; - TestMaintainer maintainer1 = new TestMaintainer(job1, jobControl); - TestMaintainer maintainer2 = new TestMaintainer(job2, jobControl); + JobMetrics metrics = new JobMetrics((job, instant) -> {}); + TestMaintainer maintainer1 = new TestMaintainer(job1, jobControl, metrics); + TestMaintainer maintainer2 = new TestMaintainer(job2, jobControl, metrics); assertEquals(2, jobControl.jobs().size()); assertTrue(jobControl.jobs().contains(job1)); assertTrue(jobControl.jobs().contains(job2)); @@ -61,7 +62,7 @@ public class JobControlTest { public void testJobControlMayDeactivateJobs() { JobControlStateMock state = new JobControlStateMock(); JobControl jobControl = new JobControl(state); - TestMaintainer mockMaintainer = new TestMaintainer(null, jobControl); + TestMaintainer mockMaintainer = new TestMaintainer(null, jobControl, new JobMetrics((job, instant) -> {})); assertTrue(jobControl.jobs().contains("TestMaintainer")); diff --git a/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/MaintainerTest.java b/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/MaintainerTest.java index 2bfaad894a5..e881d4b3ff6 100644 --- a/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/MaintainerTest.java +++ b/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/MaintainerTest.java @@ -1,6 +1,7 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.concurrent.maintenance; +import com.google.common.util.concurrent.UncheckedTimeoutException; import org.junit.Test; import java.time.Duration; @@ -15,6 +16,8 @@ import static org.junit.Assert.assertEquals; */ public class MaintainerTest { + private final JobControl jobControl = new JobControl(new JobControlStateMock()); + @Test public void staggering() { List<String> cluster = List.of("cfg1", "cfg2", "cfg3"); @@ -41,7 +44,7 @@ public class MaintainerTest { public void success_metric() { AtomicLong consecutiveFailures = new AtomicLong(); JobMetrics jobMetrics = new JobMetrics((job, count) -> consecutiveFailures.set(count)); - TestMaintainer maintainer = new TestMaintainer(jobMetrics); + TestMaintainer maintainer = new TestMaintainer(null, jobControl, jobMetrics); // Maintainer fails twice in a row maintainer.successOnNextRun(false).run(); @@ -58,12 +61,16 @@ public class MaintainerTest { assertEquals(0, consecutiveFailures.get()); // Maintainer throws - maintainer.throwOnNextRun(true).run(); + maintainer.throwOnNextRun(new RuntimeException()).run(); assertEquals(1, consecutiveFailures.get()); // Maintainer recovers - maintainer.throwOnNextRun(false).run(); + maintainer.throwOnNextRun(null).run(); assertEquals(0, consecutiveFailures.get()); + + // Lock exception is treated as a failure + maintainer.throwOnNextRun(new UncheckedTimeoutException()).run(); + assertEquals(1, consecutiveFailures.get()); } } diff --git a/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/TestMaintainer.java b/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/TestMaintainer.java index 5eae643fe40..ea32af60208 100644 --- a/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/TestMaintainer.java +++ b/vespajlib/src/test/java/com/yahoo/concurrent/maintenance/TestMaintainer.java @@ -2,6 +2,8 @@ package com.yahoo.concurrent.maintenance; import java.time.Duration; +import java.time.Instant; +import java.util.List; /** * @author mpolden @@ -10,18 +12,10 @@ class TestMaintainer extends Maintainer { private int totalRuns = 0; private boolean success = true; - private boolean throwing = false; + private RuntimeException exceptionToThrow = null; public TestMaintainer(String name, JobControl jobControl, JobMetrics jobMetrics) { - super(name, Duration.ofDays(1), Duration.ofDays(1), jobControl, jobMetrics); - } - - public TestMaintainer(JobMetrics jobMetrics) { - this(null, new JobControl(new JobControlStateMock()), jobMetrics); - } - - public TestMaintainer(String name, JobControl jobControl) { - this(name, jobControl, new JobMetrics((job, instant) -> {})); + super(name, Duration.ofDays(1), Instant.now(), jobControl, jobMetrics, List.of()); } public int totalRuns() { @@ -33,14 +27,14 @@ class TestMaintainer extends Maintainer { return this; } - public TestMaintainer throwOnNextRun(boolean throwing) { - this.throwing = throwing; + public TestMaintainer throwOnNextRun(RuntimeException e) { + this.exceptionToThrow = e; return this; } @Override protected boolean maintain() { - if (throwing) throw new RuntimeException("Maintenance run failed"); + if (exceptionToThrow != null) throw exceptionToThrow; totalRuns++; return success; } diff --git a/vespalib/src/vespa/vespalib/util/shared_string_repo.cpp b/vespalib/src/vespa/vespalib/util/shared_string_repo.cpp index a5ec9540a1b..e529b1190d9 100644 --- a/vespalib/src/vespa/vespalib/util/shared_string_repo.cpp +++ b/vespalib/src/vespa/vespalib/util/shared_string_repo.cpp @@ -7,6 +7,18 @@ namespace vespalib { SharedStringRepo::Partition::~Partition() = default; void +SharedStringRepo::Partition::find_leaked_entries(size_t my_idx) const +{ + for (size_t i = 0; i < _entries.size(); ++i) { + if (!_entries[i].is_free()) { + size_t id = (((i << PART_BITS) | my_idx) + 1); + fprintf(stderr, "WARNING: shared_string_repo: leaked string id: %zu ('%s')\n", + id, _entries[i].str().c_str()); + } + } +} + +void SharedStringRepo::Partition::make_entries(size_t hint) { hint = std::max(hint, _entries.size() + 1); @@ -20,7 +32,12 @@ SharedStringRepo::Partition::make_entries(size_t hint) } SharedStringRepo::SharedStringRepo() = default; -SharedStringRepo::~SharedStringRepo() = default; +SharedStringRepo::~SharedStringRepo() +{ + for (size_t p = 0; p < _partitions.size(); ++p) { + _partitions[p].find_leaked_entries(p); + } +} SharedStringRepo & SharedStringRepo::get() @@ -44,6 +61,13 @@ SharedStringRepo::StrongHandles::StrongHandles(size_t expect_size) _handles.reserve(expect_size); } +SharedStringRepo::StrongHandles::StrongHandles(StrongHandles &&rhs) + : _repo(rhs._repo), + _handles(std::move(rhs._handles)) +{ + assert(rhs._handles.empty()); +} + SharedStringRepo::StrongHandles::~StrongHandles() { for (uint32_t handle: _handles) { diff --git a/vespalib/src/vespa/vespalib/util/shared_string_repo.h b/vespalib/src/vespa/vespalib/util/shared_string_repo.h index afdd3a289f9..f7137984caa 100644 --- a/vespalib/src/vespa/vespalib/util/shared_string_repo.h +++ b/vespalib/src/vespa/vespalib/util/shared_string_repo.h @@ -10,6 +10,7 @@ #include <mutex> #include <vector> #include <array> +#include <cassert> namespace vespalib { @@ -34,21 +35,43 @@ private: class alignas(64) Partition { public: - struct Entry { + class Entry { + public: static constexpr uint32_t npos = -1; - uint32_t hash; - uint32_t ref_cnt; - vespalib::string str; - explicit Entry(uint32_t next) noexcept : hash(), ref_cnt(next), str() {} + private: + uint32_t _hash; + uint32_t _ref_cnt; + vespalib::string _str; + public: + explicit Entry(uint32_t next) noexcept + : _hash(next), _ref_cnt(npos), _str() {} + constexpr uint32_t hash() const noexcept { return _hash; } + constexpr const vespalib::string &str() const noexcept { return _str; } + constexpr bool is_free() const noexcept { return (_ref_cnt == npos); } uint32_t init(const AltKey &key) { - uint32_t next = ref_cnt; - hash = key.hash; - ref_cnt = 1; - str = key.str; + uint32_t next = _hash; + _hash = key.hash; + _ref_cnt = 1; + _str = key.str; return next; } void fini(uint32_t next) { - ref_cnt = next; + _hash = next; + _ref_cnt = npos; + // to reset or not to reset... + // _str.reset(); + } + vespalib::string as_string() const { + assert(!is_free()); + return _str; + } + void add_ref() { + assert(!is_free()); + ++_ref_cnt; + } + bool sub_ref() { + assert(!is_free()); + return (--_ref_cnt == 0); } }; struct Key { @@ -64,7 +87,7 @@ private: Equal(const std::vector<Entry> &entries_in) : entries(entries_in) {} Equal(const Equal &rhs) = default; bool operator()(const Key &a, const Key &b) const { return (a.idx == b.idx); } - bool operator()(const Key &a, const AltKey &b) const { return ((a.hash == b.hash) && (entries[a.idx].str == b.str)); } + bool operator()(const Key &a, const AltKey &b) const { return ((a.hash == b.hash) && (entries[a.idx].str() == b.str)); } }; using HashType = hashtable<Key,Key,Hash,Equal,Identity,hashtable_base::and_modulator>; @@ -92,12 +115,13 @@ private: make_entries(64); } ~Partition(); + void find_leaked_entries(size_t my_idx) const; uint32_t resolve(const AltKey &alt_key) { std::lock_guard guard(_lock); auto pos = _hash.find(alt_key); if (pos != _hash.end()) { - ++_entries[pos->idx].ref_cnt; + _entries[pos->idx].add_ref(); return pos->idx; } else { uint32_t idx = make_entry(alt_key); @@ -108,19 +132,19 @@ private: vespalib::string as_string(uint32_t idx) { std::lock_guard guard(_lock); - return _entries[idx].str; + return _entries[idx].as_string(); } void copy(uint32_t idx) { std::lock_guard guard(_lock); - ++_entries[idx].ref_cnt; + _entries[idx].add_ref(); } void reclaim(uint32_t idx) { std::lock_guard guard(_lock); Entry &entry = _entries[idx]; - if (--entry.ref_cnt == 0) { - _hash.erase(Key{idx, entry.hash}); + if (entry.sub_ref()) { + _hash.erase(Key{idx, entry.hash()}); entry.fini(_free); _free = idx; } @@ -178,8 +202,9 @@ public: class Handle { private: uint32_t _id; + Handle(uint32_t weak_id) : _id(get().copy(weak_id)) {} public: - Handle() : _id(0) {} + Handle() noexcept : _id(0) {} Handle(vespalib::stringref str) : _id(get().resolve(str)) {} Handle(const Handle &rhs) : _id(get().copy(rhs._id)) {} Handle &operator=(const Handle &rhs) { @@ -196,9 +221,15 @@ public: rhs._id = 0; return *this; } - bool operator==(const Handle &rhs) const { return (_id == rhs._id); } - uint32_t id() const { return _id; } + // NB: not lexical sorting order, but can be used in maps + bool operator<(const Handle &rhs) const noexcept { return (_id < rhs._id); } + bool operator==(const Handle &rhs) const noexcept { return (_id == rhs._id); } + bool operator!=(const Handle &rhs) const noexcept { return (_id != rhs._id); } + uint32_t id() const noexcept { return _id; } + uint32_t hash() const noexcept { return _id; } vespalib::string as_string() const { return get().as_string(_id); } + static Handle handle_from_id(uint32_t weak_id) { return Handle(weak_id); } + static vespalib::string string_from_id(uint32_t weak_id) { return get().as_string(weak_id); } ~Handle() { get().reclaim(_id); } }; @@ -229,8 +260,20 @@ public: std::vector<uint32_t> _handles; public: StrongHandles(size_t expect_size); + StrongHandles(StrongHandles &&rhs); + StrongHandles(const StrongHandles &) = delete; + StrongHandles &operator=(const StrongHandles &) = delete; + StrongHandles &operator=(StrongHandles &&) = delete; ~StrongHandles(); - void add(vespalib::stringref str) { _handles.push_back(_repo.resolve(str)); } + uint32_t add(vespalib::stringref str) { + uint32_t id = _repo.resolve(str); + _handles.push_back(id); + return id; + } + void add(uint32_t handle) { + uint32_t id = _repo.copy(handle); + _handles.push_back(id); + } HandleView view() const { return HandleView(_handles); } }; }; diff --git a/vespalog/src/test/threads/testthreads.cpp b/vespalog/src/test/threads/testthreads.cpp index 465d2c6e3f8..aef7f844c7e 100644 --- a/vespalog/src/test/threads/testthreads.cpp +++ b/vespalog/src/test/threads/testthreads.cpp @@ -85,7 +85,7 @@ public: int ThreadTester::Main() { - std::cerr << "Testing that logging is threadsafe. 30 sec test.\n"; + std::cerr << "Testing that logging is threadsafe. 5 sec test.\n"; FastOS_ThreadPool pool(128 * 1024); const int numWriters = 30; @@ -107,8 +107,8 @@ ThreadTester::Main() steady_clock::time_point start = steady_clock::now(); // Reduced runtime to half as the test now repeats itself to test with - // buffering. (To avoid test taking a minute) - while ((steady_clock::now() - start) < 15s) { + // buffering. (To avoid test taking 5 seconds) + while ((steady_clock::now() - start) < 2.5s) { unlink(_argv[1]); std::this_thread::sleep_for(1ms); } @@ -117,7 +117,7 @@ ThreadTester::Main() loggers[i]->_useLogBuffer = true; } start = steady_clock::now(); - while ((steady_clock::now() - start) < 15s) { + while ((steady_clock::now() - start) < 2.5s) { unlink(_argv[1]); std::this_thread::sleep_for(1ms); } |