diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-06-12 12:52:32 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-06-12 12:52:32 +0000 |
commit | 7351cd6fa62a169b55806ad41ea98acb76491c45 (patch) | |
tree | f878bcb4c8b80a75c19053305c1131015070b5b0 /eval | |
parent | 2b9aa7f40171f918fe1a906de8334d637aa8810b (diff) |
enable float result type for prod tensor operations
Diffstat (limited to 'eval')
4 files changed, 5 insertions, 41 deletions
diff --git a/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp b/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp index c8e57f970e3..91a6087ea3a 100644 --- a/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp +++ b/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp @@ -20,12 +20,12 @@ TEST("require that dimensions can be combined") { EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}}, {{"b", 5}})); EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 5}}, {{"b", 5}})); - EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 7}}, {{"b", 5}})); + EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 5}}, {{"b", 5}})); EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}), combine({{"a", 3}, {"c", 5}, {"d", 7}}, - {{"b", 11}, {"c", 13}, {"e", 17}})); + {{"b", 11}, {"c", 5}, {"e", 17}})); EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}), - combine({{"b", 11}, {"c", 13}, {"e", 17}}, + combine({{"b", 11}, {"c", 5}, {"e", 17}}, {{"a", 3}, {"c", 5}, {"d", 7}})); } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp index df6ac162d7f..b5c5d9b6a04 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_address_combiner.cpp @@ -57,31 +57,7 @@ AddressContext::~AddressContext() = default; eval::ValueType DenseTensorAddressCombiner::combineDimensions(const eval::ValueType &lhs, const eval::ValueType &rhs) { - // NOTE: both lhs and rhs are sorted according to dimension names. - std::vector<eval::ValueType::Dimension> result; - auto lhsItr = lhs.dimensions().cbegin(); - auto rhsItr = rhs.dimensions().cbegin(); - while (lhsItr != lhs.dimensions().end() && - rhsItr != rhs.dimensions().end()) { - if (lhsItr->name == rhsItr->name) { - result.emplace_back(lhsItr->name, std::min(lhsItr->size, rhsItr->size)); - ++lhsItr; - ++rhsItr; - } else if (lhsItr->name < rhsItr->name) { - result.emplace_back(*lhsItr++); - } else { - result.emplace_back(*rhsItr++); - } - } - while (lhsItr != lhs.dimensions().end()) { - result.emplace_back(*lhsItr++); - } - while (rhsItr != rhs.dimensions().end()) { - result.emplace_back(*rhsItr++); - } - return (result.empty() ? - eval::ValueType::double_type() : - eval::ValueType::tensor_type(std::move(result))); + return eval::ValueType::join(lhs, rhs); } } diff --git a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp index ded9310b450..bcfbc851e6d 100644 --- a/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp +++ b/eval/src/vespa/eval/tensor/sparse/sparse_tensor.cpp @@ -64,16 +64,7 @@ SparseTensor::operator==(const SparseTensor &rhs) const eval::ValueType SparseTensor::combineDimensionsWith(const SparseTensor &rhs) const { - std::vector<eval::ValueType::Dimension> result; - std::set_union(_type.dimensions().cbegin(), _type.dimensions().cend(), - rhs._type.dimensions().cbegin(), rhs._type.dimensions().cend(), - std::back_inserter(result), - [](const eval::ValueType::Dimension &lhsDim, - const eval::ValueType::Dimension &rhsDim) - { return lhsDim.name < rhsDim.name; }); - return (result.empty() ? - eval::ValueType::double_type() : - eval::ValueType::tensor_type(std::move(result))); + return eval::ValueType::join(_type, rhs._type); } const eval::ValueType & diff --git a/eval/src/vespa/eval/tensor/tensor.cpp b/eval/src/vespa/eval/tensor/tensor.cpp index 5697458f3ca..51c94aab5b0 100644 --- a/eval/src/vespa/eval/tensor/tensor.cpp +++ b/eval/src/vespa/eval/tensor/tensor.cpp @@ -17,9 +17,6 @@ Tensor::supported(TypeList types) bool sparse = false; bool dense = false; for (const eval::ValueType &type: types) { - if (type.cell_type() != eval::ValueType::CellType::DOUBLE) { - return false; // non-double cell types not supported - } dense = (dense || type.is_double()); for (const auto &dim: type.dimensions()) { dense = (dense || dim.is_indexed()); |