summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2017-11-28 14:59:09 +0000
committerHåvard Pettersen <havardpe@oath.com>2017-11-28 14:59:09 +0000
commitc9b2eaee09ac5ce0a9e1f1b7f3196e13ed192750 (patch)
tree10314ffc37d0a728f20daf535fb5cda90c495f12 /eval
parent4c6182ee90541283361c08b24d80f8aaf3d843c2 (diff)
internalize tensor operation special handling
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp12
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp35
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_view.h3
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp37
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h3
-rw-r--r--eval/src/vespa/eval/tensor/tensor.h10
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp21
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.h3
8 files changed, 22 insertions, 102 deletions
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 4d9ee7cb6f5..efdcfa47b56 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -260,15 +260,7 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S
if (!tensor::Tensor::supported({my_a.type(), my_b.type()})) {
return fallback_join(a, b, function, stash);
}
- if (function == eval::operation::Mul::f) {
- if (my_a.type() == my_b.type()) {
- return to_value(my_a.match(my_b), stash);
- } else {
- return to_value(my_a.multiply(my_b), stash);
- }
- } else {
- return to_value(my_a.join(function, my_b), stash);
- }
+ return to_value(my_a.join(function, my_b), stash);
} else {
return ErrorValue::instance;
}
@@ -299,8 +291,6 @@ DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespali
case Aggr::SUM:
if (dimensions.empty()) {
return stash.create<eval::DoubleValue>(my_a.as_double());
- } else if (dimensions.size() == 1) {
- return to_value(my_a.sum(dimensions[0]), stash);
} else {
return to_value(my_a.reduce(eval::operation::Add::f, dimensions), stash);
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
index 3f647975154..eba2452c41d 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.cpp
@@ -161,22 +161,6 @@ DenseTensorView::as_double() const
}
Tensor::UP
-DenseTensorView::multiply(const Tensor &arg) const
-{
- return dense::apply(*this, arg,
- [](double lhsValue, double rhsValue)
- { return lhsValue * rhsValue; });
-}
-
-Tensor::UP
-DenseTensorView::match(const Tensor &arg) const
-{
- return joinDenseTensors(*this, arg, "match",
- [](double lhsValue, double rhsValue)
- { return (lhsValue * rhsValue); });
-}
-
-Tensor::UP
DenseTensorView::apply(const CellFunction &func) const
{
Cells newCells(_cellsRef.size());
@@ -189,14 +173,6 @@ DenseTensorView::apply(const CellFunction &func) const
return std::make_unique<DenseTensor>(_typeRef, std::move(newCells));
}
-Tensor::UP
-DenseTensorView::sum(const vespalib::string &dimension) const
-{
- return dense::reduce(*this, { dimension },
- [](double lhsValue, double rhsValue)
- { return lhsValue + rhsValue; });
-}
-
bool
DenseTensorView::equals(const Tensor &arg) const
{
@@ -265,6 +241,17 @@ DenseTensorView::accept(TensorVisitor &visitor) const
Tensor::UP
DenseTensorView::join(join_fun_t function, const Tensor &arg) const
{
+ if (function == eval::operation::Mul::f) {
+ if (fast_type() == arg.type()) {
+ return joinDenseTensors(*this, arg, "match",
+ [](double lhsValue, double rhsValue)
+ { return (lhsValue * rhsValue); });
+ } else {
+ return dense::apply(*this, arg,
+ [](double lhsValue, double rhsValue)
+ { return lhsValue * rhsValue; });
+ }
+ }
return dense::apply(*this, arg, function);
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
index 95480210d44..5a59594667d 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_view.h
@@ -49,10 +49,7 @@ public:
virtual const eval::ValueType &type() const override;
virtual double as_double() const override;
- virtual Tensor::UP multiply(const Tensor &arg) const override;
- virtual Tensor::UP match(const Tensor &arg) const override;
virtual Tensor::UP apply(const CellFunction &func) const override;
- virtual Tensor::UP sum(const vespalib::string &dimension) const override;
virtual Tensor::UP join(join_fun_t function,
const Tensor &arg) const override;
virtual Tensor::UP reduce(join_fun_t op,
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index 52d171813ea..4762f1eceb4 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -93,40 +93,11 @@ SparseTensor::as_double() const
}
Tensor::UP
-SparseTensor::multiply(const Tensor &arg) const
-{
- const SparseTensor *rhs = dynamic_cast<const SparseTensor *>(&arg);
- if (!rhs) {
- return Tensor::UP();
- }
- return sparse::apply(*this, *rhs, [](double lhsValue, double rhsValue)
- { return lhsValue * rhsValue; });
-}
-
-Tensor::UP
-SparseTensor::match(const Tensor &arg) const
-{
- const SparseTensor *rhs = dynamic_cast<const SparseTensor *>(&arg);
- if (!rhs) {
- return Tensor::UP();
- }
- return SparseTensorMatch(*this, *rhs).result();
-}
-
-Tensor::UP
SparseTensor::apply(const CellFunction &func) const
{
return TensorApply<SparseTensor>(*this, func).result();
}
-Tensor::UP
-SparseTensor::sum(const vespalib::string &dimension) const
-{
- return sparse::reduce(*this, { dimension },
- [](double lhsValue, double rhsValue)
- { return lhsValue + rhsValue; });
-}
-
bool
SparseTensor::equals(const Tensor &arg) const
{
@@ -203,6 +174,14 @@ SparseTensor::join(join_fun_t function, const Tensor &arg) const
if (!rhs) {
return Tensor::UP();
}
+ if (function == eval::operation::Mul::f) {
+ if (fast_type() == rhs->fast_type()) {
+ return SparseTensorMatch(*this, *rhs).result();
+ } else {
+ return sparse::apply(*this, *rhs, [](double lhsValue, double rhsValue)
+ { return lhsValue * rhsValue; });
+ }
+ }
return sparse::apply(*this, *rhs, function);
}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index c1c17906fc5..c7c38f0a182 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -43,10 +43,7 @@ public:
virtual const eval::ValueType &type() const override;
virtual double as_double() const override;
- virtual Tensor::UP multiply(const Tensor &arg) const override;
- virtual Tensor::UP match(const Tensor &arg) const override;
virtual Tensor::UP apply(const CellFunction &func) const override;
- virtual Tensor::UP sum(const vespalib::string &dimension) const override;
virtual Tensor::UP join(join_fun_t function,
const Tensor &arg) const override;
virtual Tensor::UP reduce(join_fun_t op,
diff --git a/eval/src/vespa/eval/tensor/tensor.h b/eval/src/vespa/eval/tensor/tensor.h
index 5972f3a7957..8e31448e026 100644
--- a/eval/src/vespa/eval/tensor/tensor.h
+++ b/eval/src/vespa/eval/tensor/tensor.h
@@ -30,15 +30,9 @@ struct Tensor : public eval::Tensor
Tensor();
virtual ~Tensor() {}
- virtual Tensor::UP multiply(const Tensor &arg) const = 0;
- virtual Tensor::UP match(const Tensor &arg) const = 0;
virtual Tensor::UP apply(const CellFunction &func) const = 0;
- virtual Tensor::UP sum(const vespalib::string &dimension) const = 0;
- virtual Tensor::UP join(join_fun_t function,
- const Tensor &arg) const = 0;
- virtual Tensor::UP reduce(join_fun_t op,
- const std::vector<vespalib::string> &dimensions)
- const = 0;
+ virtual Tensor::UP join(join_fun_t function, const Tensor &arg) const = 0;
+ virtual Tensor::UP reduce(join_fun_t op, const std::vector<vespalib::string> &dimensions) const = 0;
virtual bool equals(const Tensor &arg) const = 0; // want to remove, but needed by document
virtual Tensor::UP clone() const = 0; // want to remove, but needed by document
virtual eval::TensorSpec toSpec() const = 0;
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 50c567fa6ea..463105b7c1f 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -57,20 +57,6 @@ WrappedSimpleTensor::clone() const
//-----------------------------------------------------------------------------
Tensor::UP
-WrappedSimpleTensor::multiply(const Tensor &) const
-{
- abort();
- return Tensor::UP();
-}
-
-Tensor::UP
-WrappedSimpleTensor::match(const Tensor &) const
-{
- abort();
- return Tensor::UP();
-}
-
-Tensor::UP
WrappedSimpleTensor::apply(const CellFunction &) const
{
abort();
@@ -78,13 +64,6 @@ WrappedSimpleTensor::apply(const CellFunction &) const
}
Tensor::UP
-WrappedSimpleTensor::sum(const vespalib::string &) const
-{
- abort();
- return Tensor::UP();
-}
-
-Tensor::UP
WrappedSimpleTensor::join(join_fun_t, const Tensor &) const
{
abort();
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
index fdaf86459da..ae7907845e1 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.h
@@ -35,10 +35,7 @@ public:
void accept(TensorVisitor &visitor) const override;
Tensor::UP clone() const override;
// functions below should not be used for this implementation
- Tensor::UP multiply(const Tensor &) const override;
- Tensor::UP match(const Tensor &) const override;
Tensor::UP apply(const CellFunction &) const override;
- Tensor::UP sum(const vespalib::string &) const override;
Tensor::UP join(join_fun_t, const Tensor &) const override;
Tensor::UP reduce(join_fun_t, const std::vector<vespalib::string> &) const override;
};