aboutsummaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@broadpark.no>2019-02-25 16:52:24 +0100
committerGitHub <noreply@github.com>2019-02-25 16:52:24 +0100
commit4b46918b47774d15b4882aff7db693699383ca61 (patch)
treed7c6498e7390fb8bc2bf8b3c6290c25c5be7fcb6 /eval
parent50e5898f70d4ea1ece5065b06ea7f3a0755463b9 (diff)
parente3ab5b19197122709f06636001955e8c84345a0f (diff)
Merge pull request #8604 from vespa-engine/geirst/remove-and-modify-for-mixed-tensors
Geirst/remove and modify for mixed tensors
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp51
-rw-r--r--eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp49
-rw-r--r--eval/src/vespa/eval/tensor/cell_values.h4
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.h2
-rw-r--r--eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp72
5 files changed, 156 insertions, 22 deletions
diff --git a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp
index ff59b28d60c..31f17b3eed2 100644
--- a/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp
+++ b/eval/src/tests/tensor/tensor_modify_operation/tensor_modify_operation_test.cpp
@@ -1,5 +1,6 @@
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include <vespa/eval/eval/operation.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/tensor/cell_values.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
@@ -13,16 +14,6 @@ using vespalib::eval::TensorSpec;
using vespalib::tensor::test::makeTensor;
using namespace vespalib::tensor;
-namespace {
-
-double
-replace(double, double b)
-{
- return b;
-}
-
-}
-
void
checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec &expect)
{
@@ -30,7 +21,7 @@ checkUpdate(const TensorSpec &source, const TensorSpec &update, const TensorSpec
auto updateTensor = makeTensor<SparseTensor>(update);
const CellValues cellValues(*updateTensor);
- auto actualTensor = sourceTensor->modify(replace, cellValues);
+ auto actualTensor = sourceTensor->modify(vespalib::eval::operation::Add::f, cellValues);
auto actual = actualTensor->toSpec();
auto expectTensor = makeTensor<Tensor>(expect);
auto expectPadded = expectTensor->toSpec();
@@ -45,7 +36,7 @@ TEST(TensorModifyTest, sparse_tensors_can_be_modified)
TensorSpec("tensor(x{},y{})")
.add({{"x","8"},{"y","9"}}, 2),
TensorSpec("tensor(x{},y{})")
- .add({{"x","8"},{"y","9"}}, 2)
+ .add({{"x","8"},{"y","9"}}, 13)
.add({{"x","9"},{"y","9"}}, 11));
}
@@ -57,10 +48,27 @@ TEST(TensorModifyTest, dense_tensors_can_be_modified)
TensorSpec("tensor(x{},y{})")
.add({{"x","8"},{"y","9"}}, 2),
TensorSpec("tensor(x[10],y[10])")
- .add({{"x",8},{"y",9}}, 2)
+ .add({{"x",8},{"y",9}}, 13)
.add({{"x",9},{"y",9}}, 11));
}
+TEST(TensorModifyTest, mixed_tensors_can_be_modified)
+{
+ checkUpdate(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 5),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","0"}}, 6)
+ .add({{"x","b"},{"y","1"}}, 7),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 8)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 12));
+}
+
TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells)
{
checkUpdate(TensorSpec("tensor(x{},y{})")
@@ -70,7 +78,7 @@ TEST(TensorModifyTest, sparse_tensors_ignore_updates_to_missing_cells)
.add({{"x","7"},{"y","9"}}, 2)
.add({{"x","8"},{"y","9"}}, 2),
TensorSpec("tensor(x{},y{})")
- .add({{"x","8"},{"y","9"}}, 2)
+ .add({{"x","8"},{"y","9"}}, 13)
.add({{"x","9"},{"y","9"}}, 11));
}
@@ -83,8 +91,21 @@ TEST(TensorModifyTest, dense_tensors_ignore_updates_to_out_of_range_cells)
.add({{"x","8"},{"y","9"}}, 2)
.add({{"x","10"},{"y","9"}}, 2),
TensorSpec("tensor(x[10],y[10])")
- .add({{"x",8},{"y",9}}, 2)
+ .add({{"x",8},{"y",9}}, 13)
.add({{"x",9},{"y",9}}, 11));
}
+TEST(TensorModifyTest, mixed_tensors_ignore_updates_to_missing_or_out_of_range_cells)
+{
+ checkUpdate(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","a"},{"y","2"}}, 4)
+ .add({{"x","c"},{"y","0"}}, 5),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3));
+}
+
GTEST_MAIN_RUN_ALL_TESTS
diff --git a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
index cb28019c4ee..df21b46691d 100644
--- a/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
+++ b/eval/src/tests/tensor/tensor_remove_operation/tensor_remove_operation_test.cpp
@@ -43,4 +43,53 @@ TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_sparse_tensor)
TensorSpec("tensor(x{},y{})"));
}
+TEST(TensorRemoveTest, cells_can_be_removed_from_a_mixed_tensor)
+{
+ assertRemove(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3)
+ .add({{"x","b"},{"y",0}}, 4)
+ .add({{"x","b"},{"y",1}}, 5),
+ TensorSpec("tensor(x{})")
+ .add({{"x","b"}}, 1)
+ .add({{"x","c"}}, 1),
+ TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3));
+
+ assertRemove(TensorSpec("tensor(x{},y{},z[2])")
+ .add({{"x","a"},{"y","c"},{"z",0}}, 2)
+ .add({{"x","a"},{"y","c"},{"z",1}}, 3)
+ .add({{"x","b"},{"y","c"},{"z",0}}, 4)
+ .add({{"x","b"},{"y","c"},{"z",1}}, 5),
+ TensorSpec("tensor(x{},y{})")
+ .add({{"x","b"},{"y","c"}}, 1)
+ .add({{"x","c"},{"y","c"}}, 1),
+ TensorSpec("tensor(x{},y{},z[2])")
+ .add({{"x","a"},{"y","c"},{"z",0}}, 2)
+ .add({{"x","a"},{"y","c"},{"z",1}}, 3));
+
+ assertRemove(TensorSpec("tensor(x{},y[1],z[2])")
+ .add({{"x","a"},{"y",0},{"z",0}}, 2)
+ .add({{"x","a"},{"y",0},{"z",1}}, 3)
+ .add({{"x","b"},{"y",0},{"z",0}}, 4)
+ .add({{"x","b"},{"y",0},{"z",1}}, 5),
+ TensorSpec("tensor(x{})")
+ .add({{"x","b"}}, 1)
+ .add({{"x","c"}}, 1),
+ TensorSpec("tensor(x{},y[1],z[2])")
+ .add({{"x","a"},{"y",0},{"z",0}}, 2)
+ .add({{"x","a"},{"y",0},{"z",1}}, 3));
+}
+
+TEST(TensorRemoveTest, all_cells_can_be_removed_from_a_mixed_tensor)
+{
+ assertRemove(TensorSpec("tensor(x{},y[2])")
+ .add({{"x","a"},{"y",0}}, 2)
+ .add({{"x","a"},{"y",1}}, 3),
+ TensorSpec("tensor(x{})")
+ .add({{"x","a"}}, 1),
+ TensorSpec("tensor(x{},y[2])"));
+}
+
GTEST_MAIN_RUN_ALL_TESTS
diff --git a/eval/src/vespa/eval/tensor/cell_values.h b/eval/src/vespa/eval/tensor/cell_values.h
index dabdcde3294..4b8fd33376a 100644
--- a/eval/src/vespa/eval/tensor/cell_values.h
+++ b/eval/src/vespa/eval/tensor/cell_values.h
@@ -23,6 +23,10 @@ public:
void accept(TensorVisitor &visitor) const {
_tensor.accept(visitor);
}
+
+ eval::TensorSpec toSpec() const {
+ return _tensor.toSpec();
+ }
};
}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
index 7eebff1f010..c182c09c6b0 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.h
@@ -2,10 +2,10 @@
#pragma once
+#include "sparse_tensor_address_ref.h"
#include <vespa/eval/tensor/cell_function.h>
#include <vespa/eval/tensor/tensor.h>
#include <vespa/eval/tensor/tensor_address.h>
-#include "sparse_tensor_address_ref.h"
#include <vespa/eval/tensor/types.h>
#include <vespa/vespalib/stllike/hash_map.h>
#include <vespa/vespalib/stllike/string.h>
diff --git a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
index 1268d6fa9cb..a982a4b0fe1 100644
--- a/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/wrapped_simple_tensor.cpp
@@ -1,8 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include "wrapped_simple_tensor.h"
+#include "cell_values.h"
#include "tensor_address_builder.h"
#include "tensor_visitor.h"
+#include "wrapped_simple_tensor.h"
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/vespalib/util/stringfmt.h>
@@ -80,10 +81,42 @@ WrappedSimpleTensor::reduce(join_fun_t, const std::vector<vespalib::string> &) c
LOG_ABORT("should not be reached");
}
+namespace {
+
+TensorSpec::Address
+convertToOnlyMappedDimensions(const TensorSpec::Address &address)
+{
+ TensorSpec::Address result;
+ for (const auto &elem : address) {
+ if (elem.second.is_indexed()) {
+ result.emplace(std::make_pair(elem.first,
+ TensorSpec::Label(vespalib::make_string("%zu", elem.second.index))));
+ } else {
+ result.emplace(elem);
+ }
+ }
+ return result;
+}
+
+}
+
std::unique_ptr<Tensor>
-WrappedSimpleTensor::modify(join_fun_t, const CellValues &) const
+WrappedSimpleTensor::modify(join_fun_t op, const CellValues &cellValues) const
{
- LOG_ABORT("should not be reached");
+ TensorSpec oldTensor = toSpec();
+ TensorSpec toModify = cellValues.toSpec();
+ TensorSpec result(type().to_spec());
+
+ for (const auto &cell : oldTensor.cells()) {
+ TensorSpec::Address mappedAddress = convertToOnlyMappedDimensions(cell.first);
+ auto itr = toModify.cells().find(mappedAddress);
+ if (itr != toModify.cells().end()) {
+ result.add(cell.first, op(cell.second, itr->second));
+ } else {
+ result.add(cell.first, cell.second);
+ }
+ }
+ return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
std::unique_ptr<Tensor>
@@ -114,10 +147,37 @@ WrappedSimpleTensor::add(const Tensor &arg) const
return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
+namespace {
+
+TensorSpec::Address
+extractMappedDimensions(const TensorSpec::Address &address)
+{
+ TensorSpec::Address result;
+ for (const auto &elem : address) {
+ if (elem.second.is_mapped()) {
+ result.emplace(elem);
+ }
+ }
+ return result;
+}
+
+}
+
std::unique_ptr<Tensor>
-WrappedSimpleTensor::remove(const CellValues &) const
+WrappedSimpleTensor::remove(const CellValues &cellAddresses) const
{
- LOG_ABORT("should not be reached");
+ TensorSpec oldTensor = toSpec();
+ TensorSpec toRemove = cellAddresses.toSpec();
+ TensorSpec result(type().to_spec());
+
+ for (const auto &cell : oldTensor.cells()) {
+ TensorSpec::Address mappedAddress = extractMappedDimensions(cell.first);
+ auto itr = toRemove.cells().find(mappedAddress);
+ if (itr == toRemove.cells().end()) {
+ result.add(cell.first, cell.second);
+ }
+ }
+ return std::make_unique<WrappedSimpleTensor>(SimpleTensor::create(result));
}
-} // namespace vespalib::tensor
+}