aboutsummaryrefslogtreecommitdiffstats
path: root/eval/src/tests/tensor/dense_tensor_address_combiner/dense_tensor_address_combiner_test.cpp
blob: c8e57f970e378b989897156e10c9372c9c601f90 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/tensor/dense/dense_tensor_address_combiner.h>
#include <vespa/vespalib/test/insertion_operators.h>

using namespace vespalib::tensor;
using vespalib::eval::ValueType;

ValueType
combine(const std::vector<ValueType::Dimension> &lhs,
        const std::vector<ValueType::Dimension> &rhs)
{
    return DenseTensorAddressCombiner::combineDimensions(
            ValueType::tensor_type(lhs),
            ValueType::tensor_type(rhs));
}

TEST("require that dimensions can be combined")
{
    EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}}, {{"b", 5}}));
    EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 5}}, {{"b", 5}}));
    EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 5}}), combine({{"a", 3}, {"b", 7}}, {{"b", 5}}));
    EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}),
                                combine({{"a", 3}, {"c", 5}, {"d", 7}},
                                        {{"b", 11}, {"c", 13}, {"e", 17}}));
    EXPECT_EQUAL(ValueType::tensor_type({{"a", 3}, {"b", 11}, {"c", 5}, {"d", 7}, {"e", 17}}),
                 combine({{"b", 11}, {"c", 13}, {"e", 17}},
                         {{"a", 3}, {"c", 5}, {"d", 7}}));
}

TEST_MAIN() { TEST_RUN_ALL(); }