diff options
Diffstat (limited to 'vespalib')
8 files changed, 97 insertions, 0 deletions
diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp b/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp index 120c7207100..ceb72854e66 100644 --- a/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp +++ b/vespalib/src/vespa/vespalib/eval/simple_tensor.cpp @@ -394,6 +394,13 @@ public: constexpr size_t TensorSpec::Label::npos; constexpr size_t SimpleTensor::Label::npos; +SimpleTensor::SimpleTensor(double value) + : Tensor(SimpleTensorEngine::ref()), + _type(ValueType::double_type()), + _cells({Cell({},value)}) +{ +} + SimpleTensor::SimpleTensor(const ValueType &type_in, Cells &&cells_in) : Tensor(SimpleTensorEngine::ref()), _type(type_in), diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor.h b/vespalib/src/vespa/vespalib/eval/simple_tensor.h index 1485a79fd7a..0d1538ffb18 100644 --- a/vespalib/src/vespa/vespalib/eval/simple_tensor.h +++ b/vespalib/src/vespa/vespalib/eval/simple_tensor.h @@ -70,6 +70,7 @@ private: Cells _cells; public: + explicit SimpleTensor(double value); SimpleTensor(const ValueType &type_in, Cells &&cells_in); const ValueType &type() const { return _type; } const Cells &cells() const { return _cells; } diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.cpp b/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.cpp index 9b77f7dfa99..3ba427becb9 100644 --- a/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.cpp +++ b/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.cpp @@ -78,6 +78,15 @@ SimpleTensorEngine::to_spec(const Tensor &tensor) const return spec; } +const SimpleTensor &to_simple(const Value &value, Stash &stash) { + auto tensor = value.as_tensor(); + if (tensor) { + assert(&tensor->engine() == &SimpleTensorEngine::ref()); + return static_cast<const SimpleTensor &>(*tensor); + } + return stash.create<SimpleTensor>(value.as_double()); +} + std::unique_ptr<eval::Tensor> SimpleTensorEngine::create(const TensorSpec &spec) const { @@ -117,5 +126,14 @@ SimpleTensorEngine::apply(const BinaryOperation &op, const eval::Tensor &a, cons return stash.create<TensorValue>(std::move(result)); } +const Value & +SimpleTensorEngine::concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const +{ + const SimpleTensor &simple_a = to_simple(a, stash); + const SimpleTensor &simple_b = to_simple(b, stash); + auto result = SimpleTensor::concat(simple_a, simple_b, dimension); + return stash.create<TensorValue>(std::move(result)); +} + } // namespace vespalib::eval } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.h b/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.h index c3207c440fb..b8791606084 100644 --- a/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.h +++ b/vespalib/src/vespa/vespalib/eval/simple_tensor_engine.h @@ -28,6 +28,8 @@ public: const Value &reduce(const Tensor &tensor, const BinaryOperation &op, const std::vector<vespalib::string> &dimensions, Stash &stash) const override; const Value &map(const UnaryOperation &op, const Tensor &a, Stash &stash) const override; const Value &apply(const BinaryOperation &op, const Tensor &a, const Tensor &b, Stash &stash) const override; + + const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const override; }; } // namespace vespalib::eval diff --git a/vespalib/src/vespa/vespalib/eval/tensor_engine.h b/vespalib/src/vespa/vespalib/eval/tensor_engine.h index 2458da7ff8b..25a38fed69c 100644 --- a/vespalib/src/vespa/vespalib/eval/tensor_engine.h +++ b/vespalib/src/vespa/vespalib/eval/tensor_engine.h @@ -49,6 +49,10 @@ struct TensorEngine virtual const Value &reduce(const Tensor &tensor, const BinaryOperation &op, const std::vector<vespalib::string> &dimensions, Stash &stash) const = 0; virtual const Value &map(const UnaryOperation &op, const Tensor &a, Stash &stash) const = 0; virtual const Value &apply(const BinaryOperation &op, const Tensor &a, const Tensor &b, Stash &stash) const = 0; + + // havardpe: new API, WIP + virtual const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const = 0; + virtual ~TensorEngine() {} }; diff --git a/vespalib/src/vespa/vespalib/eval/test/tensor_conformance.cpp b/vespalib/src/vespa/vespalib/eval/test/tensor_conformance.cpp index 716311d818b..b05d9c5b2ad 100644 --- a/vespalib/src/vespa/vespalib/eval/test/tensor_conformance.cpp +++ b/vespalib/src/vespa/vespalib/eval/test/tensor_conformance.cpp @@ -376,6 +376,14 @@ struct Expr_TT : Eval { } }; +const Value &make_value(const TensorEngine &engine, const TensorSpec &spec, Stash &stash) { + if (spec.type() == "double") { + double number = spec.cells().empty() ? 0.0 : spec.cells().begin()->second.value; + return stash.create<DoubleValue>(number); + } + return stash.create<TensorValue>(engine.create(spec)); +} + // evaluate tensor reduce operation using tensor engine immediate api struct ImmediateReduce : Eval { const BinaryOperation &op; @@ -409,6 +417,18 @@ struct ImmediateApply : Eval { } }; +// evaluate tensor concat operation using tensor engine immediate api +struct ImmediateConcat : Eval { + vespalib::string dimension; + ImmediateConcat(const vespalib::string &dimension_in) : dimension(dimension_in) {} + Result eval(const TensorEngine &engine, const TensorSpec &a, const TensorSpec &b) const override { + Stash stash; + const auto &lhs = make_value(engine, a, stash); + const auto &rhs = make_value(engine, b, stash); + return Result(engine.concat(lhs, rhs, dimension, stash)); + } +}; + const size_t tensor_id_a = 11; const size_t tensor_id_b = 12; const size_t map_operation_id = 22; @@ -1013,6 +1033,38 @@ struct TestContext { //------------------------------------------------------------------------- + void test_concat(const TensorSpec &expect, + const TensorSpec &a, + const TensorSpec &b, + const vespalib::string &dimension) + { + ImmediateConcat eval(dimension); + EXPECT_EQUAL(eval.eval(engine, a, b).tensor(), expect); + } + + void test_concat() { + TEST_DO(test_concat(spec(x(2), Seq({10.0, 20.0})), spec(10.0), spec(20.0), "x")); + TEST_DO(test_concat(spec(x(2), Seq({10.0, 20.0})), spec(x(1), Seq({10.0})), spec(20.0), "x")); + TEST_DO(test_concat(spec(x(2), Seq({10.0, 20.0})), spec(10.0), spec(x(1), Seq({20.0})), "x")); + TEST_DO(test_concat(spec(x(5), Seq({1.0, 2.0, 3.0, 4.0, 5.0})), + spec(x(3), Seq({1.0, 2.0, 3.0})), + spec(x(2), Seq({4.0, 5.0})), "x")); + TEST_DO(test_concat(spec({x(2),y(4)}, Seq({1.0, 2.0, 5.0, 6.0, 3.0, 4.0, 5.0, 6.0})), + spec({x(2),y(2)}, Seq({1.0, 2.0, 3.0, 4.0})), + spec(y(2), Seq({5.0, 6.0})), "y")); + TEST_DO(test_concat(spec({x(4),y(2)}, Seq({1.0, 2.0, 3.0, 4.0, 5.0, 5.0, 6.0, 6.0})), + spec({x(2),y(2)}, Seq({1.0, 2.0, 3.0, 4.0})), + spec(x(2), Seq({5.0, 6.0})), "x")); + TEST_DO(test_concat(spec({x(2),y(2),z(3)}, Seq({1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0})), + spec(z(3), Seq({1.0, 2.0, 3.0})), + spec(y(2), Seq({4.0, 5.0})), "x")); + TEST_DO(test_concat(spec({x(2), y(2)}, Seq({1.0, 2.0, 4.0, 5.0})), + spec(y(3), Seq({1.0, 2.0, 3.0})), + spec(y(2), Seq({4.0, 5.0})), "x")); + } + + //------------------------------------------------------------------------- + void run_tests() { TEST_DO(test_tensor_create_type()); TEST_DO(test_tensor_equality()); @@ -1021,6 +1073,7 @@ struct TestContext { TEST_DO(test_tensor_map()); TEST_DO(test_tensor_apply()); TEST_DO(test_dot_product()); + TEST_DO(test_concat()); } }; diff --git a/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.cpp b/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.cpp index bf1645f848e..13e645a2b48 100644 --- a/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.cpp +++ b/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.cpp @@ -13,6 +13,7 @@ namespace vespalib { namespace tensor { +using Value = eval::Value; using ErrorValue = eval::ErrorValue; using DoubleValue = eval::DoubleValue; using TensorValue = eval::TensorValue; @@ -220,5 +221,14 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten } } +const Value & +DefaultTensorEngine::concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const +{ + (void) a; + (void) b; + (void) dimension; + return stash.create<ErrorValue>(); +} + } // namespace vespalib::tensor } // namespace vespalib diff --git a/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.h b/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.h index 8e6ea39a625..44e4532a6d5 100644 --- a/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.h +++ b/vespalib/src/vespa/vespalib/tensor/default_tensor_engine.h @@ -30,6 +30,8 @@ public: const Value &reduce(const Tensor &tensor, const BinaryOperation &op, const std::vector<vespalib::string> &dimensions, Stash &stash) const override; const Value &map(const UnaryOperation &op, const Tensor &a, Stash &stash) const override; const Value &apply(const BinaryOperation &op, const Tensor &a, const Tensor &b, Stash &stash) const override; + + const Value &concat(const Value &a, const Value &b, const vespalib::string &dimension, Stash &stash) const override; }; } // namespace vespalib::tensor |