summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2019-07-03 12:03:18 +0000
committerHåvard Pettersen <havardpe@oath.com>2019-07-03 12:03:18 +0000
commit8b4a8f5234bc3a2f1f548ac57f7143ce23453ef3 (patch)
tree68c632719a8de51d85d349a003698aa79dd3975b /eval
parent0f88dd437ce1d6833c601b5d4cb80fa35d546935 (diff)
extend c++ specific conformance test with float cases
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp4
-rw-r--r--eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp2
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp4
-rw-r--r--eval/src/vespa/eval/eval/tensor_spec.h1
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp2
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.h20
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp81
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_model.hpp23
11 files changed, 95 insertions, 48 deletions
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index d970b9dad30..356625417d8 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -102,7 +102,7 @@ EvalFixture::ParamRepo make_params() {
.add("v04_y3", spec({y(3)}, MyVecSeq(10)))
.add("v05_x5", spec({x(5)}, MyVecSeq(6.0)))
.add("v06_x5", spec({x(5)}, MyVecSeq(7.0)))
- .add("v07_x5f", spec({x(5)}, MyVecSeq(7.0)), "tensor<float>(x[5])")
+ .add("v07_x5f", spec(float_cells({x(5)}), MyVecSeq(7.0)))
.add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
.add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0)));
}
diff --git a/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp b/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
index 773381b4c77..4995ea89735 100644
--- a/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
@@ -25,7 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref();
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("x5", spec({x(5)}, N()))
- .add("x5f", spec({x(5)}, N()), "tensor<float>(x[5])")
+ .add("x5f", spec(float_cells({x(5)}), N()))
.add("x_m", spec({x({"a", "b", "c"})}, N()))
.add("x5y3", spec({x(5),y(3)}, N()));
}
diff --git a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
index c9e581e6b21..083ed1c7071 100644
--- a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
@@ -45,8 +45,8 @@ EvalFixture::ParamRepo make_params() {
.add_mutable("mut_x5_A", spec({x(5)}, seq))
.add_mutable("mut_x5_B", spec({x(5)}, seq))
.add_mutable("mut_x5_C", spec({x(5)}, seq))
- .add_mutable("mut_x5f_D", spec({x(5)}, seq), "tensor<float>(x[5])")
- .add_mutable("mut_x5f_E", spec({x(5)}, seq), "tensor<float>(x[5])")
+ .add_mutable("mut_x5f_D", spec(float_cells({x(5)}), seq))
+ .add_mutable("mut_x5f_E", spec(float_cells({x(5)}), seq))
.add_mutable("mut_x5y3_A", spec({x(5),y(3)}, seq))
.add_mutable("mut_x5y3_B", spec({x(5),y(3)}, seq))
.add_mutable("mut_x_sparse", spec({x({"a", "b", "c"})}, seq));
diff --git a/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp b/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
index 36ebdec028b..314d3a6186c 100644
--- a/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
@@ -26,7 +26,7 @@ EvalFixture::ParamRepo make_params() {
.add("x5", spec({x(5)}, N()))
.add_mutable("_d", spec(5.0))
.add_mutable("_x5", spec({x(5)}, N()))
- .add_mutable("_x5f", spec({x(5)}, N()), "tensor<float>(x[5])")
+ .add_mutable("_x5f", spec(float_cells({x(5)}), N()))
.add_mutable("_x5y3", spec({x(5),y(3)}, N()))
.add_mutable("_x_m", spec({x({"a", "b", "c"})}, N()));
}
diff --git a/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
index 65208aedb4b..7856775ae30 100644
--- a/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
@@ -25,7 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref();
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("x1y5z1", spec({x(1),y(5),z(1)}, N()))
- .add("x1y5z1f", spec({x(1),y(5),z(1)}, N()), "tensor<float>(x[1],y[5],z[1])")
+ .add("x1y5z1f", spec(float_cells({x(1),y(5),z(1)}), N()))
.add("x1y1z1", spec({x(1),y(1),z(1)}, N()))
.add("x1y5z_m", spec({x(1),y(5),z({"a"})}, N()));
}
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index 8045958d9ba..335aa4791a4 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -38,13 +38,13 @@ EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("y1", spec({y(1)}, MyVecSeq()))
.add("y3", spec({y(3)}, MyVecSeq()))
- .add("y3f", spec({y(3)}, MyVecSeq()), "tensor<float>(y[3])")
+ .add("y3f", spec(float_cells({y(3)}), MyVecSeq()))
.add("y5", spec({y(5)}, MyVecSeq()))
.add("y16", spec({y(16)}, MyVecSeq()))
.add("x1y1", spec({x(1),y(1)}, MyMatSeq()))
.add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
.add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
- .add("x2y3f", spec({x(2),y(3)}, MyMatSeq()), "tensor<float>(x[2],y[3])")
+ .add("x2y3f", spec(float_cells({x(2),y(3)}), MyMatSeq()))
.add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
.add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
.add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
diff --git a/eval/src/vespa/eval/eval/tensor_spec.h b/eval/src/vespa/eval/eval/tensor_spec.h
index 32dc1c82fcb..25af4c7a93c 100644
--- a/eval/src/vespa/eval/eval/tensor_spec.h
+++ b/eval/src/vespa/eval/eval/tensor_spec.h
@@ -73,7 +73,6 @@ public:
}
return *this;
}
- void override_type(const vespalib::string &new_type) { _type = new_type; }
const vespalib::string &type() const { return _type; }
const Cells &cells() const { return _cells; }
vespalib::string to_string() const;
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 321b472a3fa..3f5fa4d72bb 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -14,7 +14,7 @@ NodeTypes get_types(const Function &function, const ParamRepo &param_repo) {
for (size_t i = 0; i < function.num_params(); ++i) {
auto pos = param_repo.map.find(function.param_name(i));
ASSERT_TRUE(pos != param_repo.map.end());
- param_types.push_back(ValueType::from_spec(pos->second.type));
+ param_types.push_back(ValueType::from_spec(pos->second.value.type()));
ASSERT_TRUE(!param_types.back().is_error());
}
return NodeTypes(function, param_types);
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.h b/eval/src/vespa/eval/eval/test/eval_fixture.h
index 8c7d15e7416..9c793c01861 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.h
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.h
@@ -18,32 +18,24 @@ class EvalFixture
public:
struct Param {
TensorSpec value; // actual parameter value
- vespalib::string type; // pre-defined type (could be abstract)
bool is_mutable; // input will be mutable (if allow_mutable is true)
- Param(TensorSpec value_in, const vespalib::string &type_in, bool is_mutable_in)
- : value(std::move(value_in)), type(type_in), is_mutable(is_mutable_in) {}
+ Param(TensorSpec value_in, bool is_mutable_in)
+ : value(std::move(value_in)), is_mutable(is_mutable_in) {}
~Param() {}
};
struct ParamRepo {
std::map<vespalib::string,Param> map;
ParamRepo() : map() {}
- ParamRepo &add(const vespalib::string &name, TensorSpec value_in, const vespalib::string &type_in, bool is_mutable_in) {
- value_in.override_type(type_in);
- map.insert_or_assign(name, Param(std::move(value_in), type_in, is_mutable_in));
+ ParamRepo &add(const vespalib::string &name, TensorSpec value_in, bool is_mutable_in) {
+ map.insert_or_assign(name, Param(std::move(value_in), is_mutable_in));
return *this;
}
- ParamRepo &add(const vespalib::string &name, TensorSpec value, const vespalib::string &type) {
- return add(name, value, type, false);
- }
- ParamRepo &add_mutable(const vespalib::string &name, TensorSpec value, const vespalib::string &type) {
- return add(name, value, type, true);
- }
ParamRepo &add(const vespalib::string &name, const TensorSpec &value) {
- return add(name, value, value.type(), false);
+ return add(name, value, false);
}
ParamRepo &add_mutable(const vespalib::string &name, const TensorSpec &value) {
- return add(name, value, value.type(), true);
+ return add(name, value, true);
}
~ParamRepo() {}
};
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 1e1bd828d41..16005970817 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -301,10 +301,13 @@ struct TestContext {
TEST_DO(verify_create_type("double"));
TEST_DO(verify_create_type("tensor(x{})"));
TEST_DO(verify_create_type("tensor(x{},y{})"));
+ TEST_DO(verify_create_type("tensor<float>(x{},y{})"));
TEST_DO(verify_create_type("tensor(x[5])"));
TEST_DO(verify_create_type("tensor(x[5],y[10])"));
+ TEST_DO(verify_create_type("tensor<float>(x[5],y[10])"));
TEST_DO(verify_create_type("tensor(x{},y[10])"));
- TEST_DO(verify_create_type("tensor(x[5],y{})"));
+ TEST_DO(verify_create_type("tensor(x[5],y{})"));
+ TEST_DO(verify_create_type("tensor<float>(x[5],y{})"));
}
//-------------------------------------------------------------------------
@@ -318,11 +321,14 @@ struct TestContext {
{x(3)},
{x(3),y(5)},
{x(3),y(5),z(7)},
+ float_cells({x(3),y(5),z(7)}),
{x({"a","b","c"})},
{x({"a","b","c"}),y({"foo","bar"})},
{x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})},
+ float_cells({x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})}),
{x(3),y({"foo", "bar"}),z(7)},
- {x({"a","b","c"}),y(5),z({"i","j","k","l"})}
+ {x({"a","b","c"}),y(5),z({"i","j","k","l"})},
+ float_cells({x({"a","b","c"}),y(5),z({"i","j","k","l"})})
};
for (const Layout &layout: layouts) {
TensorSpec input = spec(layout, seq);
@@ -363,11 +369,14 @@ struct TestContext {
{x(3)},
{x(3),y(5)},
{x(3),y(5),z(7)},
+ float_cells({x(3),y(5),z(7)}),
{x({"a","b","c"})},
{x({"a","b","c"}),y({"foo","bar"})},
{x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})},
+ float_cells({x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})}),
{x(3),y({"foo", "bar"}),z(7)},
- {x({"a","b","c"}),y(5),z({"i","j","k","l"})}
+ {x({"a","b","c"}),y(5),z({"i","j","k","l"})},
+ float_cells({x({"a","b","c"}),y(5),z({"i","j","k","l"})})
};
for (const Layout &layout: layouts) {
TEST_DO(verify_result(eval.eval(engine, spec(layout, seq)), spec(layout, OpSeq(seq, ref_op))));
@@ -612,20 +621,29 @@ struct TestContext {
void test_apply_op(const Eval &eval, join_fun_t op, const Sequence &seq) {
std::vector<Layout> layouts = {
- {}, {},
- {x(5)}, {x(5)},
- {x(5)}, {y(5)},
- {x(5)}, {x(5),y(5)},
- {y(3)}, {x(2),z(3)},
- {x(3),y(5)}, {y(5),z(7)},
- {x({"a","b","c"})}, {x({"a","b","c"})},
- {x({"a","b","c"})}, {x({"a","b"})},
- {x({"a","b","c"})}, {y({"foo","bar","baz"})},
- {x({"a","b","c"})}, {x({"a","b","c"}),y({"foo","bar","baz"})},
- {x({"a","b"}),y({"foo","bar","baz"})}, {x({"a","b","c"}),y({"foo","bar"})},
- {x({"a","b"}),y({"foo","bar","baz"})}, {y({"foo","bar"}),z({"i","j","k","l"})},
- {x(3),y({"foo", "bar"})}, {y({"foo", "bar"}),z(7)},
- {x({"a","b","c"}),y(5)}, {y(5),z({"i","j","k","l"})}
+ {}, {},
+ {x(5)}, {x(5)},
+ {x(5)}, {y(5)},
+ {x(5)}, {x(5),y(5)},
+ {y(3)}, {x(2),z(3)},
+ {x(3),y(5)}, {y(5),z(7)},
+ float_cells({x(3),y(5)}), {y(5),z(7)},
+ {x(3),y(5)}, float_cells({y(5),z(7)}),
+ float_cells({x(3),y(5)}), float_cells({y(5),z(7)}),
+ {x({"a","b","c"})}, {x({"a","b","c"})},
+ {x({"a","b","c"})}, {x({"a","b"})},
+ {x({"a","b","c"})}, {y({"foo","bar","baz"})},
+ {x({"a","b","c"})}, {x({"a","b","c"}),y({"foo","bar","baz"})},
+ {x({"a","b"}),y({"foo","bar","baz"})}, {x({"a","b","c"}),y({"foo","bar"})},
+ {x({"a","b"}),y({"foo","bar","baz"})}, {y({"foo","bar"}),z({"i","j","k","l"})},
+ float_cells({x({"a","b"}),y({"foo","bar","baz"})}), {y({"foo","bar"}),z({"i","j","k","l"})},
+ {x({"a","b"}),y({"foo","bar","baz"})}, float_cells({y({"foo","bar"}),z({"i","j","k","l"})}),
+ float_cells({x({"a","b"}),y({"foo","bar","baz"})}), float_cells({y({"foo","bar"}),z({"i","j","k","l"})}),
+ {x(3),y({"foo", "bar"})}, {y({"foo", "bar"}),z(7)},
+ {x({"a","b","c"}),y(5)}, {y(5),z({"i","j","k","l"})},
+ float_cells({x({"a","b","c"}),y(5)}), {y(5),z({"i","j","k","l"})},
+ {x({"a","b","c"}),y(5)}, float_cells({y(5),z({"i","j","k","l"})}),
+ float_cells({x({"a","b","c"}),y(5)}), float_cells({y(5),z({"i","j","k","l"})})
};
ASSERT_TRUE((layouts.size() % 2) == 0);
for (size_t i = 0; i < layouts.size(); i += 2) {
@@ -681,10 +699,20 @@ struct TestContext {
TEST_DO(verify_result(safe(eval).eval(engine, lhs, rhs), spec(expect)));
}
+ void test_dot_product(double expect,
+ const Layout &lhs, const Seq &lhs_seq,
+ const Layout &rhs, const Seq &rhs_seq)
+ {
+ TEST_DO(test_dot_product(expect, spec(lhs, lhs_seq), spec(rhs, rhs_seq)));
+ TEST_DO(test_dot_product(expect, spec(float_cells(lhs), lhs_seq), spec(rhs, rhs_seq)));
+ TEST_DO(test_dot_product(expect, spec(lhs, lhs_seq), spec(float_cells(rhs), rhs_seq)));
+ TEST_DO(test_dot_product(expect, spec(float_cells(lhs), lhs_seq), spec(float_cells(rhs), rhs_seq)));
+ }
+
void test_dot_product() {
TEST_DO(test_dot_product(((2 * 7) + (3 * 11) + (5 * 13)),
- spec(x(3), Seq({ 2, 3, 5 })),
- spec(x(3), Seq({ 7, 11, 13 }))));
+ {x(3)}, Seq({ 2, 3, 5 }),
+ {x(3)}, Seq({ 7, 11, 13 })));
}
//-------------------------------------------------------------------------
@@ -714,6 +742,16 @@ struct TestContext {
spec({x(2),y(2),z(3)}, Seq({1.0, 2.0, 3.0, 1.0, 2.0, 3.0, 4.0, 4.0, 4.0, 5.0, 5.0, 5.0}))));
TEST_DO(test_concat(spec(y(2), Seq({1.0, 2.0})), spec(y(2), Seq({4.0, 5.0})), "x",
spec({x(2), y(2)}, Seq({1.0, 2.0, 4.0, 5.0}))));
+
+ TEST_DO(test_concat(spec(float_cells({x(1)}), Seq({10.0})), spec(20.0), "x", spec(float_cells({x(2)}), Seq({10.0, 20.0}))));
+ TEST_DO(test_concat(spec(10.0), spec(float_cells({x(1)}), Seq({20.0})), "x", spec(float_cells({x(2)}), Seq({10.0, 20.0}))));
+
+ TEST_DO(test_concat(spec(float_cells({x(3)}), Seq({1.0, 2.0, 3.0})), spec(x(2), Seq({4.0, 5.0})), "x",
+ spec(x(5), Seq({1.0, 2.0, 3.0, 4.0, 5.0}))));
+ TEST_DO(test_concat(spec(x(3), Seq({1.0, 2.0, 3.0})), spec(float_cells({x(2)}), Seq({4.0, 5.0})), "x",
+ spec(x(5), Seq({1.0, 2.0, 3.0, 4.0, 5.0}))));
+ TEST_DO(test_concat(spec(float_cells({x(3)}), Seq({1.0, 2.0, 3.0})), spec(float_cells({x(2)}), Seq({4.0, 5.0})), "x",
+ spec(float_cells({x(5)}), Seq({1.0, 2.0, 3.0, 4.0, 5.0}))));
}
//-------------------------------------------------------------------------
@@ -732,6 +770,7 @@ struct TestContext {
void test_rename() {
TEST_DO(test_rename("rename(a,x,y)", spec(x(5), N()), {"x"}, {"y"}, spec(y(5), N())));
TEST_DO(test_rename("rename(a,y,x)", spec({y(5),z(5)}, N()), {"y"}, {"x"}, spec({x(5),z(5)}, N())));
+ TEST_DO(test_rename("rename(a,y,x)", spec(float_cells({y(5),z(5)}), N()), {"y"}, {"x"}, spec(float_cells({x(5),z(5)}), N())));
TEST_DO(test_rename("rename(a,z,x)", spec({y(5),z(5)}, N()), {"z"}, {"x"}, spec({y(5),x(5)}, N())));
TEST_DO(test_rename("rename(a,x,z)", spec({x(5),y(5)}, N()), {"x"}, {"z"}, spec({z(5),y(5)}, N())));
TEST_DO(test_rename("rename(a,y,z)", spec({x(5),y(5)}, N()), {"y"}, {"z"}, spec({x(5),z(5)}, N())));
@@ -746,6 +785,7 @@ struct TestContext {
void test_tensor_lambda() {
TEST_DO(test_tensor_lambda("tensor(x[10])(x+1)", spec(x(10), N())));
+ TEST_DO(test_tensor_lambda("tensor<float>(x[10])(x+1)", spec(float_cells({x(10)}), N())));
TEST_DO(test_tensor_lambda("tensor(x[5],y[4])(x*4+(y+1))", spec({x(5),y(4)}, N())));
TEST_DO(test_tensor_lambda("tensor(x[5],y[4])(x==y)", spec({x(5),y(4)},
Seq({ 1.0, 0.0, 0.0, 0.0,
@@ -818,11 +858,14 @@ struct TestContext {
TEST_DO(verify_encode_decode(spec({x(3)}, N())));
TEST_DO(verify_encode_decode(spec({x(3),y(5)}, N())));
TEST_DO(verify_encode_decode(spec({x(3),y(5),z(7)}, N())));
+ TEST_DO(verify_encode_decode(spec(float_cells({x(3),y(5),z(7)}), N())));
TEST_DO(verify_encode_decode(spec({x({"a","b","c"})}, N())));
TEST_DO(verify_encode_decode(spec({x({"a","b","c"}),y({"foo","bar"})}, N())));
TEST_DO(verify_encode_decode(spec({x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})}, N())));
+ TEST_DO(verify_encode_decode(spec(float_cells({x({"a","b","c"}),y({"foo","bar"}),z({"i","j","k","l"})}), N())));
TEST_DO(verify_encode_decode(spec({x(3),y({"foo", "bar"}),z(7)}, N())));
TEST_DO(verify_encode_decode(spec({x({"a","b","c"}),y(5),z({"i","j","k","l"})}, N())));
+ TEST_DO(verify_encode_decode(spec(float_cells({x({"a","b","c"}),y(5),z({"i","j","k","l"})}), N())));
}
//-------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/eval/test/tensor_model.hpp b/eval/src/vespa/eval/eval/test/tensor_model.hpp
index 50a7b6a639a..4fad2820cf7 100644
--- a/eval/src/vespa/eval/eval/test/tensor_model.hpp
+++ b/eval/src/vespa/eval/eval/test/tensor_model.hpp
@@ -10,6 +10,7 @@ namespace vespalib {
namespace eval {
namespace test {
+using CellType = ValueType::CellType;
using map_fun_t = TensorEngine::map_fun_t;
using join_fun_t = TensorEngine::join_fun_t;
@@ -146,7 +147,22 @@ struct Domain {
Domain::Domain(const Domain &) = default;
Domain::~Domain() {}
-using Layout = std::vector<Domain>;
+struct Layout {
+ CellType cell_type;
+ std::vector<Domain> domains;
+ Layout(std::initializer_list<Domain> domains_in)
+ : cell_type(CellType::DOUBLE), domains(domains_in) {}
+ Layout(CellType cell_type_in, std::vector<Domain> domains_in)
+ : cell_type(cell_type_in), domains(std::move(domains_in)) {}
+ auto begin() const { return domains.begin(); }
+ auto end() const { return domains.end(); }
+ auto size() const { return domains.size(); }
+ auto operator[](size_t idx) const { return domains[idx]; }
+};
+
+Layout float_cells(const Layout &layout) {
+ return Layout(CellType::FLOAT, layout.domains);
+}
Domain x() { return Domain("x", {}); }
Domain x(size_t size) { return Domain("x", size); }
@@ -162,9 +178,6 @@ Domain z(const std::vector<vespalib::string> &keys) { return Domain("z", keys);
// Infer the tensor type spanned by the given spaces
vespalib::string infer_type(const Layout &layout) {
- if (layout.empty()) {
- return "double";
- }
std::vector<ValueType::Dimension> dimensions;
for (const auto &domain: layout) {
if (domain.size == 0) {
@@ -173,7 +186,7 @@ vespalib::string infer_type(const Layout &layout) {
dimensions.emplace_back(domain.dimension, domain.size); // indexed
}
}
- return ValueType::tensor_type(dimensions).to_spec();
+ return ValueType::tensor_type(dimensions, layout.cell_type).to_spec();
}
// Wrapper for the things needed to generate a tensor