summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@gmail.com>2017-10-24 20:28:32 +0200
committerGitHub <noreply@github.com>2017-10-24 20:28:32 +0200
commitbfb1dbd6340fc65afb7037d5caa3dfc50bf8febd (patch)
tree643761d03cc82bb8a93bbbb8951e05fac536304f
parent3048921b8f9160ecd1164086abc11030b3a2c788 (diff)
parent0f63374903e65817942b46f090013a2613426bc8 (diff)
Merge pull request #3870 from vespa-engine/revert-3869-revert-3859-havardpe/less-fallback-in-reduce
Revert "Revert "implement new 'reduce' API in DefaultTensorEngine""
-rw-r--r--eval/src/vespa/eval/eval/value.cpp2
-rw-r--r--eval/src/vespa/eval/eval/value.h1
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp84
3 files changed, 59 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp
index e601c2266e7..d5111187157 100644
--- a/eval/src/vespa/eval/eval/value.cpp
+++ b/eval/src/vespa/eval/eval/value.cpp
@@ -19,6 +19,8 @@ Value::apply(const BinaryOperation &, const Value &, Stash &stash) const
return stash.create<ErrorValue>();
}
+ErrorValue ErrorValue::instance;
+
bool
TensorValue::equal(const Value &rhs) const
{
diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h
index f78242863d2..45b4f59ecd5 100644
--- a/eval/src/vespa/eval/eval/value.h
+++ b/eval/src/vespa/eval/eval/value.h
@@ -39,6 +39,7 @@ struct Value {
};
struct ErrorValue : public Value {
+ static ErrorValue instance;
bool is_error() const override { return true; }
double as_double() const override { return error_value; }
bool equal(const Value &) const override { return false; }
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 685864d3d5e..771a457509c 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -18,12 +18,17 @@
namespace vespalib {
namespace tensor {
-using Value = eval::Value;
-using ValueType = eval::ValueType;
-using ErrorValue = eval::ErrorValue;
-using DoubleValue = eval::DoubleValue;
-using TensorValue = eval::TensorValue;
-using TensorSpec = eval::TensorSpec;
+using eval::Aggr;
+using eval::Aggregator;
+using eval::DoubleValue;
+using eval::ErrorValue;
+using eval::TensorSpec;
+using eval::TensorValue;
+using eval::Value;
+using eval::ValueType;
+
+using map_fun_t = eval::TensorEngine::map_fun_t;
+using join_fun_t = eval::TensorEngine::join_fun_t;
namespace {
@@ -64,17 +69,23 @@ const Value &to_default(const Value &value, Stash &stash) {
}
const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) {
+ if (!tensor) {
+ return ErrorValue::instance;
+ }
if (tensor->getType().is_tensor()) {
return stash.create<TensorValue>(std::move(tensor));
}
return stash.create<DoubleValue>(tensor->sum());
}
-template <typename join_fun_t>
const Value &fallback_join(const Value &a, const Value &b, join_fun_t function, Stash &stash) {
return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash);
}
+const Value &fallback_reduce(const Value &a, eval::Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) {
+ return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash);
+}
+
} // namespace vespalib::tensor::<unnamed>
const DefaultTensorEngine DefaultTensorEngine::_engine;
@@ -198,10 +209,7 @@ DefaultTensorEngine::reduce(const Tensor &tensor, const BinaryOperation &op, con
} else {
result = my_tensor.reduce(op, dimensions);
}
- if (result) {
- return to_value(std::move(result), stash);
- }
- return stash.create<ErrorValue>();
+ return to_value(std::move(result), stash);
}
struct CellFunctionOpAdapter : tensor::CellFunction {
@@ -211,14 +219,12 @@ struct CellFunctionOpAdapter : tensor::CellFunction {
};
struct CellFunctionFunAdapter : tensor::CellFunction {
- using map_fun_t = DefaultTensorEngine::map_fun_t;
map_fun_t fun;
CellFunctionFunAdapter(map_fun_t fun_in) : fun(fun_in) {}
virtual double apply(double value) const override { return fun(value); }
};
struct CellFunctionBindLeftAdapter : tensor::CellFunction {
- using join_fun_t = DefaultTensorEngine::join_fun_t;
join_fun_t fun;
double a;
CellFunctionBindLeftAdapter(join_fun_t fun_in, double bound) : fun(fun_in), a(bound) {}
@@ -226,7 +232,6 @@ struct CellFunctionBindLeftAdapter : tensor::CellFunction {
};
struct CellFunctionBindRightAdapter : tensor::CellFunction {
- using join_fun_t = DefaultTensorEngine::join_fun_t;
join_fun_t fun;
double b;
CellFunctionBindRightAdapter(join_fun_t fun_in, double bound) : fun(fun_in), b(bound) {}
@@ -293,11 +298,7 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten
}
TensorOperationOverride tensor_override(my_a, my_b);
op.accept(tensor_override);
- if (tensor_override.result) {
- return to_value(std::move(tensor_override.result), stash);
- } else {
- return stash.create<ErrorValue>();
- }
+ return to_value(std::move(tensor_override.result), stash);
}
//-----------------------------------------------------------------------------
@@ -333,8 +334,8 @@ DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const
}
CellFunctionFunAdapter cell_function(function);
return to_value(my_a.apply(cell_function), stash);
- } else { // error
- return a;
+ } else {
+ return ErrorValue::instance;
}
}
@@ -352,8 +353,8 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S
}
CellFunctionBindLeftAdapter cell_function(function, a.as_double());
return to_value(my_b.apply(cell_function), stash);
- } else { // error
- return b;
+ } else {
+ return ErrorValue::instance;
}
} else if (auto tensor_a = a.as_tensor()) {
assert(&tensor_a->engine() == this);
@@ -375,18 +376,45 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S
} else {
return to_value(my_a.join(function, my_b), stash);
}
- } else { // error
- return b;
+ } else {
+ return ErrorValue::instance;
}
- } else { // error
- return a;
+ } else {
+ return ErrorValue::instance;
}
}
const Value &
DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const
{
- return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash);
+ if (a.is_double()) {
+ Aggregator &aggregator = Aggregator::create(aggr, stash);
+ aggregator.first(a.as_double());
+ return stash.create<DoubleValue>(aggregator.result());
+ } else if (auto tensor = a.as_tensor()) {
+ assert(&tensor->engine() == this);
+ const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor);
+ if (!tensor::Tensor::supported({my_a.getType()})) {
+ return fallback_reduce(a, aggr, dimensions, stash);
+ }
+ switch (aggr) {
+ case Aggr::PROD: return to_value(my_a.reduce(eval::operation::Mul(), dimensions), stash);
+ case Aggr::SUM:
+ if (dimensions.empty()) {
+ return stash.create<eval::DoubleValue>(my_a.sum());
+ } else if (dimensions.size() == 1) {
+ return to_value(my_a.sum(dimensions[0]), stash);
+ } else {
+ return to_value(my_a.reduce(eval::operation::Add(), dimensions), stash);
+ }
+ case Aggr::MAX: return to_value(my_a.reduce(eval::operation::Max(), dimensions), stash);
+ case Aggr::MIN: return to_value(my_a.reduce(eval::operation::Min(), dimensions), stash);
+ default:
+ return fallback_reduce(a, aggr, dimensions, stash);
+ }
+ } else {
+ return ErrorValue::instance;
+ }
}
const Value &