summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2020-10-14 18:21:07 +0000
committerArne Juul <arnej@verizonmedia.com>2020-10-14 19:02:40 +0000
commitf9561ddf045834fc4332367af17f9d1d62690774 (patch)
tree7bec9fa884be3ede4e0f3b35b2a78be6beb17e4d /searchlib
parenta5489bb3c38e9c69cba984723a3e53c61715112d (diff)
avoid using code that should be internal to TypedBinaryFormat
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.cpp28
1 files changed, 25 insertions, 3 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
index 37fd98c9f20..83beee98634 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
@@ -10,7 +10,7 @@
#include <vespa/searchlib/attribute/floatbase.h>
#include <vespa/searchlib/attribute/multinumericattribute.h>
#include <vespa/searchlib/attribute/multienumattribute.h>
-#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/eval/eval/engine_or_factory.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/stash.h>
@@ -19,6 +19,8 @@ LOG_SETUP(".features.dotproduct");
using namespace search::attribute;
using namespace search::fef;
+using vespalib::eval::EngineOrFactory;
+using vespalib::eval::TypedCells;
using vespalib::hwaccelrated::IAccelrated;
namespace search::features {
@@ -473,7 +475,19 @@ parseVectors<int8_t, int8_t>(const Property& prop, std::vector<int8_t>& values,
parseVectors<int8_t, int16_t>(prop, values, indexes);
}
-}
+template <typename TCT>
+struct CopyCellsToVector {
+ template<typename ICT>
+ static void invoke(TypedCells source, std::vector<TCT> &target) {
+ target.reserve(source.size);
+ auto cells = source.typify<ICT>();
+ for (auto value : cells) {
+ target.push_back(value);
+ }
+ }
+};
+
+} // namespace <unnamed>
namespace dotproduct {
@@ -484,7 +498,15 @@ ArrayParam<T>::ArrayParam(const Property & prop) {
template <typename T>
ArrayParam<T>::ArrayParam(vespalib::nbostream & stream) {
- vespalib::tensor::TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(stream, values);
+ using vespalib::typify_invoke;
+ using vespalib::eval::TypifyCellType;
+ auto tensor = EngineOrFactory::get().decode(stream);
+ if (tensor->type().is_dense()) {
+ TypedCells cells = tensor->cells();
+ typify_invoke<1,TypifyCellType,CopyCellsToVector<T>>(cells.type, cells, values);
+ } else {
+ LOG(warning, "Expected dense tensor, but got type '%s'", tensor->type().to_spec().c_str());
+ }
}
template <typename T>