summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@yahooinc.com>2022-04-12 14:52:41 +0200
committerTor Egge <Tor.Egge@yahooinc.com>2022-04-12 15:23:08 +0200
commitbc85555178efa7dafe1e790a21906378402c398e (patch)
tree2810b0f4e7ff847283e97aa927dda8403b1cad33 /searchlib
parent063a689b449c9d3e8ba247f2dab95819b4c0ae36 (diff)
Simplify dot product executors for array attribute vectors.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.cpp42
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.h25
2 files changed, 19 insertions, 48 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
index 0b4e09dde0d..b279cf6ca08 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
@@ -272,9 +272,8 @@ template <typename BaseType>
void DotProductExecutorBase<BaseType>::execute(uint32_t docId) {
auto values = getAttributeValues(docId);
size_t commonRange = std::min(values.size(), _queryVector.size());
- static_assert(std::is_same_v<multivalue::ValueType_t<AT>, BaseType>);
outputs().set_number(0, _multiplier.dotProduct(
- &_queryVector[0], reinterpret_cast<const multivalue::ValueType_t<AT> *>(values.data()), commonRange));
+ &_queryVector[0], values.data(), commonRange));
}
template <typename BaseType>
@@ -288,7 +287,7 @@ template <typename BaseType>
DotProductByArrayReadViewExecutor<BaseType>::~DotProductByArrayReadViewExecutor() = default;
template <typename BaseType>
-vespalib::ConstArrayRef<typename DotProductByArrayReadViewExecutor<BaseType>::AT>
+vespalib::ConstArrayRef<BaseType>
DotProductByArrayReadViewExecutor<BaseType>::getAttributeValues(uint32_t docId)
{
return _array_read_view->get_raw_values(docId);
@@ -326,7 +325,7 @@ template <typename BaseType>
SparseDotProductByArrayReadViewExecutor<BaseType>::~SparseDotProductByArrayReadViewExecutor() = default;
template <typename BaseType>
-vespalib::ConstArrayRef<typename SparseDotProductByArrayReadViewExecutor<BaseType>::AT>
+vespalib::ConstArrayRef<BaseType>
SparseDotProductByArrayReadViewExecutor<BaseType>::getAttributeValues(uint32_t docid)
{
auto allValues = _array_read_view->get_raw_values(docid);
@@ -348,7 +347,7 @@ template <typename A>
DotProductByCopyExecutor<A>::~DotProductByCopyExecutor() = default;
template <typename A>
-vespalib::ConstArrayRef<typename DotProductByCopyExecutor<A>::AT>
+vespalib::ConstArrayRef<typename A::BaseType>
DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId)
{
size_t count = this->_attribute->getAll(docId, &_copy[0], _copy.size());
@@ -356,7 +355,7 @@ DotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId)
_copy.resize(count);
count = this->_attribute->getAll(docId, &_copy[0], _copy.size());
}
- return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_copy.data()), count);
+ return vespalib::ConstArrayRef(_copy.data(), count);
}
template <typename A>
@@ -371,7 +370,7 @@ template <typename A>
SparseDotProductByCopyExecutor<A>::~SparseDotProductByCopyExecutor() = default;
template <typename A>
-vespalib::ConstArrayRef<typename SparseDotProductByCopyExecutor<A>::AT>
+vespalib::ConstArrayRef<typename A::BaseType>
SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId)
{
size_t count = this->_attribute->getAll(docId, &_copy[0], _copy.size());
@@ -383,7 +382,7 @@ SparseDotProductByCopyExecutor<A>::getAttributeValues(uint32_t docId)
for (const IV & iv(this->_queryIndexes); (i < iv.size()) && (iv[i] < count); i++) {
_copy[i] = _copy[iv[i]];
}
- return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_copy.data()), i);
+ return vespalib::ConstArrayRef(_copy.data(), i);
}
template <typename BaseType>
@@ -400,29 +399,11 @@ DotProductByContentFillExecutor<BaseType>::DotProductByContentFillExecutor(
template <typename BaseType>
DotProductByContentFillExecutor<BaseType>::~DotProductByContentFillExecutor() = default;
-namespace {
-
-template<typename T> struct IsNonWeightedType : std::true_type {};
-template<typename BaseType> struct IsNonWeightedType<multivalue::WeightedValue<BaseType>> : std::false_type {};
-
-// Compile-time sanity check for type compatibility of gnarly BaseType <-> multivalue::Value
-// reinterpret_cast used by some getAttributeValues calls.
-template <typename BaseType, typename AttributeValueType, typename FillerValueType>
-constexpr void sanity_check_reinterpret_cast_compatibility() {
- static_assert(IsNonWeightedType<AttributeValueType>::value);
- static_assert(sizeof(BaseType) == sizeof(AttributeValueType));
- static_assert(sizeof(BaseType) == sizeof(FillerValueType));
- static_assert(std::is_same_v<BaseType, multivalue::ValueType_t<AttributeValueType>>);
-}
-
-}
-
template <typename BaseType>
-vespalib::ConstArrayRef<typename DotProductByContentFillExecutor<BaseType>::AT>
+vespalib::ConstArrayRef<BaseType>
DotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid) {
_filler.fill(*_attribute, docid);
- sanity_check_reinterpret_cast_compatibility<BaseType, AT, decltype(*_filler.data())>();
- return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(_filler.data()), _filler.size());
+ return vespalib::ConstArrayRef(_filler.data(), _filler.size());
}
template <typename BaseType>
@@ -442,7 +423,7 @@ template <typename BaseType>
SparseDotProductByContentFillExecutor<BaseType>::~SparseDotProductByContentFillExecutor() = default;
template <typename BaseType>
-vespalib::ConstArrayRef<typename SparseDotProductByContentFillExecutor<BaseType>::AT>
+vespalib::ConstArrayRef<BaseType>
SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t docid)
{
_filler.fill(*_attribute, docid);
@@ -454,8 +435,7 @@ SparseDotProductByContentFillExecutor<BaseType>::getAttributeValues(uint32_t doc
data[i] = data[_queryIndexes[i]];
}
- sanity_check_reinterpret_cast_compatibility<BaseType, AT, decltype(*_filler.data())>();
- return vespalib::ConstArrayRef(reinterpret_cast<const AT *>(data), i);
+ return vespalib::ConstArrayRef(data, i);
}
}
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
index 8ce545bd23d..051a08025d8 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
@@ -180,12 +180,11 @@ namespace array {
template <typename BaseType>
class DotProductExecutorBase : public fef::FeatureExecutor {
public:
- using AT = BaseType;
using V = std::vector<BaseType>;
private:
const vespalib::hwaccelrated::IAccelrated & _multiplier;
V _queryVector;
- virtual vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) = 0;
+ virtual vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) = 0;
public:
DotProductExecutorBase(const V & queryVector);
~DotProductExecutorBase() override;
@@ -198,13 +197,12 @@ public:
template <typename BaseType>
class DotProductByArrayReadViewExecutor : public DotProductExecutorBase<BaseType> {
public:
- using AT = typename DotProductExecutorBase<BaseType>::AT;
using V = typename DotProductExecutorBase<BaseType>::V;
using ArrayReadView = attribute::IArrayReadView<BaseType>;
protected:
const ArrayReadView* _array_read_view;
private:
- vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override;
+ vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) override;
public:
DotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector);
~DotProductByArrayReadViewExecutor();
@@ -216,7 +214,6 @@ public:
template <typename A>
class DotProductExecutor : public DotProductExecutorBase<typename A::BaseType> {
public:
- using AT = typename DotProductExecutorBase<typename A::BaseType>::AT;
using V = typename DotProductExecutorBase<typename A::BaseType>::V;
protected:
const A * _attribute;
@@ -232,8 +229,7 @@ public:
DotProductByCopyExecutor(const A * attribute, const V & queryVector);
~DotProductByCopyExecutor();
private:
- typedef typename DotProductExecutor<A>::AT AT;
- vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override;
+ vespalib::ConstArrayRef<typename A::BaseType> getAttributeValues(uint32_t docid) final override;
std::vector<typename A::BaseType> _copy;
};
@@ -252,13 +248,12 @@ template <typename BaseType>
class DotProductByContentFillExecutor : public DotProductExecutorBase<BaseType> {
public:
using V = typename DotProductExecutorBase<BaseType>::V;
- using AT = typename DotProductExecutorBase<BaseType>::AT;
using ValueFiller = attribute::AttributeContent<BaseType>;
DotProductByContentFillExecutor(const attribute::IAttributeVector * attribute, const V & queryVector);
~DotProductByContentFillExecutor();
private:
- vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override;
+ vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) final override;
const attribute::IAttributeVector* _attribute;
ValueFiller _filler;
@@ -272,9 +267,8 @@ public:
SparseDotProductExecutorBase(const V & queryVector, const IV & queryIndexes);
~SparseDotProductExecutorBase();
protected:
- typedef typename DotProductExecutorBase<BaseType>::AT AT;
IV _queryIndexes;
- std::vector<AT> _scratch;
+ std::vector<BaseType> _scratch;
};
template <typename BaseType>
@@ -288,8 +282,7 @@ public:
SparseDotProductByArrayReadViewExecutor(const ArrayReadView* array_read_view, const V & queryVector, const IV & queryIndexes);
~SparseDotProductByArrayReadViewExecutor();
private:
- typedef typename SparseDotProductExecutorBase<BaseType>::AT AT;
- vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) override;
+ vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) override;
const ArrayReadView* _array_read_view;
};
@@ -301,8 +294,7 @@ public:
SparseDotProductByCopyExecutor(const A * attribute, const V & queryVector, const IV & queryIndexes);
~SparseDotProductByCopyExecutor();
private:
- typedef typename SparseDotProductExecutorBase<typename A::BaseType>::AT AT;
- vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override;
+ vespalib::ConstArrayRef<typename A::BaseType> getAttributeValues(uint32_t docid) final override;
const A* _attribute;
std::vector<typename A::BaseType> _copy;
};
@@ -316,7 +308,6 @@ class SparseDotProductByContentFillExecutor : public DotProductExecutorBase<Base
public:
using IV = std::vector<uint32_t>;
using V = typename DotProductExecutorBase<BaseType>::V;
- using AT = typename DotProductExecutorBase<BaseType>::AT;
using ValueFiller = attribute::AttributeContent<BaseType>;
SparseDotProductByContentFillExecutor(const attribute::IAttributeVector * attribute,
@@ -324,7 +315,7 @@ public:
const IV & queryIndexes);
~SparseDotProductByContentFillExecutor() override;
private:
- vespalib::ConstArrayRef<AT> getAttributeValues(uint32_t docid) final override;
+ vespalib::ConstArrayRef<BaseType> getAttributeValues(uint32_t docid) final override;
const attribute::IAttributeVector* _attribute;
IV _queryIndexes;