diff options
15 files changed, 164 insertions, 72 deletions
diff --git a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java index 422ceba8074..595cd97e6b6 100644 --- a/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java +++ b/config-model-api/src/main/java/com/yahoo/config/model/api/OnnxModelCost.java @@ -4,6 +4,7 @@ package com.yahoo.config.model.api; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; /** @@ -11,7 +12,7 @@ import com.yahoo.config.application.api.DeployLogger; */ public interface OnnxModelCost { - Calculator newCalculator(DeployLogger logger); + Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger); interface Calculator { long aggregatedModelCostInBytes(); @@ -20,7 +21,7 @@ public interface OnnxModelCost { } static OnnxModelCost disabled() { - return (__) -> new Calculator() { + return (__, ___) -> new Calculator() { @Override public long aggregatedModelCostInBytes() { return 0; } @Override public void registerModel(ApplicationFile path) {} @Override public void registerModel(ModelReference ref) {} diff --git a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java index 76733872882..9794cfe4ad7 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/DefaultOnnxModelCost.java @@ -4,8 +4,10 @@ package com.yahoo.vespa.model; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.api.OnnxModelCost; +import com.yahoo.vespa.model.ml.OnnxModelProbe; import java.io.IOException; import java.net.URI; @@ -29,16 +31,18 @@ import static com.yahoo.yolean.Exceptions.uncheck; public class DefaultOnnxModelCost implements OnnxModelCost { @Override - public Calculator newCalculator(DeployLogger logger) { - return new CalculatorImpl(logger); + public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) { + return new CalculatorImpl(appPkg, logger); } private static class CalculatorImpl implements Calculator { private final DeployLogger log; + private final ApplicationPackage appPkg; private final ConcurrentMap<String, Long> modelCost = new ConcurrentHashMap<>(); - private CalculatorImpl(DeployLogger log) { + private CalculatorImpl(ApplicationPackage appPkg, DeployLogger log) { + this.appPkg = appPkg; this.log = log; } @@ -52,7 +56,17 @@ public class DefaultOnnxModelCost implements OnnxModelCost { String path = f.getPath().getRelative(); if (alreadyAnalyzed(path)) return; log.log(Level.FINE, () -> "Register model '%s'".formatted(path)); - deductJvmHeapSizeWithModelCost(f.exists() ? f.getSize() : 0, path); + if (f.exists()) { + var memoryStats = OnnxModelProbe.probeMemoryStats(appPkg, f.getPath()).orElse(null); + if (memoryStats != null) { + log.log(Level.FINE, () -> "Register model '%s' with memory stats: %s".formatted(path, memoryStats)); + deductJvmHeapSizeWithModelCost(f.getSize(), memoryStats, path); + } else { + deductJvmHeapSizeWithModelCost(f.getSize(), path); + } + } else { + deductJvmHeapSizeWithModelCost(0, path); + } } @Override @@ -92,6 +106,13 @@ public class DefaultOnnxModelCost implements OnnxModelCost { modelCost.put(source, estimatedCost); } + private void deductJvmHeapSizeWithModelCost(long size, OnnxModelProbe.MemoryStats stats, String source) { + long estimatedCost = (long)(1.1D * stats.vmSize()); + log.log(Level.FINE, () -> + "Estimated %s footprint for model of size %s ('%s')".formatted(mb(estimatedCost), mb(size), source)); + modelCost.put(source, estimatedCost); + } + private boolean alreadyAnalyzed(String source) { return modelCost.containsKey(source); } private static String mb(long bytes) { return "%dMB".formatted(bytes / (1024*1024)); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java index 0945dfaf54a..4e97b20a3a9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/ApplicationContainerCluster.java @@ -134,7 +134,8 @@ public final class ApplicationContainerCluster extends ContainerCluster<Applicat heapSizePercentageOfAvailableMemory = deployState.featureFlags().heapSizePercentage() > 0 ? Math.min(99, deployState.featureFlags().heapSizePercentage()) : defaultHeapSizePercentageOfAvailableMemory; - onnxModelCost = deployState.onnxModelCost().newCalculator(deployState.getDeployLogger()); + onnxModelCost = deployState.onnxModelCost().newCalculator( + deployState.getApplicationPackage(), deployState.getDeployLogger()); logger = deployState.getDeployLogger(); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java index 7c86267c1b6..38dda3e29ff 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/OnnxModelProbe.java @@ -18,6 +18,9 @@ import java.io.InputStream; import java.io.OutputStream; import java.nio.charset.StandardCharsets; import java.util.Map; +import java.util.Optional; + +import static com.yahoo.yolean.Exceptions.uncheck; /** * Defers to 'vespa-analyze-onnx-model' to determine the output type given @@ -29,6 +32,7 @@ import java.util.Map; public class OnnxModelProbe { private static final String binary = "vespa-analyze-onnx-model"; + private static final ObjectMapper jsonParser = new ObjectMapper(); static TensorType probeModel(ApplicationPackage app, Path modelPath, String outputName, Map<String, TensorType> inputTypes) { TensorType outputType = TensorType.empty; @@ -41,8 +45,9 @@ public class OnnxModelProbe { // Otherwise, run vespa-analyze-onnx-model if the model is available if (outputType.equals(TensorType.empty) && app.getFile(modelPath).exists()) { String jsonInput = createJsonInput(app.getFileReference(modelPath).getAbsolutePath(), inputTypes); - String jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); + var jsonOutput = callVespaAnalyzeOnnxModel(jsonInput); outputType = outputTypeFromJson(jsonOutput, outputName); + writeMemoryStats(app, modelPath, MemoryStats.fromJson(jsonOutput)); if ( ! outputType.equals(TensorType.empty)) { writeProbedOutputType(app, modelPath, contextKey, outputType); } @@ -53,6 +58,22 @@ public class OnnxModelProbe { return outputType; } + public static Optional<MemoryStats> probeMemoryStats(ApplicationPackage app, Path modelPath) { + return Optional.of(app.getFile(memoryStatsPath(modelPath))) + .filter(ApplicationFile::exists) + .map(file -> MemoryStats.fromJson(uncheck(() -> jsonParser.readTree(file.createReader())))); + } + + private static void writeMemoryStats(ApplicationPackage app, Path modelPath, MemoryStats memoryStats) throws IOException { + String path = app.getFileReference(memoryStatsPath(modelPath)).getAbsolutePath(); + IOUtils.writeFile(path, memoryStats.toJson().toPrettyString(), false); + } + + private static Path memoryStatsPath(Path modelPath) { + var fileName = OnnxModelInfo.asValidIdentifier(modelPath.getRelative()) + ".memory_stats"; + return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(fileName); + } + private static String createContextKey(String onnxName, Map<String, TensorType> inputTypes) { StringBuilder key = new StringBuilder().append(onnxName).append(":"); inputTypes.entrySet().stream().sorted(Map.Entry.comparingByKey()) @@ -95,9 +116,7 @@ public class OnnxModelProbe { return TensorType.empty; } - private static TensorType outputTypeFromJson(String json, String outputName) throws IOException { - ObjectMapper m = new ObjectMapper(); - JsonNode root = m.readTree(json); + private static TensorType outputTypeFromJson(JsonNode root, String outputName) throws IOException { if ( ! root.isObject() || ! root.has("outputs")) { return TensorType.empty; } @@ -123,7 +142,7 @@ public class OnnxModelProbe { return out.toString(); } - private static String callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { + private static JsonNode callVespaAnalyzeOnnxModel(String jsonInput) throws IOException, InterruptedException { StringBuilder output = new StringBuilder(); ProcessBuilder processBuilder = new ProcessBuilder(binary, "--probe-types"); @@ -148,7 +167,16 @@ public class OnnxModelProbe { throw new IllegalArgumentException("Error from '" + binary + "'. Return code: " + returnCode + ". " + "Output: '" + output + "'"); } - return output.toString(); + return jsonParser.readTree(output.toString()); + } + + public record MemoryStats(long vmSize, long vmRss) { + static MemoryStats fromJson(JsonNode json) { + return new MemoryStats(json.get("vm_size").asLong(), json.get("vm_rss").asLong()); + } + JsonNode toJson() { + return jsonParser.createObjectNode().put("vm_size", vmSize).put("vm_rss", vmRss); + } } } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java index 245887a5d03..447614b8396 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/JvmHeapSizeValidatorTest.java @@ -4,6 +4,7 @@ package com.yahoo.vespa.model.application.validation; import com.yahoo.config.ModelReference; import com.yahoo.config.application.api.ApplicationFile; +import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.config.application.api.DeployLogger; import com.yahoo.config.model.NullConfigModelRegistry; import com.yahoo.config.model.api.OnnxModelCost; @@ -112,7 +113,7 @@ class JvmHeapSizeValidatorTest { ModelCostDummy(long modelCost) { this.modelCost = modelCost; } - @Override public Calculator newCalculator(DeployLogger logger) { return this; } + @Override public Calculator newCalculator(ApplicationPackage appPkg, DeployLogger logger) { return this; } @Override public long aggregatedModelCostInBytes() { return totalCost.get(); } @Override public void registerModel(ApplicationFile path) {} diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java index 091836a1eea..27cd5e7e576 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/RoutingController.java @@ -525,7 +525,8 @@ public class RoutingController { } public boolean generatedEndpointsEnabled(ApplicationId instance) { - return randomizedEndpoints.with(FetchVector.Dimension.INSTANCE_ID, instance.serializedForm()).value(); + return randomizedEndpoints.with(FetchVector.Dimension.INSTANCE_ID, instance.serializedForm()) + .with(FetchVector.Dimension.TENANT_ID, instance.tenant().value()).value(); } private static void requireGeneratedEndpoints(GeneratedEndpointList generatedEndpoints, boolean declared) { diff --git a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java index 70eeb2b9f6c..ed383175cc3 100644 --- a/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java +++ b/controller-server/src/main/java/com/yahoo/vespa/hosted/controller/maintenance/CertificatePoolMaintainer.java @@ -69,7 +69,7 @@ public class CertificatePoolMaintainer extends ControllerMaintainer { // Create metric for available certificates in the pool as a fraction of configured size int poolSize = certPoolSize.value(); long available = certificatePool.stream().filter(c -> c.state() == UnassignedCertificate.State.ready).count(); - metric.set(ControllerMetrics.CERTIFICATE_POOL_AVAILABLE.baseName(), (poolSize > 0 ? (available/poolSize) : 1.0), metric.createContext(Map.of())); + metric.set(ControllerMetrics.CERTIFICATE_POOL_AVAILABLE.baseName(), (poolSize > 0 ? ((double)available/poolSize) : 1.0), metric.createContext(Map.of())); if (certificatePool.size() < poolSize) { provisionRandomizedCertificate(); diff --git a/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp b/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp index d7a854e0afc..6c6f05fd5e2 100644 --- a/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp +++ b/searchlib/src/tests/attribute/searchable/attribute_weighted_set_blueprint_test.cpp @@ -29,14 +29,14 @@ using namespace search::attribute::test; namespace { void -setupAttributeManager(MockAttributeManager &manager) +setupAttributeManager(MockAttributeManager &manager, bool isFilter) { AttributeVector::DocId docId; { - AttributeVector::SP attr_sp = AttributeFactory::createAttribute("integer", Config(BasicType("int64"))); + AttributeVector::SP attr_sp = AttributeFactory::createAttribute("integer", Config(BasicType("int64")).setIsFilter(isFilter)); manager.addAttribute(attr_sp); - IntegerAttribute *attr = (IntegerAttribute*)(attr_sp.get()); + auto *attr = (IntegerAttribute*)(attr_sp.get()); for (size_t i = 1; i < 10; ++i) { attr->addDoc(docId); assert(i == docId); @@ -45,10 +45,10 @@ setupAttributeManager(MockAttributeManager &manager) } } { - AttributeVector::SP attr_sp = AttributeFactory::createAttribute("string", Config(BasicType("string"))); + AttributeVector::SP attr_sp = AttributeFactory::createAttribute("string", Config(BasicType("string")).setIsFilter(isFilter)); manager.addAttribute(attr_sp); - StringAttribute *attr = (StringAttribute*)(attr_sp.get()); + auto *attr = (StringAttribute*)(attr_sp.get()); for (size_t i = 1; i < 10; ++i) { attr->addDoc(docId); assert(i == docId); @@ -58,9 +58,9 @@ setupAttributeManager(MockAttributeManager &manager) } { AttributeVector::SP attr_sp = AttributeFactory::createAttribute( - "multi", Config(BasicType("int64"), search::attribute::CollectionType("array"))); + "multi", Config(BasicType("int64"), search::attribute::CollectionType("array")).setIsFilter(isFilter)); manager.addAttribute(attr_sp); - IntegerAttribute *attr = (IntegerAttribute*)(attr_sp.get()); + auto *attr = (IntegerAttribute*)(attr_sp.get()); for (size_t i = 1; i < 10; ++i) { attr->addDoc(docId); assert(i == docId); @@ -78,35 +78,43 @@ struct WS { TermFieldHandle handle; std::vector<std::pair<std::string, uint32_t> > tokens; - WS(IAttributeManager & manager) : attribute_manager(manager), layout(), handle(layout.allocTermField(fieldId)), tokens() { + explicit WS(IAttributeManager & manager) + : attribute_manager(manager), + layout(), handle(layout.allocTermField(fieldId)), + tokens() + { MatchData::UP tmp = layout.createMatchData(); ASSERT_TRUE(tmp->resolveTermField(handle)->getFieldId() == fieldId); } WS &add(const std::string &token, uint32_t weight) { - tokens.push_back(std::make_pair(token, weight)); + tokens.emplace_back(token, weight); return *this; } Node::UP createNode() const { - SimpleWeightedSetTerm *node = new SimpleWeightedSetTerm(tokens.size(), "view", 0, Weight(0)); - for (size_t i = 0; i < tokens.size(); ++i) { - node->addTerm(tokens[i].first, Weight(tokens[i].second)); + auto *node = new SimpleWeightedSetTerm(tokens.size(), "view", 0, Weight(0)); + for (const auto & token : tokens) { + node->addTerm(token.first, Weight(token.second)); } return Node::UP(node); } - bool isGenericSearch(Searchable &searchable, const std::string &field, bool strict) const { + SearchIterator::UP + createSearch(Searchable &searchable, const std::string &field, bool strict) const { AttributeContext ac(attribute_manager); FakeRequestContext requestContext(&ac); MatchData::UP md = layout.createMatchData(); Node::UP node = createNode(); FieldSpecList fields; - fields.add(FieldSpec(field, fieldId, handle)); + fields.add(FieldSpec(field, fieldId, handle, ac.getAttribute(field)->getIsFilter())); queryeval::Blueprint::UP bp = searchable.createBlueprint(requestContext, fields, *node); bp->fetchPostings(queryeval::ExecuteInfo::create(strict)); SearchIterator::UP sb = bp->createSearch(*md, strict); - return (dynamic_cast<WeightedSetTermSearch*>(sb.get()) != 0); + return sb; + } + bool isWeightedSetTermSearch(Searchable &searchable, const std::string &field, bool strict) const { + return dynamic_cast<WeightedSetTermSearch *>(createSearch(searchable, field, strict).get()) != nullptr; } FakeResult search(Searchable &searchable, const std::string &field, bool strict) const { @@ -140,23 +148,58 @@ struct WS { } // namespace <unnamed> +void test_tokens(bool isFilter, const std::vector<uint32_t> & docs) { + MockAttributeManager manager; + setupAttributeManager(manager, isFilter); + AttributeBlueprintFactory adapter; + + FakeResult expect = FakeResult(); + WS ws = WS(manager); + for (uint32_t doc : docs) { + auto docS = vespalib::stringify(doc); + int32_t weight = doc * 10; + expect.doc(doc).weight(weight).pos(0); + ws.add(docS, weight); + } + + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "integer", true)); + EXPECT_TRUE(!ws.isWeightedSetTermSearch(adapter, "integer", false)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "string", true)); + EXPECT_TRUE(!ws.isWeightedSetTermSearch(adapter, "string", false)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", true)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", false)); + + EXPECT_EQUAL(expect, ws.search(adapter, "integer", true)); + EXPECT_EQUAL(expect, ws.search(adapter, "integer", false)); + EXPECT_EQUAL(expect, ws.search(adapter, "string", true)); + EXPECT_EQUAL(expect, ws.search(adapter, "string", false)); + EXPECT_EQUAL(expect, ws.search(adapter, "multi", true)); + EXPECT_EQUAL(expect, ws.search(adapter, "multi", false)); +} TEST("attribute_weighted_set_test") { + test_tokens(false, {3, 5, 7}); + test_tokens(true, {3, 5, 7}); + test_tokens(false, {3}); +} + +TEST("attribute_weighted_set_single_token_filter_lifted_out") { MockAttributeManager manager; - setupAttributeManager(manager); + setupAttributeManager(manager, true); AttributeBlueprintFactory adapter; - FakeResult expect = FakeResult() - .doc(3).elem(0).weight(30).pos(0) - .doc(5).elem(0).weight(50).pos(0) - .doc(7).elem(0).weight(70).pos(0); - WS ws = WS(manager).add("7", 70).add("5", 50).add("3", 30); - - EXPECT_TRUE(ws.isGenericSearch(adapter, "integer", true)); - EXPECT_TRUE(!ws.isGenericSearch(adapter, "integer", false)); - EXPECT_TRUE(ws.isGenericSearch(adapter, "string", true)); - EXPECT_TRUE(!ws.isGenericSearch(adapter, "string", false)); - EXPECT_TRUE(ws.isGenericSearch(adapter, "multi", true)); - EXPECT_TRUE(ws.isGenericSearch(adapter, "multi", false)); + FakeResult expect = FakeResult().doc(3).elem(0).weight(30).pos(0); + WS ws = WS(manager).add("3", 30); + + EXPECT_EQUAL("search::FilterAttributeIteratorStrict<search::attribute::SingleNumericSearchContext<long, search::attribute::NumericMatcher<long> > >", + ws.createSearch(adapter, "integer", true)->getClassName()); + EXPECT_EQUAL("search::FilterAttributeIteratorT<search::attribute::SingleNumericSearchContext<long, search::attribute::NumericMatcher<long> > >", + ws.createSearch(adapter, "integer", false)->getClassName()); + EXPECT_EQUAL("search::FilterAttributeIteratorStrict<search::attribute::SingleEnumSearchContext<char const*, search::attribute::StringSearchContext> >", + ws.createSearch(adapter, "string", true)->getClassName()); + EXPECT_EQUAL("search::FilterAttributeIteratorT<search::attribute::SingleEnumSearchContext<char const*, search::attribute::StringSearchContext> >", + ws.createSearch(adapter, "string", false)->getClassName()); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", true)); + EXPECT_TRUE(ws.isWeightedSetTermSearch(adapter, "multi", false)); EXPECT_EQUAL(expect, ws.search(adapter, "integer", true)); EXPECT_EQUAL(expect, ws.search(adapter, "integer", false)); diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp index 1519bb14554..b4cdd621b71 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_blueprint_factory.cpp @@ -337,10 +337,7 @@ public: if (tfmda.size() == 1) { // search in exactly one field fef::TermFieldMatchData &tfmd = *tfmda[0]; - return search::common::create_location_iterator(tfmd, - _attribute.getNumDocs(), - strict, - _location); + return common::create_location_iterator(tfmd, _attribute.getNumDocs(), strict, _location); } else { LOG(debug, "wrong size tfmda: %zu (fallback to old location iterator)\n", tfmda.size()); } diff --git a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp index 108128eeb39..94c560a0dae 100644 --- a/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/attribute/attribute_weighted_set_blueprint.cpp @@ -30,7 +30,7 @@ protected: const attribute::IAttributeVector &attribute() const { return _attr; } public: - UseAttr(const attribute::IAttributeVector & attr) + explicit UseAttr(const attribute::IAttributeVector & attr) : _attr(attr) {} }; @@ -40,7 +40,7 @@ class UseStringEnum : public UseAttr { public: using TokenT = uint32_t; - UseStringEnum(const IAttributeVector & attr) + explicit UseStringEnum(const IAttributeVector & attr) : UseAttr(attr) {} auto mapToken(const ISearchContext &context) const { return attribute().findFoldedEnums(context.queryTerm()->getTerm()); @@ -56,7 +56,7 @@ class UseInteger : public UseAttr { public: using TokenT = uint64_t; - UseInteger(const IAttributeVector & attr) : UseAttr(attr) {} + explicit UseInteger(const IAttributeVector & attr) : UseAttr(attr) {} std::vector<int64_t> mapToken(const ISearchContext &context) const { std::vector<int64_t> result; Int64Range range(context.getAsIntegerTerm()); @@ -157,6 +157,10 @@ AttributeWeightedSetBlueprint::createLeafSearch(const fef::TermFieldMatchDataArr assert(tfmda.size() == 1); assert(getState().numFields() == 1); fef::TermFieldMatchData &tfmd = *tfmda[0]; + bool field_is_filter = getState().fields()[0].isFilter(); + if (field_is_filter && (_contexts.size() == 1)) { + return _contexts[0]->createIterator(&tfmd, strict); + } if (strict) { // use generic weighted set search fef::MatchDataLayout layout; auto handle = layout.allocTermField(tfmd.getFieldId()); @@ -167,7 +171,6 @@ AttributeWeightedSetBlueprint::createLeafSearch(const fef::TermFieldMatchDataArr // TODO: pass ownership with unique_ptr children[i] = _contexts[i]->createIterator(child_tfmd, true).release(); } - bool field_is_filter = getState().fields()[0].isFilter(); return queryeval::WeightedSetTermSearch::create(children, tfmd, field_is_filter, _weights, std::move(match_data)); } else { // use attribute filter optimization bool isString = (_attr.isStringType() && _attr.hasEnum()); @@ -182,18 +185,16 @@ AttributeWeightedSetBlueprint::createLeafSearch(const fef::TermFieldMatchDataArr } queryeval::SearchIterator::UP -AttributeWeightedSetBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) const +AttributeWeightedSetBlueprint::createFilterSearch(bool strict, FilterConstraint) const { - (void) constraint; std::vector<std::unique_ptr<queryeval::SearchIterator>> children; children.reserve(_contexts.size()); for (auto& context : _contexts) { - auto wrapper = std::make_unique<search::queryeval::FilterWrapper>(1); + auto wrapper = std::make_unique<queryeval::FilterWrapper>(1); wrapper->wrap(context->createIterator(wrapper->tfmda()[0], strict)); children.emplace_back(std::move(wrapper)); } - search::queryeval::UnpackInfo unpack_info; - return search::queryeval::OrSearch::create(std::move(children), strict, unpack_info); + return queryeval::OrSearch::create(std::move(children), strict, queryeval::UnpackInfo()); } void diff --git a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp index 3c0bae00047..e2566c94f1c 100644 --- a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp +++ b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.cpp @@ -11,8 +11,8 @@ class DocumentWeightOrFilterSearchImpl : public DocumentWeightOrFilterSearch { AttributeIteratorPack _children; public: - DocumentWeightOrFilterSearchImpl(AttributeIteratorPack&& children); - ~DocumentWeightOrFilterSearchImpl(); + explicit DocumentWeightOrFilterSearchImpl(AttributeIteratorPack&& children); + ~DocumentWeightOrFilterSearchImpl() override; void doSeek(uint32_t docId) override; @@ -67,7 +67,7 @@ DocumentWeightOrFilterSearchImpl::doUnpack(uint32_t) { } -std::unique_ptr<search::queryeval::SearchIterator> +std::unique_ptr<queryeval::SearchIterator> DocumentWeightOrFilterSearch::create(std::vector<DocumentWeightIterator>&& children) { if (children.empty()) { diff --git a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h index 62be883ab52..c601856573f 100644 --- a/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h +++ b/searchlib/src/vespa/searchlib/attribute/document_weight_or_filter_search.h @@ -9,15 +9,15 @@ namespace search::attribute { * Filter iterator on top of document weight iterators with OR semantics used during * calculation of global filter for weighted set terms, wand terms and dot product terms. */ -class DocumentWeightOrFilterSearch : public search::queryeval::SearchIterator +class DocumentWeightOrFilterSearch : public queryeval::SearchIterator { protected: DocumentWeightOrFilterSearch() - : search::queryeval::SearchIterator() + : queryeval::SearchIterator() { } public: - static std::unique_ptr<search::queryeval::SearchIterator> create(std::vector<DocumentWeightIterator>&& children); + static std::unique_ptr<queryeval::SearchIterator> create(std::vector<DocumentWeightIterator>&& children); }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp index 77d9875bf69..97f6bc2e6f8 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.cpp @@ -33,10 +33,8 @@ WeightedSetTermMatchingElementsSearch::WeightedSetTermMatchingElementsSearch(con _search() { _tfmda.add(&_tfmd); - auto generic_search = bp.createLeafSearch(_tfmda, false); - auto weighted_set_term_search = dynamic_cast<WeightedSetTermSearch *>(generic_search.get()); - generic_search.release(); - _search.reset(weighted_set_term_search); + _search.reset(static_cast<WeightedSetTermSearch *>(bp.createLeafSearch(_tfmda, false).release())); + } WeightedSetTermMatchingElementsSearch::~WeightedSetTermMatchingElementsSearch() = default; diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h index 0e3c82444d7..9c8d6d88329 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_blueprint.h @@ -18,7 +18,7 @@ class WeightedSetTermBlueprint : public ComplexLeafBlueprint std::vector<Blueprint::UP> _terms; public: - WeightedSetTermBlueprint(const FieldSpec &field); + explicit WeightedSetTermBlueprint(const FieldSpec &field); WeightedSetTermBlueprint(const WeightedSetTermBlueprint &) = delete; WeightedSetTermBlueprint &operator=(const WeightedSetTermBlueprint &) = delete; ~WeightedSetTermBlueprint() override; diff --git a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp index ee3978705cf..8478a0d3c35 100644 --- a/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/weighted_set_term_search.cpp @@ -21,7 +21,7 @@ private: struct CmpDocId { const uint32_t *termPos; - CmpDocId(const uint32_t *tp) : termPos(tp) {} + explicit CmpDocId(const uint32_t *tp) : termPos(tp) {} bool operator()(const ref_t &a, const ref_t &b) const { return (termPos[a] < termPos[b]); } @@ -29,7 +29,7 @@ private: struct CmpWeight { const int32_t *weight; - CmpWeight(const int32_t *w) : weight(w) {} + explicit CmpWeight(const int32_t *w) : weight(w) {} bool operator()(const ref_t &a, const ref_t &b) const { return (weight[a] > weight[b]); } @@ -61,7 +61,7 @@ private: } public: - WeightedSetTermSearchImpl(search::fef::TermFieldMatchData &tmd, + WeightedSetTermSearchImpl(fef::TermFieldMatchData &tmd, bool field_is_filter, const std::vector<int32_t> &weights, IteratorPack &&iteratorPack) @@ -180,7 +180,7 @@ WeightedSetTermSearch::create(const std::vector<SearchIterator *> &children, //----------------------------------------------------------------------------- SearchIterator::UP -WeightedSetTermSearch::create(search::fef::TermFieldMatchData &tmd, +WeightedSetTermSearch::create(fef::TermFieldMatchData &tmd, bool field_is_filter, const std::vector<int32_t> &weights, std::vector<DocumentWeightIterator> &&iterators) |