diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2023-06-12 12:06:43 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-06-12 12:06:43 +0200 |
commit | 212677f8d36e5b41414b0f03dec370dbfe749448 (patch) | |
tree | 58900f5afe63295c5fa24ddee20957427adac9a0 | |
parent | 39f44c787daa051d8308ce6871e346a61a0d2f8d (diff) | |
parent | c63710e6b1da936e773bce6a3da6b11539478172 (diff) |
Merge pull request #27363 from vespa-engine/balder/refactor-attributenode-for-lookups
Prepare AttributeNode to handle both full vector extraction and singl…
5 files changed, 171 insertions, 107 deletions
diff --git a/searchlib/src/vespa/searchlib/aggregation/grouping.cpp b/searchlib/src/vespa/searchlib/aggregation/grouping.cpp index daff16e9029..b61cb585899 100644 --- a/searchlib/src/vespa/searchlib/aggregation/grouping.cpp +++ b/searchlib/src/vespa/searchlib/aggregation/grouping.cpp @@ -220,7 +220,7 @@ void Grouping::aggregate(const RankedHit * rankedHit, unsigned int len) preAggregate(isOrdered); HitsAggregationResult::SetOrdered pred; select(pred, pred); - if (_clock == NULL) { + if (_clock == nullptr) { aggregateWithoutClock(rankedHit, getMaxN(len)); } else { aggregateWithClock(rankedHit, getMaxN(len)); @@ -231,14 +231,14 @@ void Grouping::aggregate(const RankedHit * rankedHit, unsigned int len) void Grouping::aggregate(const RankedHit * rankedHit, unsigned int len, const BitVector * bVec) { preAggregate(false); - if (_clock == NULL) { + if (_clock == nullptr) { aggregateWithoutClock(rankedHit, getMaxN(len)); } else { aggregateWithClock(rankedHit, getMaxN(len)); } - if (bVec != NULL) { + if (bVec != nullptr) { unsigned int sz(bVec->size()); - if (_clock == NULL) { + if (_clock == nullptr) { if (getTopN() > 0) { for(DocId d(bVec->getFirstTrueBit()), i(0), m(getMaxN(sz)); (d < sz) && (i < m); d = bVec->getNextTrueBit(d+1), i++) { aggregate(d, 0.0); @@ -291,12 +291,12 @@ void Grouping::sortById() void Grouping::configureStaticStuff(const ConfigureStaticParams & params) { - if (params._attrCtx != NULL) { + if (params._attrCtx != nullptr) { AttributeNode::Configure confAttr(*params._attrCtx); select(confAttr, confAttr); } - if (params._docType != NULL) { + if (params._docType != nullptr) { DocumentAccessorNode::Configure confDoc(*params._docType); select(confDoc, confDoc); } diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.cpp b/searchlib/src/vespa/searchlib/expression/attributenode.cpp index 7e46de934f0..abdc09d5e78 100644 --- a/searchlib/src/vespa/searchlib/expression/attributenode.cpp +++ b/searchlib/src/vespa/searchlib/expression/attributenode.cpp @@ -4,6 +4,7 @@ #include "resultvector.h" #include "enumattributeresult.h" #include <vespa/searchcommon/attribute/iattributecontext.h> +#include <cassert> namespace search::expression { @@ -85,13 +86,30 @@ createResult(const IAttributeVector * attribute) return std::make_unique<EnumAttributeResult>(enumRefs, attribute, 0); } +template<typename T> +std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<AttributeNode::Handler>> +createSingle() { + return { std::make_unique<T>(), std::unique_ptr<AttributeNode::Handler>()}; +} + +template<typename T, typename H> +std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<AttributeNode::Handler>> +createMulti() { + auto result = std::make_unique<T>(); + auto handler = std::make_unique<H>(*result); + return { std::move(result), std::move(handler)}; +} + } AttributeNode::AttributeNode() : FunctionNode(), _scratchResult(std::make_unique<AttributeResult>()), + _index(nullptr), + _keepAliveForIndexLookups(), _hasMultiValue(false), _useEnumOptimization(false), + _needExecute(true), _handler(), _attributeName() {} @@ -101,16 +119,22 @@ AttributeNode::~AttributeNode() = default; AttributeNode::AttributeNode(vespalib::stringref name) : FunctionNode(), _scratchResult(std::make_unique<AttributeResult>()), + _index(nullptr), + _keepAliveForIndexLookups(), _hasMultiValue(false), _useEnumOptimization(false), + _needExecute(true), _handler(), _attributeName(name) {} AttributeNode::AttributeNode(const IAttributeVector & attribute) : FunctionNode(), _scratchResult(createResult(&attribute)), + _index(nullptr), + _keepAliveForIndexLookups(), _hasMultiValue(attribute.hasMultiValue()), _useEnumOptimization(false), + _needExecute(true), _handler(), _attributeName(attribute.getName()) {} @@ -118,8 +142,11 @@ AttributeNode::AttributeNode(const IAttributeVector & attribute) : AttributeNode::AttributeNode(const AttributeNode & attribute) : FunctionNode(attribute), _scratchResult(attribute._scratchResult->clone()), + _index(nullptr), + _keepAliveForIndexLookups(), _hasMultiValue(attribute._hasMultiValue), _useEnumOptimization(attribute._useEnumOptimization), + _needExecute(true), _handler(), _attributeName(attribute._attributeName) { @@ -136,105 +163,105 @@ AttributeNode::operator = (const AttributeNode & attr) _useEnumOptimization = attr._useEnumOptimization; _scratchResult.reset(attr._scratchResult->clone()); _scratchResult->setDocId(0); + _handler.reset(); + _index = nullptr; + _keepAliveForIndexLookups.reset(); + _needExecute = true; } return *this; } -void -AttributeNode::onPrepare(bool preserveAccurateTypes) -{ - const IAttributeVector * attribute = _scratchResult->getAttribute(); - if (attribute != nullptr) { - BasicType::Type basicType = attribute->getBasicType(); - if (attribute->isIntegerType()) { - if (_hasMultiValue) { - if (basicType == BasicType::BOOL) { - setResultType(std::make_unique<BoolResultNodeVector>()); - _handler = std::make_unique<IntegerHandler<BoolResultNodeVector>>(updateResult()); - } else if (preserveAccurateTypes) { - switch (basicType) { - case BasicType::INT8: - setResultType(std::make_unique<Int8ResultNodeVector>()); - _handler = std::make_unique<IntegerHandler<Int8ResultNodeVector>>(updateResult()); - break; - case BasicType::INT16: - setResultType(std::make_unique<Int16ResultNodeVector>()); - _handler = std::make_unique<IntegerHandler<Int16ResultNodeVector>>(updateResult()); - break; - case BasicType::INT32: - setResultType(std::make_unique<Int32ResultNodeVector>()); - _handler = std::make_unique<IntegerHandler<Int32ResultNodeVector>>(updateResult()); - break; - case BasicType::INT64: - setResultType(std::make_unique<Int64ResultNodeVector>()); - _handler = std::make_unique<IntegerHandler<Int64ResultNodeVector>>(updateResult()); - break; - default: - throw std::runtime_error("This is no valid integer attribute " + attribute->getName()); - break; - } - } else { - setResultType(std::make_unique<IntegerResultNodeVector>()); - _handler = std::make_unique<IntegerHandler<IntegerResultNodeVector>>(updateResult()); +std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<AttributeNode::Handler>> +AttributeNode::createResultAndHandler(bool preserveAccurateTypes, const attribute::IAttributeVector & attribute) const { + BasicType::Type basicType = attribute.getBasicType(); + if (attribute.isIntegerType()) { + if (_hasMultiValue) { + if (basicType == BasicType::BOOL) { + return createMulti<BoolResultNodeVector, IntegerHandler<BoolResultNodeVector>>(); + } else if (preserveAccurateTypes) { + switch (basicType) { + case BasicType::INT8: + return createMulti<Int8ResultNodeVector, IntegerHandler<Int8ResultNodeVector>>(); + case BasicType::INT16: + return createMulti<Int16ResultNodeVector, IntegerHandler<Int16ResultNodeVector>>(); + case BasicType::INT32: + return createMulti<Int32ResultNodeVector, IntegerHandler<Int32ResultNodeVector>>(); + case BasicType::INT64: + return createMulti<Int64ResultNodeVector, IntegerHandler<Int64ResultNodeVector>>(); + default: + throw std::runtime_error("This is no valid integer attribute " + attribute.getName()); } } else { - if (basicType == BasicType::BOOL) { - setResultType(std::make_unique<BoolResultNode>()); - } else if (preserveAccurateTypes) { - switch (basicType) { - case BasicType::INT8: - setResultType(std::make_unique<Int8ResultNode>()); - break; - case BasicType::INT16: - setResultType(std::make_unique<Int16ResultNode>()); - break; - case BasicType::INT32: - setResultType(std::make_unique<Int32ResultNode>()); - break; - case BasicType::INT64: - setResultType(std::make_unique<Int64ResultNode>()); - break; - default: - throw std::runtime_error("This is no valid integer attribute " + attribute->getName()); - break; - } - } else { - setResultType(std::make_unique<Int64ResultNode>()); - } + return createMulti<IntegerResultNodeVector, IntegerHandler<IntegerResultNodeVector>>(); } - } else if (attribute->isFloatingPointType()) { - if (_hasMultiValue) { - setResultType(std::make_unique<FloatResultNodeVector>()); - _handler = std::make_unique<FloatHandler>(updateResult()); + } else { + if (basicType == BasicType::BOOL) { + return createSingle<BoolResultNode>(); + } else if (preserveAccurateTypes) { + switch (basicType) { + case BasicType::INT8: + return createSingle<Int8ResultNode>(); + case BasicType::INT16: + return createSingle<Int16ResultNode>(); + case BasicType::INT32: + return createSingle<Int32ResultNode>(); + case BasicType::INT64: + return createSingle<Int64ResultNode>(); + default: + throw std::runtime_error("This is no valid integer attribute " + attribute.getName()); + } } else { - setResultType(std::make_unique<FloatResultNode>()); + return createSingle<Int64ResultNode>(); } - } else if (attribute->isStringType()) { - if (_hasMultiValue) { - if (_useEnumOptimization) { - setResultType(std::make_unique<EnumResultNodeVector>()); - _handler = std::make_unique<EnumHandler>(updateResult()); - } else { - setResultType(std::make_unique<StringResultNodeVector>()); - _handler = std::make_unique<StringHandler>(updateResult()); - } + } + } else if (attribute.isFloatingPointType()) { + if (_hasMultiValue) { + return createMulti<FloatResultNodeVector, FloatHandler>(); + } else { + return createSingle<FloatResultNode>(); + } + } else if (attribute.isStringType()) { + if (_hasMultiValue) { + if (_useEnumOptimization) { + return createMulti<EnumResultNodeVector, EnumHandler>(); } else { - if (_useEnumOptimization) { - setResultType(std::make_unique<EnumResultNode>()); - } else { - setResultType(std::make_unique<StringResultNode>()); - } + return createMulti<StringResultNodeVector, StringHandler>(); } - } else if (attribute->is_raw_type()) { - if (_hasMultiValue) { - throw std::runtime_error(make_string("Does not support multivalue raw attribute vector '%s'", - attribute->getName().c_str())); + } else { + if (_useEnumOptimization) { + return createSingle<EnumResultNode>(); } else { - setResultType(std::make_unique<RawResultNode>()); + return createSingle<StringResultNode>(); } + } + } else if (attribute.is_raw_type()) { + if (_hasMultiValue) { + throw std::runtime_error(make_string("Does not support multivalue raw attribute vector '%s'", + attribute.getName().c_str())); + } else { + return createSingle<RawResultNode>(); + } + } else { + throw std::runtime_error(make_string("Can not deduce correct resultclass for attribute vector '%s'", + attribute.getName().c_str())); + } +} + +void +AttributeNode::onPrepare(bool preserveAccurateTypes) +{ + const IAttributeVector * attribute = getAttribute(); + if (attribute != nullptr) { + auto[result, handler] = createResultAndHandler(preserveAccurateTypes, *attribute); + _handler = std::move(handler); + if (_index == nullptr) { + setResultType(std::move(result)); } else { - throw std::runtime_error(make_string("Can not deduce correct resultclass for attribute vector '%s'", - attribute->getName().c_str())); + assert(_hasMultiValue); + assert(_handler); + setResultType(result->createBaseType()); + assert(result->inherits(ResultNodeVector::classId)); + _keepAliveForIndexLookups.reset(dynamic_cast<ResultNodeVector *>(result.release())); } } } @@ -287,10 +314,24 @@ void AttributeNode::EnumHandler::handle(const AttributeResult & r) } } +void +AttributeNode::setDocId(DocId docId) { + _scratchResult->setDocId(docId); + _needExecute = true; +} + bool AttributeNode::onExecute() const { if (_handler) { - _handler->handle(*_scratchResult); + if (_needExecute) { + _handler->handle(*_scratchResult); + _needExecute = false; + } + if (_index != nullptr) { + assert(_hasMultiValue); + assert(_keepAliveForIndexLookups); + updateResult().set(_keepAliveForIndexLookups->get(_index->get())); + } } else { updateResult().set(*_scratchResult); } diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.h b/searchlib/src/vespa/searchlib/expression/attributenode.h index 67ec6a3302f..abb46240d65 100644 --- a/searchlib/src/vespa/searchlib/expression/attributenode.h +++ b/searchlib/src/vespa/searchlib/expression/attributenode.h @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include "currentindex.h" #include "functionnode.h" #include "attributeresult.h" #include <vespa/vespalib/objects/objectoperation.h> @@ -19,7 +20,7 @@ public: class Configure : public vespalib::ObjectOperation, public vespalib::ObjectPredicate { public: - Configure(const search::attribute::IAttributeContext & attrCtx) : _attrCtx(attrCtx) { } + Configure(const attribute::IAttributeContext & attrCtx) : _attrCtx(attrCtx) { } private: void execute(vespalib::Identifiable &obj) override { static_cast<ExpressionNode &>(obj).wireAttributes(_attrCtx); @@ -28,7 +29,7 @@ public: bool check(const vespalib::Identifiable &obj) const override { return obj.inherits(ExpressionNode::classId); } - const search::attribute::IAttributeContext & _attrCtx; + const attribute::IAttributeContext & _attrCtx; }; class CleanupAttributeReferences : public vespalib::ObjectOperation, public vespalib::ObjectPredicate @@ -42,12 +43,12 @@ public: DECLARE_EXPRESSIONNODE(AttributeNode); AttributeNode(); AttributeNode(vespalib::stringref name); - AttributeNode(const search::attribute::IAttributeVector & attribute); + AttributeNode(const attribute::IAttributeVector & attribute); AttributeNode(const AttributeNode & attribute); AttributeNode & operator = (const AttributeNode & attribute); ~AttributeNode() override; - void setDocId(DocId docId) const { _scratchResult->setDocId(docId); } - const search::attribute::IAttributeVector *getAttribute() const { + void setDocId(DocId docId); + const attribute::IAttributeVector *getAttribute() const { return _scratchResult ? _scratchResult->getAttribute() : nullptr; } const vespalib::string & getAttributeName() const { return _attributeName; } @@ -62,21 +63,26 @@ public: virtual void handle(const AttributeResult & r) = 0; }; private: + std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<Handler>> + createResultAndHandler(bool preserveAccurateType, const attribute::IAttributeVector & attribute) const; template <typename V> class IntegerHandler; class FloatHandler; class StringHandler; class EnumHandler; protected: virtual void cleanup(); - void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override; + void wireAttributes(const attribute::IAttributeContext & attrCtx) override; void onPrepare(bool preserveAccurateTypes) override; bool onExecute() const override; - std::unique_ptr<AttributeResult> _scratchResult; - bool _hasMultiValue; - bool _useEnumOptimization; - std::unique_ptr<Handler> _handler; - vespalib::string _attributeName; + std::unique_ptr<AttributeResult> _scratchResult; + const CurrentIndex *_index; + std::unique_ptr<ResultNodeVector> _keepAliveForIndexLookups; + bool _hasMultiValue; + bool _useEnumOptimization; + mutable bool _needExecute; + std::unique_ptr<Handler> _handler; + vespalib::string _attributeName; }; } diff --git a/searchlib/src/vespa/searchlib/expression/currentindex.h b/searchlib/src/vespa/searchlib/expression/currentindex.h new file mode 100644 index 00000000000..98ecdf7252e --- /dev/null +++ b/searchlib/src/vespa/searchlib/expression/currentindex.h @@ -0,0 +1,18 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <cstdint> + +namespace search::expression { + +class CurrentIndex { +public: + CurrentIndex() noexcept : _index(0) {} + uint32_t get() const noexcept { return _index; } + void set(uint32_t index) noexcept { _index = index; } +private: + uint32_t _index; +}; + +} diff --git a/searchlib/src/vespa/searchlib/expression/functionnodes.cpp b/searchlib/src/vespa/searchlib/expression/functionnodes.cpp index 109b4a59a05..60574a355d0 100644 --- a/searchlib/src/vespa/searchlib/expression/functionnodes.cpp +++ b/searchlib/src/vespa/searchlib/expression/functionnodes.cpp @@ -149,21 +149,20 @@ ArithmeticTypeConversion::getType(const ResultNode & arg1, const ResultNode & ar { size_t baseTypeId = getType(getBaseType2(arg1), getBaseType2(arg2)); size_t dimension = std::max(getDimension(arg1), getDimension(arg2)); - ResultNode::UP result; if (dimension == 0) { return ResultNode::UP(static_cast<ResultNode *>(Identifiable::classFromId(baseTypeId)->create())); } else if (dimension == 1) { if (baseTypeId == Int64ResultNode::classId) { - result.reset(new IntegerResultNodeVector()); + return std::make_unique<IntegerResultNodeVector>(); } else if (baseTypeId == FloatResultNode::classId) { - result.reset(new FloatResultNodeVector()); + return std::make_unique<FloatResultNodeVector>(); } else { throw std::runtime_error("We can not handle anything but numbers."); } } else { throw std::runtime_error("We are not able to handle multidimensional arrays"); } - return result; + return ResultNode::UP(); } ResultNode::UP |