summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-10-23 13:21:18 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-10-23 13:21:18 +0000
commitebfcf420950172ce9a6c65cb2d858a381e5e2485 (patch)
tree9f0889fa6616895be483d780cf8a4b8b8546ea69 /eval
parent94e2f9e38d3491391ca6a8c10e151f8f94dbe98b (diff)
implement new join API
only 'match' is inlined added new join function to tensor::Tensor
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp62
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp8
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h2
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp12
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h2
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h3
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp7
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.h1
8 files changed, 96 insertions, 1 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 35db1ed3b4b..685864d3d5e 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -70,6 +70,11 @@ const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) {
return stash.create<DoubleValue>(tensor->sum());
}
+template <typename join_fun_t>
+const Value &fallback_join(const Value &a, const Value &b, join_fun_t function, Stash &stash) {
+ return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash);
+}
+
} // namespace vespalib::tensor::<unnamed>
const DefaultTensorEngine DefaultTensorEngine::_engine;
@@ -212,6 +217,22 @@ struct CellFunctionFunAdapter : tensor::CellFunction {
virtual double apply(double value) const override { return fun(value); }
};
+struct CellFunctionBindLeftAdapter : tensor::CellFunction {
+ using join_fun_t = DefaultTensorEngine::join_fun_t;
+ join_fun_t fun;
+ double a;
+ CellFunctionBindLeftAdapter(join_fun_t fun_in, double bound) : fun(fun_in), a(bound) {}
+ virtual double apply(double b) const override { return fun(a, b); }
+};
+
+struct CellFunctionBindRightAdapter : tensor::CellFunction {
+ using join_fun_t = DefaultTensorEngine::join_fun_t;
+ join_fun_t fun;
+ double b;
+ CellFunctionBindRightAdapter(join_fun_t fun_in, double bound) : fun(fun_in), b(bound) {}
+ virtual double apply(double a) const override { return fun(a, b); }
+};
+
const eval::Value &
DefaultTensorEngine::map(const UnaryOperation &op, const Tensor &a, Stash &stash) const
{
@@ -320,7 +341,46 @@ DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const
const Value &
DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, Stash &stash) const
{
- return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash);
+ if (a.is_double()) {
+ if (b.is_double()) {
+ return stash.create<DoubleValue>(function(a.as_double(), b.as_double()));
+ } else if (auto tensor_b = b.as_tensor()) {
+ assert(&tensor_b->engine() == this);
+ const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b);
+ if (!tensor::Tensor::supported({my_b.getType()})) {
+ return fallback_join(a, b, function, stash);
+ }
+ CellFunctionBindLeftAdapter cell_function(function, a.as_double());
+ return to_value(my_b.apply(cell_function), stash);
+ } else { // error
+ return b;
+ }
+ } else if (auto tensor_a = a.as_tensor()) {
+ assert(&tensor_a->engine() == this);
+ const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor_a);
+ if (b.is_double()) {
+ if (!tensor::Tensor::supported({my_a.getType()})) {
+ return fallback_join(a, b, function, stash);
+ }
+ CellFunctionBindRightAdapter cell_function(function, b.as_double());
+ return to_value(my_a.apply(cell_function), stash);
+ } else if (auto tensor_b = b.as_tensor()) {
+ assert(&tensor_b->engine() == this);
+ const tensor::Tensor &my_b = static_cast<const tensor::Tensor &>(*tensor_b);
+ if (!tensor::Tensor::supported({my_a.getType(), my_b.getType()})) {
+ return fallback_join(a, b, function, stash);
+ }
+ if ((function == eval::operation::Mul::f) && (my_a.getType() == my_b.getType())) {
+ return to_value(my_a.match(my_b), stash);
+ } else {
+ return to_value(my_a.join(function, my_b), stash);
+ }
+ } else { // error
+ return b;
+ }
+ } else { // error
+ return a;
+ }
}
const Value &
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index 1a76ef6475b..e0babad36aa 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -336,6 +336,14 @@ DenseTensorView::apply(const eval::BinaryOperation &op, const Tensor &arg) const
}
Tensor::UP
+DenseTensorView::join(join_fun_t function, const Tensor &arg) const
+{
+ return dense::apply(*this, arg,
+ [function](double lhsValue, double rhsValue)
+ { return function(lhsValue, rhsValue); });
+}
+
+Tensor::UP
DenseTensorView::reduce(const eval::BinaryOperation &op,
const std::vector<vespalib::string> &dimensions) const
{
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index d5d8c8af821..56f26b4ba7c 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -59,6 +59,8 @@ public:
virtual Tensor::UP sum(const vespalib::string &dimension) const override;
virtual Tensor::UP apply(const eval::BinaryOperation &op,
const Tensor &arg) const override;
+ virtual Tensor::UP join(join_fun_t function,
+ const Tensor &arg) const override;
virtual Tensor::UP reduce(const eval::BinaryOperation &op,
const std::vector<vespalib::string> &dimensions)
const override;
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index 92f0b40e259..05fb39f03ab 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -298,6 +298,18 @@ SparseTensor::apply(const eval::BinaryOperation &op, const Tensor &arg) const
}
Tensor::UP
+SparseTensor::join(join_fun_t function, const Tensor &arg) const
+{
+ const SparseTensor *rhs = dynamic_cast<const SparseTensor *>(&arg);
+ if (!rhs) {
+ return Tensor::UP();
+ }
+ return sparse::apply(*this, *rhs,
+ [function](double lhsValue, double rhsValue)
+ { return function(lhsValue, rhsValue); });
+}
+
+Tensor::UP
SparseTensor::reduce(const eval::BinaryOperation &op,
const std::vector<vespalib::string> &dimensions) const
{
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index f7fd5663dfe..b23cf5fee0d 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -53,6 +53,8 @@ public:
virtual Tensor::UP sum(const vespalib::string &dimension) const override;
virtual Tensor::UP apply(const eval::BinaryOperation &op,
const Tensor &arg) const override;
+ virtual Tensor::UP join(join_fun_t function,
+ const Tensor &arg) const override;
virtual Tensor::UP reduce(const eval::BinaryOperation &op,
const std::vector<vespalib::string> &dimensions)
const override;
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index 5f657d32df3..8965accf57f 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -26,6 +26,7 @@ struct Tensor : public eval::Tensor
{
typedef std::unique_ptr<Tensor> UP;
typedef std::reference_wrapper<const Tensor> CREF;
+ using join_fun_t = double (*)(double, double);
Tensor();
virtual ~Tensor() {}
@@ -41,6 +42,8 @@ struct Tensor : public eval::Tensor
virtual Tensor::UP sum(const vespalib::string &dimension) const = 0;
virtual Tensor::UP apply(const eval::BinaryOperation &op,
const Tensor &arg) const = 0;
+ virtual Tensor::UP join(join_fun_t function,
+ const Tensor &arg) const = 0;
virtual Tensor::UP reduce(const eval::BinaryOperation &op,
const std::vector<vespalib::string> &dimensions)
const = 0;
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 7bf7413bd36..110fb446e95 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -136,6 +136,13 @@ WrappedSimpleTensor::apply(const eval::BinaryOperation &, const Tensor &) const
}
Tensor::UP
+WrappedSimpleTensor::join(join_fun_t, const Tensor &) const
+{
+ abort();
+ return Tensor::UP();
+}
+
+Tensor::UP
WrappedSimpleTensor::reduce(const eval::BinaryOperation &, const std::vector<vespalib::string> &) const
{
abort();
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
index 520ea27c096..68b3b332ef4 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
@@ -46,6 +46,7 @@ public:
Tensor::UP apply(const CellFunction &) const override;
Tensor::UP sum(const vespalib::string &) const override;
Tensor::UP apply(const eval::BinaryOperation &, const Tensor &) const override;
+ Tensor::UP join(join_fun_t, const Tensor &) const override;
Tensor::UP reduce(const eval::BinaryOperation &, const std::vector<vespalib::string> &) const override;
};