diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-10-23 13:21:18 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-10-23 13:21:18 +0000 |
commit | ebfcf420950172ce9a6c65cb2d858a381e5e2485 (patch) | |
tree | 9f0889fa6616895be483d780cf8a4b8b8546ea69 /eval | |
parent | 94e2f9e38d3491391ca6a8c10e151f8f94dbe98b (diff) |
implement new join API
only 'match' is inlined
added new join function to tensor::Tensor
Diffstat (limited to 'eval')
8 files changed, 96 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 35db1ed3b4b..685864d3d5e 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -70,6 +70,11 @@ const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) { return stash.create<DoubleValue>(tensor->sum()); } +template <typename join_fun_t> +const Value &fallback_join(const Value &a, const Value &b, join_fun_t function, Stash &stash) { + return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash); +} + } // namespace vespalib::tensor::<unnamed> const DefaultTensorEngine DefaultTensorEngine::_engine; @@ -212,6 +217,22 @@ struct CellFunctionFunAdapter : tensor::CellFunction { virtual double apply(double value) const override { return fun(value); } }; +struct CellFunctionBindLeftAdapter : tensor::CellFunction { + using join_fun_t = DefaultTensorEngine::join_fun_t; + join_fun_t fun; + double a; + CellFunctionBindLeftAdapter(join_fun_t fun_in, double bound) : fun(fun_in), a(bound) {} + virtual double apply(double b) const override { return fun(a, b); } +}; + +struct CellFunctionBindRightAdapter : tensor::CellFunction { + using join_fun_t = DefaultTensorEngine::join_fun_t; + join_fun_t fun; + double b; + CellFunctionBindRightAdapter(join_fun_t fun_in, double bound) : fun(fun_in), b(bound) {} + virtual double apply(double a) const override { return fun(a, b); } +}; + const eval::Value & DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash) const { @@ -320,7 +341,46 @@ DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const const Value & DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const { - return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash); + if (a.is_double()) { + if (b.is_double()) { + return stash.create<DoubleValue>(function(a.as_double(), b.as_double())); + } else if (auto tensor_b = b.as_tensor()) { + assert(&tensor_b->engine() == this); + const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b); + if (!tensor::Tensor::supported({my_b.getType()})) { + return fallback_join(a, b, function, stash); + } + CellFunctionBindLeftAdapter cell_function(function, a.as_double()); + return to_value(my_b.apply(cell_function), stash); + } else { // error + return b; + } + } else if (auto tensor_a = a.as_tensor()) { + assert(&tensor_a->engine() == this); + const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor_a); + if (b.is_double()) { + if (!tensor::Tensor::supported({my_a.getType()})) { + return fallback_join(a, b, function, stash); + } + CellFunctionBindRightAdapter cell_function(function, b.as_double()); + return to_value(my_a.apply(cell_function), stash); + } else if (auto tensor_b = b.as_tensor()) { + assert(&tensor_b->engine() == this); + const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b); + if (!tensor::Tensor::supported({my_a.getType(), my_b.getType()})) { + return fallback_join(a, b, function, stash); + } + if ((function == eval::operation::Mul::f) && (my_a.getType() == my_b.getType())) { + return to_value(my_a.match(my_b), stash); + } else { + return to_value(my_a.join(function, my_b), stash); + } + } else { // error + return b; + } + } else { // error + return a; + } } const Value & diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp index 1a76ef6475b..e0babad36aa 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp @@ -336,6 +336,14 @@ DenseTensorView::apply(const eval::BinaryOperation &op, const Tensor &arg) const } Tensor::UP +DenseTensorView::join(join_fun_t function, const Tensor &arg) const +{ + return dense::apply(*this, arg, + [function](double lhsValue, double rhsValue) + { return function(lhsValue, rhsValue); }); +} + +Tensor::UP DenseTensorView::reduce(const eval::BinaryOperation &op, const std::vector<vespalib::string> &dimensions) const { diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h index d5d8c8af821..56f26b4ba7c 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h @@ -59,6 +59,8 @@ public: virtual Tensor::UP sum(const vespalib::string &dimension) const override; virtual Tensor::UP apply(const eval::BinaryOperation &op, const Tensor &arg) const override; + virtual Tensor::UP join(join_fun_t function, + const Tensor &arg) const override; virtual Tensor::UP reduce(const eval::BinaryOperation &op, const std::vector<vespalib::string> &dimensions) const override; diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index 92f0b40e259..05fb39f03ab 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -298,6 +298,18 @@ SparseTensor::apply(const eval::BinaryOperation &op, const Tensor &arg) const } Tensor::UP +SparseTensor::join(join_fun_t function, const Tensor &arg) const +{ + const SparseTensor *rhs = dynamic_cast<const SparseTensor *>(&arg); + if (!rhs) { + return Tensor::UP(); + } + return sparse::apply(*this, *rhs, + [function](double lhsValue, double rhsValue) + { return function(lhsValue, rhsValue); }); +} + +Tensor::UP SparseTensor::reduce(const eval::BinaryOperation &op, const std::vector<vespalib::string> &dimensions) const { diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h index f7fd5663dfe..b23cf5fee0d 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h @@ -53,6 +53,8 @@ public: virtual Tensor::UP sum(const vespalib::string &dimension) const override; virtual Tensor::UP apply(const eval::BinaryOperation &op, const Tensor &arg) const override; + virtual Tensor::UP join(join_fun_t function, + const Tensor &arg) const override; virtual Tensor::UP reduce(const eval::BinaryOperation &op, const std::vector<vespalib::string> &dimensions) const override; diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h index 5f657d32df3..8965accf57f 100644 --- a/eval/src/vespa/eval/tensor/tensor.h +++ b/eval/src/vespa/eval/tensor/tensor.h @@ -26,6 +26,7 @@ struct Tensor : public eval::Tensor { typedef std::unique_ptr<Tensor> UP; typedef std::reference_wrapper<const Tensor> CREF; + using join_fun_t = double (*)(double, double); Tensor(); virtual ~Tensor() {} @@ -41,6 +42,8 @@ struct Tensor : public eval::Tensor virtual Tensor::UP sum(const vespalib::string &dimension) const = 0; virtual Tensor::UP apply(const eval::BinaryOperation &op, const Tensor &arg) const = 0; + virtual Tensor::UP join(join_fun_t function, + const Tensor &arg) const = 0; virtual Tensor::UP reduce(const eval::BinaryOperation &op, const std::vector<vespalib::string> &dimensions) const = 0; diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp index 7bf7413bd36..110fb446e95 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp @@ -136,6 +136,13 @@ WrappedSimpleTensor::apply(const eval::BinaryOperation &, const Tensor &) const } Tensor::UP +WrappedSimpleTensor::join(join_fun_t, const Tensor &) const +{ + abort(); + return Tensor::UP(); +} + +Tensor::UP WrappedSimpleTensor::reduce(const eval::BinaryOperation &, const std::vector<vespalib::string> &) const { abort(); diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h index 520ea27c096..68b3b332ef4 100644 --- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h +++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h @@ -46,6 +46,7 @@ public: Tensor::UP apply(const CellFunction &) const override; Tensor::UP sum(const vespalib::string &) const override; Tensor::UP apply(const eval::BinaryOperation &, const Tensor &) const override; + Tensor::UP join(join_fun_t, const Tensor &) const override; Tensor::UP reduce(const eval::BinaryOperation &, const std::vector<vespalib::string> &) const override; }; |