diff options
author | Henning Baldersheim <balder@oath.com> | 2019-04-04 18:16:59 +0000 |
---|---|---|
committer | Henning Baldersheim <balder@oath.com> | 2019-04-04 18:16:59 +0000 |
commit | 39f61af291ed2ba600bdada669e0a2111df80246 (patch) | |
tree | a4ce4a612b905b3fd9b1f8cb033b8ccef8e093dd /searchlib | |
parent | e081ec72d214c7a3322c8ecd62937af5ca49e1e1 (diff) |
Accept a tensor set down in the dedicated '.tensor' field.
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.cpp | 38 | ||||
-rw-r--r-- | searchlib/src/vespa/searchlib/features/dotproductfeature.h | 2 |
2 files changed, 33 insertions, 7 deletions
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp index dffa3bb28b5..06c653b9a01 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp @@ -12,6 +12,8 @@ #include <type_traits> #include <vespa/log/log.h> +#include <vespa/eval/tensor/serialization/typed_binary_format.h> +#include <vespa/vespalib/objects/nbostream.h> LOG_SETUP(".features.dotproduct"); @@ -340,11 +342,18 @@ ArrayParam<T>::ArrayParam(const Property & prop) { parseVectors(prop, values, indexes); } +template <typename T> +ArrayParam<T>::ArrayParam(vespalib::nbostream & stream) { + vespalib::tensor::TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(stream, values); +} + + // Explicit instantiation since these are inspected by unit tests. // FIXME this feels a bit dirty, consider breaking up ArrayParam to remove dependencies // on templated vector parsing. This is why it's defined in this translation unit as it is. -template struct ArrayParam<int64_t>; +template ArrayParam<int64_t>::ArrayParam(const Property & prop); template struct ArrayParam<double>; +template struct ArrayParam<float>; } // namespace dotproduct @@ -621,9 +630,26 @@ DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectSt { attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env)); } + fef::Anything::UP arguments; + if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) { + Property tensorBlob = env.getProperties().lookup(getBaseName(), _queryVector, "tensor"); + if (attribute->isFloatingPointType() && tensorBlob.found() && !tensorBlob.get().empty()) { + const Property::Value & blob = tensorBlob.get(); + vespalib::nbostream stream(blob.data(), blob.size()); + if (attribute->getBasicType() == BasicType::FLOAT) { + arguments = std::make_unique<ArrayParam<float>>(stream); + } else { + arguments = std::make_unique<ArrayParam<double>>(stream); + } + } else { + Property prop = env.getProperties().lookup(getBaseName(), _queryVector); + if (prop.found() && !prop.get().empty()) { + arguments = attemptParseArrayQueryVector(*attribute, prop); + } + } + } Property prop = env.getProperties().lookup(getBaseName(), _queryVector); if (prop.found() && !prop.get().empty()) { - fef::Anything::UP arguments; if (attribute->getCollectionType() == attribute::CollectionType::WSET) { if (attribute->isStringType() && attribute->hasEnum()) { dotproduct::wset::EnumVector vector(attribute); @@ -638,13 +664,11 @@ DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectSt } } // TODO actually use the parsed output for wset operations! - } else if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) { - arguments = attemptParseArrayQueryVector(*attribute, prop); - } - if (arguments.get()) { - store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments)); } } + if (arguments) { + store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments)); + } } } diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h index b6107a1a271..ad40edc49e5 100644 --- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h +++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h @@ -10,6 +10,7 @@ #include <vespa/vespalib/stllike/hash_map.hpp> namespace search::fef { class Property; } +namespace vespalib { class nbostream; } namespace search::features { @@ -34,6 +35,7 @@ struct Converter<vespalib::string, const char *> { template <typename T> struct ArrayParam : public fef::Anything { ArrayParam(const fef::Property & prop); + ArrayParam(vespalib::nbostream & stream); std::vector<T> values; std::vector<uint32_t> indexes; }; |