summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-02-10 12:30:29 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-02-10 16:07:49 +0000
commit6f4eb0728a0ce1182dcc8d9434ff0d6b812d5e31 (patch)
tree4f86fd4b48da16b45433302418fe6bef15385d92 /eval
parentb75ca29597c0ccbde95ec55a4cbd9daf4d79d1de (diff)
allow simple tensors to represent error values
Diffstat (limited to 'eval')
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.cpp19
-rw-r--r--eval/src/vespa/eval/eval/simple_tensor.h1
2 files changed, 20 insertions, 0 deletions
diff --git a/eval/src/vespa/eval/eval/simple_tensor.cpp b/eval/src/vespa/eval/eval/simple_tensor.cpp
index 477b38725e6..fbe2278a0c9 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.cpp
+++ b/eval/src/vespa/eval/eval/simple_tensor.cpp
@@ -407,6 +407,13 @@ public:
constexpr size_t TensorSpec::Label::npos;
constexpr size_t SimpleTensor::Label::npos;
+SimpleTensor::SimpleTensor()
+ : Tensor(SimpleTensorEngine::ref()),
+ _type(ValueType::error_type()),
+ _cells()
+{
+}
+
SimpleTensor::SimpleTensor(double value)
: Tensor(SimpleTensorEngine::ref()),
_type(ValueType::double_type()),
@@ -441,6 +448,9 @@ std::unique_ptr<SimpleTensor>
SimpleTensor::reduce(Aggregator &aggr, const std::vector<vespalib::string> &dimensions) const
{
ValueType result_type = _type.reduce(dimensions);
+ if (result_type.is_error()) {
+ return std::make_unique<SimpleTensor>();
+ }
Builder builder(result_type);
IndexList selector = TypeAnalyzer(_type, result_type).overlap_a;
View view(*this, selector);
@@ -459,6 +469,9 @@ std::unique_ptr<SimpleTensor>
SimpleTensor::rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const
{
ValueType result_type = _type.rename(from, to);
+ if (result_type.is_error()) {
+ return std::make_unique<SimpleTensor>();
+ }
Builder builder(result_type);
IndexList selector;
for (const auto &dim: result_type.dimensions()) {
@@ -511,6 +524,9 @@ std::unique_ptr<SimpleTensor>
SimpleTensor::join(const SimpleTensor &a, const SimpleTensor &b, const std::function<double(double,double)> &function)
{
ValueType result_type = ValueType::join(a.type(), b.type());
+ if (result_type.is_error()) {
+ return std::make_unique<SimpleTensor>();
+ }
Builder builder(result_type);
TypeAnalyzer type_info(a.type(), b.type());
View view_a(a, type_info.overlap_a);
@@ -530,6 +546,9 @@ std::unique_ptr<SimpleTensor>
SimpleTensor::concat(const SimpleTensor &a, const SimpleTensor &b, const vespalib::string &dimension)
{
ValueType result_type = ValueType::concat(a.type(), b.type(), dimension);
+ if (result_type.is_error()) {
+ return std::make_unique<SimpleTensor>();
+ }
Builder builder(result_type);
TypeAnalyzer type_info(a.type(), b.type(), dimension);
View view_a(a, type_info.overlap_a);
diff --git a/eval/src/vespa/eval/eval/simple_tensor.h b/eval/src/vespa/eval/eval/simple_tensor.h
index a48f3025f34..a42413233af 100644
--- a/eval/src/vespa/eval/eval/simple_tensor.h
+++ b/eval/src/vespa/eval/eval/simple_tensor.h
@@ -71,6 +71,7 @@ private:
Cells _cells;
public:
+ SimpleTensor();
explicit SimpleTensor(double value);
SimpleTensor(const ValueType &type_in, Cells &&cells_in);
const ValueType &type() const { return _type; }