summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2019-04-05 12:35:17 +0200
committerGitHub <noreply@github.com>2019-04-05 12:35:17 +0200
commitb138df1052b2d6458ef882bdb1da32896d92b510 (patch)
tree3349fa8db0ca57e916a0a3d8c01302477eaa7522
parentfcc5d2470ee7f091bf08cb952de6f606fcb2fa6b (diff)
parent64669cbe556cbdf4e7c7c084bbd0c89f1923ab07 (diff)
Merge pull request #9020 from vespa-engine/balder/accept-tensor-in-the-dot-tensor-extension
Accept a tensor set down in the dedicated '.tensor' field.
-rw-r--r--searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp68
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.cpp87
-rw-r--r--searchlib/src/vespa/searchlib/features/dotproductfeature.h8
3 files changed, 121 insertions, 42 deletions
diff --git a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
index 42ce9725f91..54c77fb25a7 100644
--- a/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
+++ b/searchlib/src/tests/features/imported_dot_product/imported_dot_product_test.cpp
@@ -6,6 +6,10 @@
#include <vespa/searchlib/fef/test/ftlib.h>
#include <vespa/searchlib/fef/test/rankresult.h>
#include <vespa/searchlib/fef/test/dummy_dependency_handler.h>
+#include <vespa/eval/tensor/tensor.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/eval/tensor/dense/dense_tensor.h>
using namespace search;
using namespace search::attribute;
@@ -104,7 +108,26 @@ struct ArrayFixture : FixtureBase {
}
template <typename ExpectedType>
- void check_prepare_state_output(const vespalib::string& input_vector) {
+ void check_prepare_state_output(const vespalib::tensor::Tensor & tensor, vespalib::tensor::SerializeFormat format, const ExpectedType & expected) {
+ vespalib::nbostream os;
+ vespalib::tensor::TypedBinaryFormat::serialize(os, tensor, format);
+ vespalib::string input_vector(os.c_str(), os.size());
+ check_prepare_state_output(".tensor", input_vector, expected);
+ }
+
+ template <typename ExpectedType>
+ void check_prepare_state_output(const vespalib::string& input_vector, const ExpectedType & expected) {
+ check_prepare_state_output("", input_vector, expected);
+ }
+ template <typename T>
+ static void verify(const dotproduct::ArrayParam<T> & a, const dotproduct::ArrayParam<T> & b) {
+ ASSERT_EQUAL(a.values.size(), b.values.size());
+ for (size_t i(0); i < a.values.size(); i++) {
+ EXPECT_EQUAL(a.values[i], b.values[i]);
+ }
+ }
+ template <typename ExpectedType>
+ void check_prepare_state_output(const vespalib::string & postfix, const vespalib::string& input_vector, const ExpectedType & expected) {
FtFeatureTest feature(_factory, "");
DotProductBlueprint bp;
DummyDependencyHandler dependency_handler(bp);
@@ -116,7 +139,7 @@ struct ArrayFixture : FixtureBase {
FieldType::ATTRIBUTE, schema::CollectionType::ARRAY, imported_attr->getName());
bp.setup(feature.getIndexEnv(), params);
- feature.getQueryEnv().getProperties().add("dotProduct.fancyvector", input_vector);
+ feature.getQueryEnv().getProperties().add("dotProduct.fancyvector" + postfix, input_vector);
auto& obj_store = feature.getQueryEnv().getObjectStore();
bp.prepareSharedState(feature.getQueryEnv(), obj_store);
// Resulting name is very implementation defined. But at least the tests will break if it changes.
@@ -124,13 +147,12 @@ struct ArrayFixture : FixtureBase {
ASSERT_TRUE(parsed != nullptr);
const auto* as_object = dynamic_cast<const ExpectedType*>(parsed);
ASSERT_TRUE(as_object != nullptr);
- // We don't test the parsed output values here; that's the responsibility of other tests.
+ verify(expected, *as_object);
}
- void check_all_float_executions(feature_t expected,
- const vespalib::string& vector,
- DocId doc_id,
- const vespalib::string& shared_param = "") {
+ void check_all_float_executions(feature_t expected, const vespalib::string& vector,
+ DocId doc_id, const vespalib::string& shared_param = "")
+ {
check_executions<double>([this](auto float_type){ this->setup_float_mappings(float_type); },
{{BasicType::FLOAT, BasicType::DOUBLE}},
expected, vector, doc_id, shared_param);
@@ -155,22 +177,46 @@ TEST_F("Zero-length float/double array query vector evaluates to zero", ArrayFix
TEST_F("prepareSharedState emits i64 vector for i32 imported attribute", ArrayFixture) {
f.setup_integer_mappings(BasicType::INT32);
- f.template check_prepare_state_output<dotproduct::ArrayParam<int64_t>>("[101 202 303]");
+ f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303}));
}
TEST_F("prepareSharedState emits i64 vector for i64 imported attribute", ArrayFixture) {
f.setup_integer_mappings(BasicType::INT64);
- f.template check_prepare_state_output<dotproduct::ArrayParam<int64_t>>("[101 202 303]");
+ f.template check_prepare_state_output("[101 202 303]", dotproduct::ArrayParam<int64_t>({101, 202, 303}));
}
TEST_F("prepareSharedState emits double vector for float imported attribute", ArrayFixture) {
f.setup_float_mappings(BasicType::FLOAT);
- f.template check_prepare_state_output<dotproduct::ArrayParam<double>>("[10.1 20.2 30.3]");
+ f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
}
TEST_F("prepareSharedState emits double vector for double imported attribute", ArrayFixture) {
f.setup_float_mappings(BasicType::DOUBLE);
- f.template check_prepare_state_output<dotproduct::ArrayParam<double>>("[10.1 20.2 30.3]");
+ f.template check_prepare_state_output("[10.1 20.2 30.3]", dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as float from tensor for double imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::DOUBLE);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::FLOAT, dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as double from tensor for double imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::DOUBLE);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::DOUBLE, dotproduct::ArrayParam<double>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as float from tensor for float imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::FLOAT);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::FLOAT, dotproduct::ArrayParam<float>({10.1, 20.2, 30.3}));
+}
+
+TEST_F("prepareSharedState handles tensor as double from tensor for float imported attribute", ArrayFixture) {
+ f.setup_float_mappings(BasicType::FLOAT);
+ vespalib::tensor::DenseTensor tensor(vespalib::eval::ValueType::from_spec("tensor(x[3])"), {10.1, 20.2, 30.3});
+ f.template check_prepare_state_output(tensor, vespalib::tensor::SerializeFormat::DOUBLE, dotproduct::ArrayParam<float>({10.1, 20.2, 30.3}));
}
TEST_F("Dense i32/i64 array dot product can be evaluated with pre-parsed object parameter", ArrayFixture) {
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
index dffa3bb28b5..1dcd3e35580 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.cpp
@@ -12,6 +12,8 @@
#include <type_traits>
#include <vespa/log/log.h>
+#include <vespa/eval/tensor/serialization/typed_binary_format.h>
+#include <vespa/vespalib/objects/nbostream.h>
LOG_SETUP(".features.dotproduct");
@@ -340,11 +342,21 @@ ArrayParam<T>::ArrayParam(const Property & prop) {
parseVectors(prop, values, indexes);
}
+template <typename T>
+ArrayParam<T>::ArrayParam(vespalib::nbostream & stream) {
+ vespalib::tensor::TypedBinaryFormat::deserializeCellsOnlyFromDenseTensors(stream, values);
+}
+
+template <typename T>
+ArrayParam<T>::~ArrayParam() = default;
+
+
// Explicit instantiation since these are inspected by unit tests.
// FIXME this feels a bit dirty, consider breaking up ArrayParam to remove dependencies
// on templated vector parsing. This is why it's defined in this translation unit as it is.
-template struct ArrayParam<int64_t>;
+template ArrayParam<int64_t>::ArrayParam(const Property & prop);
template struct ArrayParam<double>;
+template struct ArrayParam<float>;
} // namespace dotproduct
@@ -609,43 +621,63 @@ fef::Anything::UP attemptParseArrayQueryVector(const IAttributeVector & attribut
} // anon ns
+const IAttributeVector *
+DotProductBlueprint::upgradeIfNecessary(const IAttributeVector * attribute, const IQueryEnvironment & env) const {
+ if ((attribute->getCollectionType() == attribute::CollectionType::WSET) &&
+ attribute->hasEnum() &&
+ (attribute->isStringType() || attribute->isIntegerType()))
+ {
+ attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
+ }
+ return attribute;
+}
+
void
DotProductBlueprint::prepareSharedState(const IQueryEnvironment & env, IObjectStore & store) const
{
_attribute = env.getAttributeContext().getAttribute(getAttribute(env));
const IAttributeVector * attribute = _attribute;
- if (attribute != nullptr) {
- if ((attribute->getCollectionType() == attribute::CollectionType::WSET) &&
- attribute->hasEnum() &&
- (attribute->isStringType() || attribute->isIntegerType()))
- {
- attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
+ if (attribute == nullptr) return;
+
+ attribute = upgradeIfNecessary(attribute, env);
+ fef::Anything::UP arguments;
+ if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) {
+ Property tensorBlob = env.getProperties().lookup(getBaseName(), _queryVector, "tensor");
+ if (attribute->isFloatingPointType() && tensorBlob.found() && !tensorBlob.get().empty()) {
+ const Property::Value & blob = tensorBlob.get();
+ vespalib::nbostream stream(blob.data(), blob.size());
+ if (attribute->getBasicType() == BasicType::FLOAT) {
+ arguments = std::make_unique<ArrayParam<float>>(stream);
+ } else {
+ arguments = std::make_unique<ArrayParam<double>>(stream);
+ }
+ } else {
+ Property prop = env.getProperties().lookup(getBaseName(), _queryVector);
+ if (prop.found() && !prop.get().empty()) {
+ arguments = attemptParseArrayQueryVector(*attribute, prop);
+ }
}
+ } else if (attribute->getCollectionType() == attribute::CollectionType::WSET) {
Property prop = env.getProperties().lookup(getBaseName(), _queryVector);
if (prop.found() && !prop.get().empty()) {
- fef::Anything::UP arguments;
- if (attribute->getCollectionType() == attribute::CollectionType::WSET) {
- if (attribute->isStringType() && attribute->hasEnum()) {
+ if (attribute->isStringType() && attribute->hasEnum()) {
+ dotproduct::wset::EnumVector vector(attribute);
+ WeightedSetParser::parse(prop.get(), vector);
+ } else if (attribute->isIntegerType()) {
+ if (attribute->hasEnum()) {
dotproduct::wset::EnumVector vector(attribute);
WeightedSetParser::parse(prop.get(), vector);
- } else if (attribute->isIntegerType()) {
- if (attribute->hasEnum()) {
- dotproduct::wset::EnumVector vector(attribute);
- WeightedSetParser::parse(prop.get(), vector);
- } else {
- dotproduct::wset::IntegerVector vector;
- WeightedSetParser::parse(prop.get(), vector);
- }
+ } else {
+ dotproduct::wset::IntegerVector vector;
+ WeightedSetParser::parse(prop.get(), vector);
}
- // TODO actually use the parsed output for wset operations!
- } else if (attribute->getCollectionType() == attribute::CollectionType::ARRAY) {
- arguments = attemptParseArrayQueryVector(*attribute, prop);
- }
- if (arguments.get()) {
- store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments));
}
+ // TODO actually use the parsed output for wset operations!
}
}
+ if (arguments) {
+ store.add(getBaseName() + "." + _queryVector + "." + OBJECT, std::move(arguments));
+ }
}
FeatureExecutor &
@@ -657,12 +689,7 @@ DotProductBlueprint::createExecutor(const IQueryEnvironment & env, vespalib::Sta
getAttribute(env).c_str());
return stash.create<SingleZeroValueExecutor>();
}
- if ((attribute->getCollectionType() == attribute::CollectionType::WSET) &&
- attribute->hasEnum() &&
- (attribute->isStringType() || attribute->isIntegerType()))
- {
- attribute = env.getAttributeContext().getAttributeStableEnum(getAttribute(env));
- }
+ attribute = upgradeIfNecessary(attribute, env);
const fef::Anything * argument = env.getObjectStore().get(getBaseName() + "." + _queryVector + "." + OBJECT);
if (argument != nullptr) {
return createFromObject(attribute, *argument, stash);
diff --git a/searchlib/src/vespa/searchlib/features/dotproductfeature.h b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
index b6107a1a271..089066cb5f6 100644
--- a/searchlib/src/vespa/searchlib/features/dotproductfeature.h
+++ b/searchlib/src/vespa/searchlib/features/dotproductfeature.h
@@ -10,6 +10,7 @@
#include <vespa/vespalib/stllike/hash_map.hpp>
namespace search::fef { class Property; }
+namespace vespalib { class nbostream; }
namespace search::features {
@@ -34,6 +35,9 @@ struct Converter<vespalib::string, const char *> {
template <typename T>
struct ArrayParam : public fef::Anything {
ArrayParam(const fef::Property & prop);
+ ArrayParam(vespalib::nbostream & stream);
+ ArrayParam(std::vector<T> v) : values(std::move(v)) {}
+ ~ArrayParam() override;
std::vector<T> values;
std::vector<uint32_t> indexes;
};
@@ -260,12 +264,14 @@ private:
*/
class DotProductBlueprint : public fef::Blueprint {
private:
+ using IAttributeVector = attribute::IAttributeVector;
vespalib::string _defaultAttribute;
vespalib::string _queryVector;
- mutable const attribute::IAttributeVector * _attribute;
+ mutable const IAttributeVector * _attribute;
vespalib::string getAttribute(const fef::IQueryEnvironment & env) const;
+ const IAttributeVector * upgradeIfNecessary(const IAttributeVector * attribute, const fef::IQueryEnvironment & env) const;
public:
DotProductBlueprint();