summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@oath.com>2019-04-04 18:16:59 +0000
committerHenning Baldersheim <balder@oath.com>2019-04-04 18:16:59 +0000
commit39f61af291ed2ba600bdada669e0a2111df80246 (patch)
treea4ce4a612b905b3fd9b1f8cb033b8ccef8e093dd /searchlib
parente081ec72d214c7a3322c8ecd62937af5ca49e1e1 (diff)
Accept a tensor set down in the dedicated '.tensor' field.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.cpp38
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.h2
2 files changed, 33 insertions, 7 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
index dffa3bb28b5..06c653b9a01 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
@@ -12,6 +12,8 @@
#include <type_traits>
#include <vespa/log/log.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/vespalib/objects/nbostream.h>
LOG_SETUP(".features.dotproduct");
@@ -340,11 +342,18 @@ ArrayParam<T>::ArrayParam(const Property & prop) {
parseVectors(prop, values, indexes);
}
+template <typename T>
+ArrayParam<T>::ArrayParam(vespalib::nbostream & stream) {
+ vespalib::tensor::TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(stream, values);
+}
+
+
// Explicit instantiation since these are inspected by unit tests.
// FIXME this feels a bit dirty, consider breaking up ArrayParam to remove dependencies
// on templated vector parsing. This is why it's defined in this translation unit as it is.
-template struct ArrayParam<int64_t>;
+template ArrayParam<int64_t>::ArrayParam(const Property & prop);
template struct ArrayParam<double>;
+template struct ArrayParam<float>;
} // namespace dotproduct
@@ -621,9 +630,26 @@ DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectSt
{
attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
}
+ fef::Anything::UP arguments;
+ if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) {
+ Property tensorBlob = env.getProperties().lookup(getBaseName(), _queryVector, "tensor");
+ if (attribute->isFloatingPointType() && tensorBlob.found() && !tensorBlob.get().empty()) {
+ const Property::Value & blob = tensorBlob.get();
+ vespalib::nbostream stream(blob.data(), blob.size());
+ if (attribute->getBasicType() == BasicType::FLOAT) {
+ arguments = std::make_unique<ArrayParam<float>>(stream);
+ } else {
+ arguments = std::make_unique<ArrayParam<double>>(stream);
+ }
+ } else {
+ Property prop = env.getProperties().lookup(getBaseName(), _queryVector);
+ if (prop.found() && !prop.get().empty()) {
+ arguments = attemptParseArrayQueryVector(*attribute, prop);
+ }
+ }
+ }
Property prop = env.getProperties().lookup(getBaseName(), _queryVector);
if (prop.found() && !prop.get().empty()) {
- fef::Anything::UP arguments;
if (attribute->getCollectionType() == attribute::CollectionType::WSET) {
if (attribute->isStringType() && attribute->hasEnum()) {
dotproduct::wset::EnumVector vector(attribute);
@@ -638,13 +664,11 @@ DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectSt
}
}
// TODO actually use the parsed output for wset operations!
- } else if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) {
- arguments = attemptParseArrayQueryVector(*attribute, prop);
- }
- if (arguments.get()) {
- store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments));
}
}
+ if (arguments) {
+ store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments));
+ }
}
}
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
index b6107a1a271..ad40edc49e5 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
@@ -10,6 +10,7 @@
#include <vespa/vespalib/stllike/hash_map.hpp>
namespace search::fef { class Property; }
+namespace vespalib { class nbostream; }
namespace search::features {
@@ -34,6 +35,7 @@ struct Converter<vespalib::string, const char *> {
template <typename T>
struct ArrayParam : public fef::Anything {
ArrayParam(const fef::Property & prop);
+ ArrayParam(vespalib::nbostream & stream);
std::vector<T> values;
std::vector<uint32_t> indexes;
};