summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-06-12 12:52:32 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-06-12 12:52:32 +0000
commit7351cd6fa62a169b55806ad41ea98acb76491c45 (patch)
treef878bcb4c8b80a75c19053305c1131015070b5b0 /eval
parent2b9aa7f40171f918fe1a906de8334d637aa8810b (diff)
enable float result type for prod tensor operations
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp6
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp26
-rw-r--r--eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp11
-rw-r--r--eval/src/vespa/eval/tensor/tensor.cpp3
4 files changed, 5 insertions, 41 deletions
diff --git a/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp b/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp
index c8e57f970e3..91a6087ea3a 100644
--- a/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp
+++ b/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp
@@ -20,12 +20,12 @@ TEST("require that dimensions can be combined")
{
EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}}, {{"b", 5}}));
EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 5}}, {{"b", 5}}));
- EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 7}}, {{"b", 5}}));
+ EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 5}}, {{"b", 5}}));
EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}),
combine({{"a", 3}, {"c", 5}, {"d", 7}},
- {{"b", 11}, {"c", 13}, {"e", 17}}));
+ {{"b", 11}, {"c", 5}, {"e", 17}}));
EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}),
- combine({{"b", 11}, {"c", 13}, {"e", 17}},
+ combine({{"b", 11}, {"c", 5}, {"e", 17}},
{{"a", 3}, {"c", 5}, {"d", 7}}));
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
index df6ac162d7f..b5c5d9b6a04 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp
@@ -57,31 +57,7 @@ AddressContext::~AddressContext() = default;
eval::ValueType
DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs)
{
- // NOTE: both lhs and rhs are sorted according to dimension names.
- std::vector<eval::ValueType::Dimension> result;
- auto lhsItr = lhs.dimensions().cbegin();
- auto rhsItr = rhs.dimensions().cbegin();
- while (lhsItr != lhs.dimensions().end() &&
- rhsItr != rhs.dimensions().end()) {
- if (lhsItr->name == rhsItr->name) {
- result.emplace_back(lhsItr->name, std::min(lhsItr->size, rhsItr->size));
- ++lhsItr;
- ++rhsItr;
- } else if (lhsItr->name < rhsItr->name) {
- result.emplace_back(*lhsItr++);
- } else {
- result.emplace_back(*rhsItr++);
- }
- }
- while (lhsItr != lhs.dimensions().end()) {
- result.emplace_back(*lhsItr++);
- }
- while (rhsItr != rhs.dimensions().end()) {
- result.emplace_back(*rhsItr++);
- }
- return (result.empty() ?
- eval::ValueType::double_type() :
- eval::ValueType::tensor_type(std::move(result)));
+ return eval::ValueType::join(lhs, rhs);
}
}
diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
index ded9310b450..bcfbc851e6d 100644
--- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
+++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp
@@ -64,16 +64,7 @@ SparseTensor::operator==(const SparseTensor &rhs) const
eval::ValueType
SparseTensor::combineDimensionsWith(const SparseTensor &rhs) const
{
- std::vector<eval::ValueType::Dimension> result;
- std::set_union(_type.dimensions().cbegin(), _type.dimensions().cend(),
- rhs._type.dimensions().cbegin(), rhs._type.dimensions().cend(),
- std::back_inserter(result),
- [](const eval::ValueType::Dimension &lhsDim,
- const eval::ValueType::Dimension &rhsDim)
- { return lhsDim.name < rhsDim.name; });
- return (result.empty() ?
- eval::ValueType::double_type() :
- eval::ValueType::tensor_type(std::move(result)));
+ return eval::ValueType::join(_type, rhs._type);
}
const eval::ValueType &
diff --git a/eval/src/vespa/eval/tensor/tensor.cpp b/eval/src/vespa/eval/tensor/tensor.cpp
index 5697458f3ca..51c94aab5b0 100644
--- a/eval/src/vespa/eval/tensor/tensor.cpp
+++ b/eval/src/vespa/eval/tensor/tensor.cpp
@@ -17,9 +17,6 @@ Tensor::supported(TypeList types)
bool sparse = false;
bool dense = false;
for (const eval::ValueType &type: types) {
- if (type.cell_type() != eval::ValueType::CellType::DOUBLE) {
- return false; // non-double cell types not supported
- }
dense = (dense || type.is_double());
for (const auto &dim: type.dimensions()) {
dense = (dense || dim.is_indexed());