summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2023-06-12 12:06:43 +0200
committerGitHub <noreply@github.com>2023-06-12 12:06:43 +0200
commit212677f8d36e5b41414b0f03dec370dbfe749448 (patch)
tree58900f5afe63295c5fa24ddee20957427adac9a0
parent39f44c787daa051d8308ce6871e346a61a0d2f8d (diff)
parentc63710e6b1da936e773bce6a3da6b11539478172 (diff)
Merge pull request #27363 from vespa-engine/balder/refactor-attributenode-for-lookups
Prepare AttributeNode to handle both full vector extraction and singl…
-rw-r--r--searchlib/src/vespa/searchlib/aggregation/grouping.cpp12
-rw-r--r--searchlib/src/vespa/searchlib/expression/attributenode.cpp213
-rw-r--r--searchlib/src/vespa/searchlib/expression/attributenode.h28
-rw-r--r--searchlib/src/vespa/searchlib/expression/currentindex.h18
-rw-r--r--searchlib/src/vespa/searchlib/expression/functionnodes.cpp7
5 files changed, 171 insertions, 107 deletions
diff --git a/searchlib/src/vespa/searchlib/aggregation/grouping.cpp b/searchlib/src/vespa/searchlib/aggregation/grouping.cpp
index daff16e9029..b61cb585899 100644
--- a/searchlib/src/vespa/searchlib/aggregation/grouping.cpp
+++ b/searchlib/src/vespa/searchlib/aggregation/grouping.cpp
@@ -220,7 +220,7 @@ void Grouping::aggregate(const RankedHit * rankedHit, unsigned int len)
preAggregate(isOrdered);
HitsAggregationResult::SetOrdered pred;
select(pred, pred);
- if (_clock == NULL) {
+ if (_clock == nullptr) {
aggregateWithoutClock(rankedHit, getMaxN(len));
} else {
aggregateWithClock(rankedHit, getMaxN(len));
@@ -231,14 +231,14 @@ void Grouping::aggregate(const RankedHit * rankedHit, unsigned int len)
void Grouping::aggregate(const RankedHit * rankedHit, unsigned int len, const BitVector * bVec)
{
preAggregate(false);
- if (_clock == NULL) {
+ if (_clock == nullptr) {
aggregateWithoutClock(rankedHit, getMaxN(len));
} else {
aggregateWithClock(rankedHit, getMaxN(len));
}
- if (bVec != NULL) {
+ if (bVec != nullptr) {
unsigned int sz(bVec->size());
- if (_clock == NULL) {
+ if (_clock == nullptr) {
if (getTopN() > 0) {
for(DocId d(bVec->getFirstTrueBit()), i(0), m(getMaxN(sz)); (d < sz) && (i < m); d = bVec->getNextTrueBit(d+1), i++) {
aggregate(d, 0.0);
@@ -291,12 +291,12 @@ void Grouping::sortById()
void Grouping::configureStaticStuff(const ConfigureStaticParams & params)
{
- if (params._attrCtx != NULL) {
+ if (params._attrCtx != nullptr) {
AttributeNode::Configure confAttr(*params._attrCtx);
select(confAttr, confAttr);
}
- if (params._docType != NULL) {
+ if (params._docType != nullptr) {
DocumentAccessorNode::Configure confDoc(*params._docType);
select(confDoc, confDoc);
}
diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.cpp b/searchlib/src/vespa/searchlib/expression/attributenode.cpp
index 7e46de934f0..abdc09d5e78 100644
--- a/searchlib/src/vespa/searchlib/expression/attributenode.cpp
+++ b/searchlib/src/vespa/searchlib/expression/attributenode.cpp
@@ -4,6 +4,7 @@
#include "resultvector.h"
#include "enumattributeresult.h"
#include <vespa/searchcommon/attribute/iattributecontext.h>
+#include <cassert>
namespace search::expression {
@@ -85,13 +86,30 @@ createResult(const IAttributeVector * attribute)
return std::make_unique<EnumAttributeResult>(enumRefs, attribute, 0);
}
+template<typename T>
+std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<AttributeNode::Handler>>
+createSingle() {
+ return { std::make_unique<T>(), std::unique_ptr<AttributeNode::Handler>()};
+}
+
+template<typename T, typename H>
+std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<AttributeNode::Handler>>
+createMulti() {
+ auto result = std::make_unique<T>();
+ auto handler = std::make_unique<H>(*result);
+ return { std::move(result), std::move(handler)};
+}
+
}
AttributeNode::AttributeNode() :
FunctionNode(),
_scratchResult(std::make_unique<AttributeResult>()),
+ _index(nullptr),
+ _keepAliveForIndexLookups(),
_hasMultiValue(false),
_useEnumOptimization(false),
+ _needExecute(true),
_handler(),
_attributeName()
{}
@@ -101,16 +119,22 @@ AttributeNode::~AttributeNode() = default;
AttributeNode::AttributeNode(vespalib::stringref name) :
FunctionNode(),
_scratchResult(std::make_unique<AttributeResult>()),
+ _index(nullptr),
+ _keepAliveForIndexLookups(),
_hasMultiValue(false),
_useEnumOptimization(false),
+ _needExecute(true),
_handler(),
_attributeName(name)
{}
AttributeNode::AttributeNode(const IAttributeVector & attribute) :
FunctionNode(),
_scratchResult(createResult(&attribute)),
+ _index(nullptr),
+ _keepAliveForIndexLookups(),
_hasMultiValue(attribute.hasMultiValue()),
_useEnumOptimization(false),
+ _needExecute(true),
_handler(),
_attributeName(attribute.getName())
{}
@@ -118,8 +142,11 @@ AttributeNode::AttributeNode(const IAttributeVector & attribute) :
AttributeNode::AttributeNode(const AttributeNode & attribute) :
FunctionNode(attribute),
_scratchResult(attribute._scratchResult->clone()),
+ _index(nullptr),
+ _keepAliveForIndexLookups(),
_hasMultiValue(attribute._hasMultiValue),
_useEnumOptimization(attribute._useEnumOptimization),
+ _needExecute(true),
_handler(),
_attributeName(attribute._attributeName)
{
@@ -136,105 +163,105 @@ AttributeNode::operator = (const AttributeNode & attr)
_useEnumOptimization = attr._useEnumOptimization;
_scratchResult.reset(attr._scratchResult->clone());
_scratchResult->setDocId(0);
+ _handler.reset();
+ _index = nullptr;
+ _keepAliveForIndexLookups.reset();
+ _needExecute = true;
}
return *this;
}
-void
-AttributeNode::onPrepare(bool preserveAccurateTypes)
-{
- const IAttributeVector * attribute = _scratchResult->getAttribute();
- if (attribute != nullptr) {
- BasicType::Type basicType = attribute->getBasicType();
- if (attribute->isIntegerType()) {
- if (_hasMultiValue) {
- if (basicType == BasicType::BOOL) {
- setResultType(std::make_unique<BoolResultNodeVector>());
- _handler = std::make_unique<IntegerHandler<BoolResultNodeVector>>(updateResult());
- } else if (preserveAccurateTypes) {
- switch (basicType) {
- case BasicType::INT8:
- setResultType(std::make_unique<Int8ResultNodeVector>());
- _handler = std::make_unique<IntegerHandler<Int8ResultNodeVector>>(updateResult());
- break;
- case BasicType::INT16:
- setResultType(std::make_unique<Int16ResultNodeVector>());
- _handler = std::make_unique<IntegerHandler<Int16ResultNodeVector>>(updateResult());
- break;
- case BasicType::INT32:
- setResultType(std::make_unique<Int32ResultNodeVector>());
- _handler = std::make_unique<IntegerHandler<Int32ResultNodeVector>>(updateResult());
- break;
- case BasicType::INT64:
- setResultType(std::make_unique<Int64ResultNodeVector>());
- _handler = std::make_unique<IntegerHandler<Int64ResultNodeVector>>(updateResult());
- break;
- default:
- throw std::runtime_error("This is no valid integer attribute " + attribute->getName());
- break;
- }
- } else {
- setResultType(std::make_unique<IntegerResultNodeVector>());
- _handler = std::make_unique<IntegerHandler<IntegerResultNodeVector>>(updateResult());
+std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<AttributeNode::Handler>>
+AttributeNode::createResultAndHandler(bool preserveAccurateTypes, const attribute::IAttributeVector & attribute) const {
+ BasicType::Type basicType = attribute.getBasicType();
+ if (attribute.isIntegerType()) {
+ if (_hasMultiValue) {
+ if (basicType == BasicType::BOOL) {
+ return createMulti<BoolResultNodeVector, IntegerHandler<BoolResultNodeVector>>();
+ } else if (preserveAccurateTypes) {
+ switch (basicType) {
+ case BasicType::INT8:
+ return createMulti<Int8ResultNodeVector, IntegerHandler<Int8ResultNodeVector>>();
+ case BasicType::INT16:
+ return createMulti<Int16ResultNodeVector, IntegerHandler<Int16ResultNodeVector>>();
+ case BasicType::INT32:
+ return createMulti<Int32ResultNodeVector, IntegerHandler<Int32ResultNodeVector>>();
+ case BasicType::INT64:
+ return createMulti<Int64ResultNodeVector, IntegerHandler<Int64ResultNodeVector>>();
+ default:
+ throw std::runtime_error("This is no valid integer attribute " + attribute.getName());
}
} else {
- if (basicType == BasicType::BOOL) {
- setResultType(std::make_unique<BoolResultNode>());
- } else if (preserveAccurateTypes) {
- switch (basicType) {
- case BasicType::INT8:
- setResultType(std::make_unique<Int8ResultNode>());
- break;
- case BasicType::INT16:
- setResultType(std::make_unique<Int16ResultNode>());
- break;
- case BasicType::INT32:
- setResultType(std::make_unique<Int32ResultNode>());
- break;
- case BasicType::INT64:
- setResultType(std::make_unique<Int64ResultNode>());
- break;
- default:
- throw std::runtime_error("This is no valid integer attribute " + attribute->getName());
- break;
- }
- } else {
- setResultType(std::make_unique<Int64ResultNode>());
- }
+ return createMulti<IntegerResultNodeVector, IntegerHandler<IntegerResultNodeVector>>();
}
- } else if (attribute->isFloatingPointType()) {
- if (_hasMultiValue) {
- setResultType(std::make_unique<FloatResultNodeVector>());
- _handler = std::make_unique<FloatHandler>(updateResult());
+ } else {
+ if (basicType == BasicType::BOOL) {
+ return createSingle<BoolResultNode>();
+ } else if (preserveAccurateTypes) {
+ switch (basicType) {
+ case BasicType::INT8:
+ return createSingle<Int8ResultNode>();
+ case BasicType::INT16:
+ return createSingle<Int16ResultNode>();
+ case BasicType::INT32:
+ return createSingle<Int32ResultNode>();
+ case BasicType::INT64:
+ return createSingle<Int64ResultNode>();
+ default:
+ throw std::runtime_error("This is no valid integer attribute " + attribute.getName());
+ }
} else {
- setResultType(std::make_unique<FloatResultNode>());
+ return createSingle<Int64ResultNode>();
}
- } else if (attribute->isStringType()) {
- if (_hasMultiValue) {
- if (_useEnumOptimization) {
- setResultType(std::make_unique<EnumResultNodeVector>());
- _handler = std::make_unique<EnumHandler>(updateResult());
- } else {
- setResultType(std::make_unique<StringResultNodeVector>());
- _handler = std::make_unique<StringHandler>(updateResult());
- }
+ }
+ } else if (attribute.isFloatingPointType()) {
+ if (_hasMultiValue) {
+ return createMulti<FloatResultNodeVector, FloatHandler>();
+ } else {
+ return createSingle<FloatResultNode>();
+ }
+ } else if (attribute.isStringType()) {
+ if (_hasMultiValue) {
+ if (_useEnumOptimization) {
+ return createMulti<EnumResultNodeVector, EnumHandler>();
} else {
- if (_useEnumOptimization) {
- setResultType(std::make_unique<EnumResultNode>());
- } else {
- setResultType(std::make_unique<StringResultNode>());
- }
+ return createMulti<StringResultNodeVector, StringHandler>();
}
- } else if (attribute->is_raw_type()) {
- if (_hasMultiValue) {
- throw std::runtime_error(make_string("Does not support multivalue raw attribute vector '%s'",
- attribute->getName().c_str()));
+ } else {
+ if (_useEnumOptimization) {
+ return createSingle<EnumResultNode>();
} else {
- setResultType(std::make_unique<RawResultNode>());
+ return createSingle<StringResultNode>();
}
+ }
+ } else if (attribute.is_raw_type()) {
+ if (_hasMultiValue) {
+ throw std::runtime_error(make_string("Does not support multivalue raw attribute vector '%s'",
+ attribute.getName().c_str()));
+ } else {
+ return createSingle<RawResultNode>();
+ }
+ } else {
+ throw std::runtime_error(make_string("Can not deduce correct resultclass for attribute vector '%s'",
+ attribute.getName().c_str()));
+ }
+}
+
+void
+AttributeNode::onPrepare(bool preserveAccurateTypes)
+{
+ const IAttributeVector * attribute = getAttribute();
+ if (attribute != nullptr) {
+ auto[result, handler] = createResultAndHandler(preserveAccurateTypes, *attribute);
+ _handler = std::move(handler);
+ if (_index == nullptr) {
+ setResultType(std::move(result));
} else {
- throw std::runtime_error(make_string("Can not deduce correct resultclass for attribute vector '%s'",
- attribute->getName().c_str()));
+ assert(_hasMultiValue);
+ assert(_handler);
+ setResultType(result->createBaseType());
+ assert(result->inherits(ResultNodeVector::classId));
+ _keepAliveForIndexLookups.reset(dynamic_cast<ResultNodeVector *>(result.release()));
}
}
}
@@ -287,10 +314,24 @@ void AttributeNode::EnumHandler::handle(const AttributeResult & r)
}
}
+void
+AttributeNode::setDocId(DocId docId) {
+ _scratchResult->setDocId(docId);
+ _needExecute = true;
+}
+
bool AttributeNode::onExecute() const
{
if (_handler) {
- _handler->handle(*_scratchResult);
+ if (_needExecute) {
+ _handler->handle(*_scratchResult);
+ _needExecute = false;
+ }
+ if (_index != nullptr) {
+ assert(_hasMultiValue);
+ assert(_keepAliveForIndexLookups);
+ updateResult().set(_keepAliveForIndexLookups->get(_index->get()));
+ }
} else {
updateResult().set(*_scratchResult);
}
diff --git a/searchlib/src/vespa/searchlib/expression/attributenode.h b/searchlib/src/vespa/searchlib/expression/attributenode.h
index 67ec6a3302f..abb46240d65 100644
--- a/searchlib/src/vespa/searchlib/expression/attributenode.h
+++ b/searchlib/src/vespa/searchlib/expression/attributenode.h
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#pragma once
+#include "currentindex.h"
#include "functionnode.h"
#include "attributeresult.h"
#include <vespa/vespalib/objects/objectoperation.h>
@@ -19,7 +20,7 @@ public:
class Configure : public vespalib::ObjectOperation, public vespalib::ObjectPredicate
{
public:
- Configure(const search::attribute::IAttributeContext & attrCtx) : _attrCtx(attrCtx) { }
+ Configure(const attribute::IAttributeContext & attrCtx) : _attrCtx(attrCtx) { }
private:
void execute(vespalib::Identifiable &obj) override {
static_cast<ExpressionNode &>(obj).wireAttributes(_attrCtx);
@@ -28,7 +29,7 @@ public:
bool check(const vespalib::Identifiable &obj) const override {
return obj.inherits(ExpressionNode::classId);
}
- const search::attribute::IAttributeContext & _attrCtx;
+ const attribute::IAttributeContext & _attrCtx;
};
class CleanupAttributeReferences : public vespalib::ObjectOperation, public vespalib::ObjectPredicate
@@ -42,12 +43,12 @@ public:
DECLARE_EXPRESSIONNODE(AttributeNode);
AttributeNode();
AttributeNode(vespalib::stringref name);
- AttributeNode(const search::attribute::IAttributeVector & attribute);
+ AttributeNode(const attribute::IAttributeVector & attribute);
AttributeNode(const AttributeNode & attribute);
AttributeNode & operator = (const AttributeNode & attribute);
~AttributeNode() override;
- void setDocId(DocId docId) const { _scratchResult->setDocId(docId); }
- const search::attribute::IAttributeVector *getAttribute() const {
+ void setDocId(DocId docId);
+ const attribute::IAttributeVector *getAttribute() const {
return _scratchResult ? _scratchResult->getAttribute() : nullptr;
}
const vespalib::string & getAttributeName() const { return _attributeName; }
@@ -62,21 +63,26 @@ public:
virtual void handle(const AttributeResult & r) = 0;
};
private:
+ std::pair<std::unique_ptr<ResultNode>, std::unique_ptr<Handler>>
+ createResultAndHandler(bool preserveAccurateType, const attribute::IAttributeVector & attribute) const;
template <typename V> class IntegerHandler;
class FloatHandler;
class StringHandler;
class EnumHandler;
protected:
virtual void cleanup();
- void wireAttributes(const search::attribute::IAttributeContext & attrCtx) override;
+ void wireAttributes(const attribute::IAttributeContext & attrCtx) override;
void onPrepare(bool preserveAccurateTypes) override;
bool onExecute() const override;
- std::unique_ptr<AttributeResult> _scratchResult;
- bool _hasMultiValue;
- bool _useEnumOptimization;
- std::unique_ptr<Handler> _handler;
- vespalib::string _attributeName;
+ std::unique_ptr<AttributeResult> _scratchResult;
+ const CurrentIndex *_index;
+ std::unique_ptr<ResultNodeVector> _keepAliveForIndexLookups;
+ bool _hasMultiValue;
+ bool _useEnumOptimization;
+ mutable bool _needExecute;
+ std::unique_ptr<Handler> _handler;
+ vespalib::string _attributeName;
};
}
diff --git a/searchlib/src/vespa/searchlib/expression/currentindex.h b/searchlib/src/vespa/searchlib/expression/currentindex.h
new file mode 100644
index 00000000000..98ecdf7252e
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/expression/currentindex.h
@@ -0,0 +1,18 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <cstdint>
+
+namespace search::expression {
+
+class CurrentIndex {
+public:
+ CurrentIndex() noexcept : _index(0) {}
+ uint32_t get() const noexcept { return _index; }
+ void set(uint32_t index) noexcept { _index = index; }
+private:
+ uint32_t _index;
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/expression/functionnodes.cpp b/searchlib/src/vespa/searchlib/expression/functionnodes.cpp
index 109b4a59a05..60574a355d0 100644
--- a/searchlib/src/vespa/searchlib/expression/functionnodes.cpp
+++ b/searchlib/src/vespa/searchlib/expression/functionnodes.cpp
@@ -149,21 +149,20 @@ ArithmeticTypeConversion::getType(const ResultNode & arg1, const ResultNode & ar
{
size_t baseTypeId = getType(getBaseType2(arg1), getBaseType2(arg2));
size_t dimension = std::max(getDimension(arg1), getDimension(arg2));
- ResultNode::UP result;
if (dimension == 0) {
return ResultNode::UP(static_cast<ResultNode *>(Identifiable::classFromId(baseTypeId)->create()));
} else if (dimension == 1) {
if (baseTypeId == Int64ResultNode::classId) {
- result.reset(new IntegerResultNodeVector());
+ return std::make_unique<IntegerResultNodeVector>();
} else if (baseTypeId == FloatResultNode::classId) {
- result.reset(new FloatResultNodeVector());
+ return std::make_unique<FloatResultNodeVector>();
} else {
throw std::runtime_error("We can not handle anything but numbers.");
}
} else {
throw std::runtime_error("We are not able to handle multidimensional arrays");
}
- return result;
+ return ResultNode::UP();
}
ResultNode::UP