diff options
author | Håvard Pettersen <havardpe@oath.com> | 2017-10-24 09:46:00 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2017-10-24 09:46:00 +0000 |
commit | b4c4e5161f0b99a28f541b0c5eee9d673053e236 (patch) | |
tree | 3519199bb722bb57da1c5833848b503265ac7b1d /eval | |
parent | c8e3f69ed929b5e5f48176feafa75d4c422147cc (diff) |
implement new 'reduce' API in DefaultTensorEngine
add a singleton instance of ErrorValue, and use that instead of
creating new instances many places
move null pointer checks inside to_value()
simplify using statements
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/vespa/eval/eval/value.cpp | 2 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/value.h | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/tensor/default_tensor_engine.cpp | 84 |
3 files changed, 59 insertions, 28 deletions
diff --git a/eval/src/vespa/eval/eval/value.cpp b/eval/src/vespa/eval/eval/value.cpp index e601c2266e7..d5111187157 100644 --- a/eval/src/vespa/eval/eval/value.cpp +++ b/eval/src/vespa/eval/eval/value.cpp @@ -19,6 +19,8 @@ Value::apply(const BinaryOperation &, const Value &, Stash &stash) const return stash.create<ErrorValue>(); } +ErrorValue ErrorValue::instance; + bool TensorValue::equal(const Value &rhs) const { diff --git a/eval/src/vespa/eval/eval/value.h b/eval/src/vespa/eval/eval/value.h index f78242863d2..45b4f59ecd5 100644 --- a/eval/src/vespa/eval/eval/value.h +++ b/eval/src/vespa/eval/eval/value.h @@ -39,6 +39,7 @@ struct Value { }; struct ErrorValue : public Value { + static ErrorValue instance; bool is_error() const override { return true; } double as_double() const override { return error_value; } bool equal(const Value &) const override { return false; } diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 685864d3d5e..771a457509c 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -18,12 +18,17 @@ namespace vespalib { namespace tensor { -using Value = eval::Value; -using ValueType = eval::ValueType; -using ErrorValue = eval::ErrorValue; -using DoubleValue = eval::DoubleValue; -using TensorValue = eval::TensorValue; -using TensorSpec = eval::TensorSpec; +using eval::Aggr; +using eval::Aggregator; +using eval::DoubleValue; +using eval::ErrorValue; +using eval::TensorSpec; +using eval::TensorValue; +using eval::Value; +using eval::ValueType; + +using map_fun_t = eval::TensorEngine::map_fun_t; +using join_fun_t = eval::TensorEngine::join_fun_t; namespace { @@ -64,17 +69,23 @@ const Value &to_default(const Value &value, Stash &stash) { } const Value &to_value(std::unique_ptr<Tensor> tensor, Stash &stash) { + if (!tensor) { + return ErrorValue::instance; + } if (tensor->getType().is_tensor()) { return stash.create<TensorValue>(std::move(tensor)); } return stash.create<DoubleValue>(tensor->sum()); } -template <typename join_fun_t> const Value &fallback_join(const Value &a, const Value &b, join_fun_t function, Stash &stash) { return to_default(simple_engine().join(to_simple(a, stash), to_simple(b, stash), function, stash), stash); } +const Value &fallback_reduce(const Value &a, eval::Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) { + return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash); +} + } // namespace vespalib::tensor::<unnamed> const DefaultTensorEngine DefaultTensorEngine::_engine; @@ -198,10 +209,7 @@ DefaultTensorEngine::reduce(const Tensor &tensor, const BinaryOperation &op, con } else { result = my_tensor.reduce(op, dimensions); } - if (result) { - return to_value(std::move(result), stash); - } - return stash.create<ErrorValue>(); + return to_value(std::move(result), stash); } struct CellFunctionOpAdapter : tensor::CellFunction { @@ -211,14 +219,12 @@ struct CellFunctionOpAdapter : tensor::CellFunction { }; struct CellFunctionFunAdapter : tensor::CellFunction { - using map_fun_t = DefaultTensorEngine::map_fun_t; map_fun_t fun; CellFunctionFunAdapter(map_fun_t fun_in) : fun(fun_in) {} virtual double apply(double value) const override { return fun(value); } }; struct CellFunctionBindLeftAdapter : tensor::CellFunction { - using join_fun_t = DefaultTensorEngine::join_fun_t; join_fun_t fun; double a; CellFunctionBindLeftAdapter(join_fun_t fun_in, double bound) : fun(fun_in), a(bound) {} @@ -226,7 +232,6 @@ struct CellFunctionBindLeftAdapter : tensor::CellFunction { }; struct CellFunctionBindRightAdapter : tensor::CellFunction { - using join_fun_t = DefaultTensorEngine::join_fun_t; join_fun_t fun; double b; CellFunctionBindRightAdapter(join_fun_t fun_in, double bound) : fun(fun_in), b(bound) {} @@ -293,11 +298,7 @@ DefaultTensorEngine::apply(const BinaryOperation &op, const Tensor &a, const Ten } TensorOperationOverride tensor_override(my_a, my_b); op.accept(tensor_override); - if (tensor_override.result) { - return to_value(std::move(tensor_override.result), stash); - } else { - return stash.create<ErrorValue>(); - } + return to_value(std::move(tensor_override.result), stash); } //----------------------------------------------------------------------------- @@ -333,8 +334,8 @@ DefaultTensorEngine::map(const Value &a, map_fun_t function, Stash &stash) const } CellFunctionFunAdapter cell_function(function); return to_value(my_a.apply(cell_function), stash); - } else { // error - return a; + } else { + return ErrorValue::instance; } } @@ -352,8 +353,8 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S } CellFunctionBindLeftAdapter cell_function(function, a.as_double()); return to_value(my_b.apply(cell_function), stash); - } else { // error - return b; + } else { + return ErrorValue::instance; } } else if (auto tensor_a = a.as_tensor()) { assert(&tensor_a->engine() == this); @@ -375,18 +376,45 @@ DefaultTensorEngine::join(const Value &a, const Value &b, join_fun_t function, S } else { return to_value(my_a.join(function, my_b), stash); } - } else { // error - return b; + } else { + return ErrorValue::instance; } - } else { // error - return a; + } else { + return ErrorValue::instance; } } const Value & DefaultTensorEngine::reduce(const Value &a, Aggr aggr, const std::vector<vespalib::string> &dimensions, Stash &stash) const { - return to_default(simple_engine().reduce(to_simple(a, stash), aggr, dimensions, stash), stash); + if (a.is_double()) { + Aggregator &aggregator = Aggregator::create(aggr, stash); + aggregator.first(a.as_double()); + return stash.create<DoubleValue>(aggregator.result()); + } else if (auto tensor = a.as_tensor()) { + assert(&tensor->engine() == this); + const tensor::Tensor &my_a = static_cast<const tensor::Tensor &>(*tensor); + if (!tensor::Tensor::supported({my_a.getType()})) { + return fallback_reduce(a, aggr, dimensions, stash); + } + switch (aggr) { + case Aggr::PROD: return to_value(my_a.reduce(eval::operation::Mul(), dimensions), stash); + case Aggr::SUM: + if (dimensions.empty()) { + return stash.create<eval::DoubleValue>(my_a.sum()); + } else if (dimensions.size() == 1) { + return to_value(my_a.sum(dimensions[0]), stash); + } else { + return to_value(my_a.reduce(eval::operation::Add(), dimensions), stash); + } + case Aggr::MAX: return to_value(my_a.reduce(eval::operation::Max(), dimensions), stash); + case Aggr::MIN: return to_value(my_a.reduce(eval::operation::Min(), dimensions), stash); + default: + return fallback_reduce(a, aggr, dimensions, stash); + } + } else { + return ErrorValue::instance; + } } const Value & |