summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2019-05-03 14:59:39 +0200
committerGitHub <noreply@github.com>2019-05-03 14:59:39 +0200
commita3edaae250c407de5445823851b4182b9e2e4d5b (patch)
treec4fb865892ab484cfbb049125725be4d86f2e5b6 /eval
parent99f41741ca2784640ce1bec9e673355ab92e9d42 (diff)
parentfe283c00dd00c91548f2b8ef8ef9c7a465234587 (diff)
Merge pull request #9271 from vespa-engine/havardpe/talk-about-float-cell-types
Havardpe/talk about float cell types
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp18
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp405
-rw-r--r--eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp4
-rw-r--r--eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp15
-rw-r--r--eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp5
-rw-r--r--eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp8
-rw-r--r--eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp5
-rw-r--r--eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp5
-rw-r--r--eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp8
-rw-r--r--eval/src/vespa/eval/eval/function.cpp9
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp6
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp33
-rw-r--r--eval/src/vespa/eval/eval/value_type.h17
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.cpp73
-rw-r--r--eval/src/vespa/eval/tensor/default_tensor_engine.cpp16
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp3
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp19
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp5
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp5
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp3
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp3
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp3
-rw-r--r--eval/src/vespa/eval/tensor/tensor.cpp3
23 files changed, 404 insertions, 267 deletions
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index c18470887b2..256c7b85f72 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -76,8 +76,9 @@ TEST("require that leaf constants have appropriate type") {
TEST("require that input parameters preserve their type") {
TEST_DO(verify("error", "error"));
TEST_DO(verify("double", "double"));
- TEST_DO(verify("tensor", "double"));
+ TEST_DO(verify("tensor()", "double"));
TEST_DO(verify("tensor(x{},y[10],z[5])", "tensor(x{},y[10],z[5])"));
+ TEST_DO(verify("tensor<float>(x{},y[10],z[5])", "tensor<float>(x{},y[10],z[5])"));
}
TEST("require that if resolves to the appropriate type") {
@@ -88,6 +89,8 @@ TEST("require that if resolves to the appropriate type") {
TEST_DO(verify("if(tensor(x[10]),1,2)", "double"));
TEST_DO(verify("if(double,tensor(a{}),tensor(a{}))", "tensor(a{})"));
TEST_DO(verify("if(double,tensor(a[2]),tensor(a[2]))", "tensor(a[2])"));
+ TEST_DO(verify("if(double,tensor<float>(a[2]),tensor<float>(a[2]))", "tensor<float>(a[2])"));
+ TEST_DO(verify("if(double,tensor(a[2]),tensor<float>(a[2]))", "error"));
TEST_DO(verify("if(double,tensor(a[2]),tensor(a[3]))", "error"));
TEST_DO(verify("if(double,tensor(a[2]),tensor(a{}))", "error"));
TEST_DO(verify("if(double,tensor(a{}),tensor(b{}))", "error"));
@@ -105,6 +108,9 @@ TEST("require that reduce resolves correct type") {
TEST_DO(verify("reduce(tensor(x{},y{},z{}),sum,y,z,x)", "double"));
TEST_DO(verify("reduce(tensor(x{},y{},z{}),sum,w)", "error"));
TEST_DO(verify("reduce(tensor(x{}),sum,x)", "double"));
+ TEST_DO(verify("reduce(tensor<float>(x{},y{},z{}),sum,x,z)", "tensor<float>(y{})"));
+ TEST_DO(verify("reduce(tensor<float>(x{}),sum,x)", "double"));
+ TEST_DO(verify("reduce(tensor<float>(x{}),sum)", "double"));
}
TEST("require that rename resolves correct type") {
@@ -119,6 +125,7 @@ TEST("require that rename resolves correct type") {
TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,y,z),(z,y,x))", "tensor(z{},y[1],x[5])"));
TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,z),(z,x))", "tensor(z{},y[1],x[5])"));
TEST_DO(verify("rename(tensor(x{},y[1],z[5]),(x,y,z),(a,b,c))", "tensor(a{},b[1],c[5])"));
+ TEST_DO(verify("rename(tensor<float>(x{},y[1],z[5]),(x,y,z),(a,b,c))", "tensor<float>(a{},b[1],c[5])"));
}
vespalib::string strfmt(const char *pattern, const char *a) {
@@ -133,6 +140,7 @@ void verify_op1(const char *pattern) {
TEST_DO(verify(strfmt(pattern, "error"), "error"));
TEST_DO(verify(strfmt(pattern, "double"), "double"));
TEST_DO(verify(strfmt(pattern, "tensor(x{},y[10],z[1])"), "tensor(x{},y[10],z[1])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x{},y[10],z[1])"), "tensor<float>(x{},y[10],z[1])"));
}
void verify_op2(const char *pattern) {
@@ -150,6 +158,9 @@ void verify_op2(const char *pattern) {
TEST_DO(verify(strfmt(pattern, "tensor(x[3])", "tensor(x[5])"), "error"));
TEST_DO(verify(strfmt(pattern, "tensor(x[5])", "tensor(x[3])"), "error"));
TEST_DO(verify(strfmt(pattern, "tensor(x{})", "tensor(x[5])"), "error"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor<float>(x[5])"), "tensor<float>(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "tensor(x[5])"), "tensor(x[5])"));
+ TEST_DO(verify(strfmt(pattern, "tensor<float>(x[5])", "double"), "tensor<float>(x[5])"));
}
TEST("require that various operations resolve appropriate type") {
@@ -213,6 +224,8 @@ TEST("require that lambda tensor resolves correct type") {
TEST_DO(verify("tensor(x[5])(1.0)", "tensor(x[5])", false));
TEST_DO(verify("tensor(x[5],y[10])(1.0)", "tensor(x[5],y[10])", false));
TEST_DO(verify("tensor(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])", false));
+ TEST_DO(verify("tensor<double>(x[5],y[10],z[15])(1.0)", "tensor(x[5],y[10],z[15])", false));
+ TEST_DO(verify("tensor<float>(x[5],y[10],z[15])(1.0)", "tensor<float>(x[5],y[10],z[15])", false));
}
TEST("require that tensor concat resolves correct type") {
@@ -222,6 +235,9 @@ TEST("require that tensor concat resolves correct type") {
TEST_DO(verify("concat(tensor(x[2]),tensor(x[3]),y)", "error"));
TEST_DO(verify("concat(tensor(x[2]),tensor(x{}),x)", "error"));
TEST_DO(verify("concat(tensor(x[2]),tensor(y{}),x)", "tensor(x[3],y{})"));
+ TEST_DO(verify("concat(tensor<float>(x[2]),tensor<float>(x[3]),x)", "tensor<float>(x[5])"));
+ TEST_DO(verify("concat(tensor<float>(x[2]),tensor(x[3]),x)", "tensor(x[5])"));
+ TEST_DO(verify("concat(tensor<float>(x[2]),double,x)", "tensor<float>(x[3])"));
}
TEST("require that double only expressions can be detected") {
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index a755eac965f..cf61da00cca 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -8,16 +8,32 @@
using namespace vespalib::eval;
+using CellType = ValueType::CellType;
+
const size_t npos = ValueType::Dimension::npos;
+ValueType type(const vespalib::string &type_str) {
+ ValueType ret = ValueType::from_spec(type_str);
+ ASSERT_TRUE(!ret.is_error() || (type_str == "error"));
+ return ret;
+}
+
+std::vector<vespalib::string> str_list(const std::vector<vespalib::string> &list) {
+ return list;
+}
+
+//-----------------------------------------------------------------------------
+
TEST("require that ERROR value type can be created") {
ValueType t = ValueType::error_type();
+ EXPECT_TRUE(t.cell_type() == CellType::DOUBLE);
EXPECT_TRUE(t.type() == ValueType::Type::ERROR);
EXPECT_EQUAL(t.dimensions().size(), 0u);
}
TEST("require that DOUBLE value type can be created") {
ValueType t = ValueType::double_type();
+ EXPECT_TRUE(t.cell_type() == CellType::DOUBLE);
EXPECT_TRUE(t.type() == ValueType::Type::DOUBLE);
EXPECT_EQUAL(t.dimensions().size(), 0u);
}
@@ -25,6 +41,18 @@ TEST("require that DOUBLE value type can be created") {
TEST("require that TENSOR value type can be created") {
ValueType t = ValueType::tensor_type({{"x", 10},{"y"}});
EXPECT_TRUE(t.type() == ValueType::Type::TENSOR);
+ EXPECT_TRUE(t.cell_type() == CellType::DOUBLE);
+ ASSERT_EQUAL(t.dimensions().size(), 2u);
+ EXPECT_EQUAL(t.dimensions()[0].name, "x");
+ EXPECT_EQUAL(t.dimensions()[0].size, 10u);
+ EXPECT_EQUAL(t.dimensions()[1].name, "y");
+ EXPECT_EQUAL(t.dimensions()[1].size, npos);
+}
+
+TEST("require that float TENSOR value type can be created") {
+ ValueType t = ValueType::tensor_type({{"x", 10},{"y"}}, CellType::FLOAT);
+ EXPECT_TRUE(t.type() == ValueType::Type::TENSOR);
+ EXPECT_TRUE(t.cell_type() == CellType::FLOAT);
ASSERT_EQUAL(t.dimensions().size(), 2u);
EXPECT_EQUAL(t.dimensions()[0].name, "x");
EXPECT_EQUAL(t.dimensions()[0].size, 10u);
@@ -35,6 +63,7 @@ TEST("require that TENSOR value type can be created") {
TEST("require that TENSOR value type sorts dimensions") {
ValueType t = ValueType::tensor_type({{"x", 10}, {"z", 30}, {"y"}});
EXPECT_TRUE(t.type() == ValueType::Type::TENSOR);
+ EXPECT_TRUE(t.cell_type() == CellType::DOUBLE);
ASSERT_EQUAL(t.dimensions().size(), 3u);
EXPECT_EQUAL(t.dimensions()[0].name, "x");
EXPECT_EQUAL(t.dimensions()[0].size, 10u);
@@ -44,26 +73,23 @@ TEST("require that TENSOR value type sorts dimensions") {
EXPECT_EQUAL(t.dimensions()[2].size, 30u);
}
-TEST("require that dimension names can be obtained") {
- EXPECT_EQUAL(ValueType::double_type().dimension_names(),
- std::vector<vespalib::string>({}));
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}, {"x", 30}}).dimension_names(),
- std::vector<vespalib::string>({"x", "y"}));
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}, {"x", 30}, {"z"}}).dimension_names(),
- std::vector<vespalib::string>({"x", "y", "z"}));
+TEST("require that 'tensor<float>()' is normalized to 'double'") {
+ ValueType t = ValueType::tensor_type({}, CellType::FLOAT);
+ EXPECT_TRUE(t.cell_type() == CellType::DOUBLE);
+ EXPECT_TRUE(t.type() == ValueType::Type::DOUBLE);
+ EXPECT_EQUAL(t.dimensions().size(), 0u);
}
-TEST("require that dimension index can be obtained") {
- EXPECT_EQUAL(ValueType::error_type().dimension_index("x"), ValueType::Dimension::npos);
- EXPECT_EQUAL(ValueType::double_type().dimension_index("x"), ValueType::Dimension::npos);
- EXPECT_EQUAL(ValueType::tensor_type({}).dimension_index("x"), ValueType::Dimension::npos);
- auto my_type = ValueType::tensor_type({{"y", 10}, {"x"}, {"z", 5}});
- EXPECT_EQUAL(my_type.dimension_index("x"), 0u);
- EXPECT_EQUAL(my_type.dimension_index("y"), 1u);
- EXPECT_EQUAL(my_type.dimension_index("z"), 2u);
- EXPECT_EQUAL(my_type.dimension_index("w"), ValueType::Dimension::npos);
+TEST("require that use of unbound dimensions result in error types") {
+ EXPECT_TRUE(ValueType::tensor_type({{"x", 0}}).is_error());
}
+TEST("require that duplicate dimension names result in error types") {
+ EXPECT_TRUE(ValueType::tensor_type({{"x"}, {"x"}}).is_error());
+}
+
+//-----------------------------------------------------------------------------
+
void verify_equal(const ValueType &a, const ValueType &b) {
EXPECT_EQUAL(a, b);
EXPECT_EQUAL(b, a);
@@ -94,149 +120,46 @@ TEST("require that value types can be compared") {
TEST_DO(verify_equal(ValueType::tensor_type({{"x", 10}, {"y", 20}}), ValueType::tensor_type({{"y", 20}, {"x", 10}})));
TEST_DO(verify_not_equal(ValueType::tensor_type({{"x", 10}, {"y", 20}}), ValueType::tensor_type({{"x", 10}, {"y", 10}})));
TEST_DO(verify_not_equal(ValueType::tensor_type({{"x", 10}}), ValueType::tensor_type({{"x"}})));
+ TEST_DO(verify_equal(ValueType::tensor_type({{"x", 10}}, CellType::FLOAT), ValueType::tensor_type({{"x", 10}}, CellType::FLOAT)));
+ TEST_DO(verify_not_equal(ValueType::tensor_type({{"x", 10}}, CellType::DOUBLE), ValueType::tensor_type({{"x", 10}}, CellType::FLOAT)));
}
-void verify_predicates(const ValueType &type,
- bool expect_error, bool expect_double, bool expect_tensor,
- bool expect_sparse, bool expect_dense)
-{
- EXPECT_EQUAL(type.is_error(), expect_error);
- EXPECT_EQUAL(type.is_double(), expect_double);
- EXPECT_EQUAL(type.is_tensor(), expect_tensor);
- EXPECT_EQUAL(type.is_sparse(), expect_sparse);
- EXPECT_EQUAL(type.is_dense(), expect_dense);
-}
-
-TEST("require that type-related predicate functions work as expected") {
- TEST_DO(verify_predicates(ValueType::error_type(), true, false, false, false, false));
- TEST_DO(verify_predicates(ValueType::double_type(), false, true, false, false, false));
- TEST_DO(verify_predicates(ValueType::tensor_type({}), false, true, false, false, false));
- TEST_DO(verify_predicates(ValueType::tensor_type({{"x"}}), false, false, true, true, false));
- TEST_DO(verify_predicates(ValueType::tensor_type({{"x"},{"y"}}), false, false, true, true, false));
- TEST_DO(verify_predicates(ValueType::tensor_type({{"x", 5}}), false, false, true, false, true));
- TEST_DO(verify_predicates(ValueType::tensor_type({{"x", 5},{"y", 10}}), false, false, true, false, true));
- TEST_DO(verify_predicates(ValueType::tensor_type({{"x", 5}, {"y"}}), false, false, true, false, false));
-}
-
-TEST("require that dimension predicates work as expected") {
- ValueType::Dimension x("x");
- ValueType::Dimension y("y", 10);
- ValueType::Dimension z("z", 0);
- EXPECT_TRUE(x.is_mapped());
- EXPECT_TRUE(!x.is_indexed());
- EXPECT_TRUE(!x.is_bound());
- EXPECT_TRUE(!y.is_mapped());
- EXPECT_TRUE(y.is_indexed());
- EXPECT_TRUE(y.is_bound());
- EXPECT_TRUE(!z.is_mapped());
- EXPECT_TRUE(z.is_indexed());
- EXPECT_TRUE(!z.is_bound());
-}
-
-TEST("require that use of unbound dimensions result in error types") {
- EXPECT_TRUE(ValueType::tensor_type({{"x", 0}}).is_error());
-}
-
-TEST("require that duplicate dimension names result in error types") {
- EXPECT_TRUE(ValueType::tensor_type({{"x"}, {"x"}}).is_error());
-}
-
-TEST("require that removing dimensions from non-tensor types gives error type") {
- EXPECT_TRUE(ValueType::error_type().reduce({"x"}).is_error());
- EXPECT_TRUE(ValueType::double_type().reduce({"x"}).is_error());
-}
-
-TEST("require that dimensions can be removed from tensor value types") {
- ValueType type = ValueType::tensor_type({{"x", 10}, {"y", 20}, {"z", 30}});
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 20}, {"z", 30}}), type.reduce({"x"}));
- EXPECT_EQUAL(ValueType::tensor_type({{"x", 10}, {"z", 30}}), type.reduce({"y"}));
- EXPECT_EQUAL(ValueType::tensor_type({{"x", 10}, {"y", 20}}), type.reduce({"z"}));
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 20}}), type.reduce({"x", "z"}));
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 20}}), type.reduce({"z", "x"}));
-}
-
-TEST("require that removing an empty set of dimensions means removing them all") {
- EXPECT_EQUAL(ValueType::tensor_type({{"x", 10}, {"y", 20}, {"z", 30}}).reduce({}), ValueType::double_type());
-}
-
-TEST("require that removing non-existing dimensions gives error type") {
- EXPECT_TRUE(ValueType::tensor_type({{"y"}}).reduce({"x"}).is_error());
- EXPECT_TRUE(ValueType::tensor_type({{"y", 10}}).reduce({"x"}).is_error());
-}
-
-TEST("require that removing all dimensions gives double type") {
- ValueType type = ValueType::tensor_type({{"x", 10}, {"y", 20}, {"z", 30}});
- EXPECT_EQUAL(ValueType::double_type(), type.reduce({"x", "y", "z"}));
-}
-
-TEST("require that dimensions can be combined for value types") {
- ValueType tensor_type_xy = ValueType::tensor_type({{"x"}, {"y"}});
- ValueType tensor_type_yz = ValueType::tensor_type({{"y"}, {"z"}});
- ValueType tensor_type_xyz = ValueType::tensor_type({{"x"}, {"y"}, {"z"}});
- ValueType tensor_type_y = ValueType::tensor_type({{"y"}});
- ValueType tensor_type_a10 = ValueType::tensor_type({{"a", 10}});
- ValueType tensor_type_a10xyz = ValueType::tensor_type({{"a", 10}, {"x"}, {"y"}, {"z"}});
- ValueType scalar = ValueType::double_type();
- EXPECT_EQUAL(ValueType::join(scalar, scalar), scalar);
- EXPECT_EQUAL(ValueType::join(tensor_type_xy, tensor_type_yz), tensor_type_xyz);
- EXPECT_EQUAL(ValueType::join(tensor_type_yz, tensor_type_xy), tensor_type_xyz);
- EXPECT_EQUAL(ValueType::join(tensor_type_y, tensor_type_y), tensor_type_y);
- EXPECT_EQUAL(ValueType::join(scalar, tensor_type_y), tensor_type_y);
- EXPECT_EQUAL(ValueType::join(tensor_type_a10, tensor_type_a10), tensor_type_a10);
- EXPECT_EQUAL(ValueType::join(tensor_type_a10, scalar), tensor_type_a10);
- EXPECT_EQUAL(ValueType::join(tensor_type_xyz, tensor_type_a10), tensor_type_a10xyz);
-}
-
-void verify_not_combinable(const ValueType &a, const ValueType &b) {
- EXPECT_TRUE(ValueType::join(a, b).is_error());
- EXPECT_TRUE(ValueType::join(b, a).is_error());
-}
-
-TEST("require that mapped and indexed dimensions are not combinable") {
- verify_not_combinable(ValueType::tensor_type({{"x", 10}}), ValueType::tensor_type({{"x"}}));
-}
-
-TEST("require that indexed dimensions of different sizes are not combinable") {
- verify_not_combinable(ValueType::tensor_type({{"x", 10}}), ValueType::tensor_type({{"x", 20}}));
-}
-
-TEST("require that error type combined with anything produces error type") {
- verify_not_combinable(ValueType::error_type(), ValueType::error_type());
- verify_not_combinable(ValueType::error_type(), ValueType::double_type());
- verify_not_combinable(ValueType::error_type(), ValueType::tensor_type({{"x"}}));
- verify_not_combinable(ValueType::error_type(), ValueType::tensor_type({{"x", 10}}));
-}
+//-----------------------------------------------------------------------------
TEST("require that value type can make spec") {
EXPECT_EQUAL("error", ValueType::error_type().to_spec());
EXPECT_EQUAL("double", ValueType::double_type().to_spec());
EXPECT_EQUAL("double", ValueType::tensor_type({}).to_spec());
+ EXPECT_EQUAL("double", ValueType::tensor_type({}, CellType::FLOAT).to_spec());
EXPECT_EQUAL("tensor(x{})", ValueType::tensor_type({{"x"}}).to_spec());
EXPECT_EQUAL("tensor(y[10])", ValueType::tensor_type({{"y", 10}}).to_spec());
EXPECT_EQUAL("tensor(x{},y[10],z[5])", ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}).to_spec());
+ EXPECT_EQUAL("tensor<float>(x{})", ValueType::tensor_type({{"x"}}, CellType::FLOAT).to_spec());
+ EXPECT_EQUAL("tensor<float>(y[10])", ValueType::tensor_type({{"y", 10}}, CellType::FLOAT).to_spec());
+ EXPECT_EQUAL("tensor<float>(x{},y[10],z[5])", ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}, CellType::FLOAT).to_spec());
}
+//-----------------------------------------------------------------------------
+
TEST("require that value type spec can be parsed") {
EXPECT_EQUAL(ValueType::double_type(), ValueType::from_spec("double"));
- EXPECT_EQUAL(ValueType::tensor_type({}), ValueType::from_spec("tensor"));
EXPECT_EQUAL(ValueType::tensor_type({}), ValueType::from_spec("tensor()"));
EXPECT_EQUAL(ValueType::tensor_type({{"x"}}), ValueType::from_spec("tensor(x{})"));
EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec("tensor(y[10])"));
EXPECT_EQUAL(ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}), ValueType::from_spec("tensor(x{},y[10],z[5])"));
EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec("tensor<double>(y[10])"));
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec("tensor<float>(y[10])"));
+ EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}, CellType::FLOAT), ValueType::from_spec("tensor<float>(y[10])"));
}
TEST("require that value type spec can be parsed with extra whitespace") {
EXPECT_EQUAL(ValueType::double_type(), ValueType::from_spec(" double "));
- EXPECT_EQUAL(ValueType::tensor_type({}), ValueType::from_spec(" tensor "));
EXPECT_EQUAL(ValueType::tensor_type({}), ValueType::from_spec(" tensor ( ) "));
EXPECT_EQUAL(ValueType::tensor_type({{"x"}}), ValueType::from_spec(" tensor ( x { } ) "));
EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec(" tensor ( y [ 10 ] ) "));
EXPECT_EQUAL(ValueType::tensor_type({{"x"}, {"y", 10}, {"z", 5}}),
ValueType::from_spec(" tensor ( x { } , y [ 10 ] , z [ 5 ] ) "));
EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec(" tensor < double > ( y [ 10 ] ) "));
- EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}), ValueType::from_spec(" tensor < float > ( y [ 10 ] ) "));
+ EXPECT_EQUAL(ValueType::tensor_type({{"y", 10}}, CellType::FLOAT), ValueType::from_spec(" tensor < float > ( y [ 10 ] ) "));
}
TEST("require that malformed value type spec is parsed as error") {
@@ -244,7 +167,9 @@ TEST("require that malformed value type spec is parsed as error") {
EXPECT_TRUE(ValueType::from_spec(" ").is_error());
EXPECT_TRUE(ValueType::from_spec("error").is_error());
EXPECT_TRUE(ValueType::from_spec("any").is_error());
- EXPECT_TRUE(ValueType::from_spec("tensor tensor").is_error());
+ EXPECT_TRUE(ValueType::from_spec("tensor").is_error());
+ EXPECT_TRUE(ValueType::from_spec("tensor<double>").is_error());
+ EXPECT_TRUE(ValueType::from_spec("tensor() tensor()").is_error());
EXPECT_TRUE(ValueType::from_spec("tensor(x{10})").is_error());
EXPECT_TRUE(ValueType::from_spec("tensor(x{},)").is_error());
EXPECT_TRUE(ValueType::from_spec("tensor(,x{})").is_error());
@@ -277,9 +202,8 @@ ParseResult::ParseResult(const vespalib::string &spec_in)
pos(spec.data()),
end(pos + spec.size()),
after(nullptr),
- type(value_type::parse_spec(pos, end, after))
-{ }
-ParseResult::~ParseResult() { }
+ type(value_type::parse_spec(pos, end, after)) {}
+ParseResult::~ParseResult() = default;
TEST("require that we can parse a partial string into a type with the low-level API") {
ParseResult result("tensor(a[5]) , ");
@@ -297,56 +221,177 @@ TEST("require that 'error' is the valid representation of the error type") {
EXPECT_TRUE(invalid.after == nullptr); // parse not ok
}
+//-----------------------------------------------------------------------------
+
+TEST("require that value types preserve cell type") {
+ EXPECT_TRUE(type("tensor(x[10])").cell_type() == CellType::DOUBLE);
+ EXPECT_TRUE(type("tensor<double>(x[10])").cell_type() == CellType::DOUBLE);
+ EXPECT_TRUE(type("tensor<float>(x[10])").cell_type() == CellType::FLOAT);
+}
+
+TEST("require that dimension names can be obtained") {
+ EXPECT_EQUAL(type("double").dimension_names(), str_list({}));
+ EXPECT_EQUAL(type("tensor(y[30],x[10])").dimension_names(), str_list({"x", "y"}));
+ EXPECT_EQUAL(type("tensor<float>(y[10],x[30],z{})").dimension_names(), str_list({"x", "y", "z"}));
+}
+
+TEST("require that dimension index can be obtained") {
+ EXPECT_EQUAL(type("error").dimension_index("x"), ValueType::Dimension::npos);
+ EXPECT_EQUAL(type("double").dimension_index("x"), ValueType::Dimension::npos);
+ EXPECT_EQUAL(type("tensor()").dimension_index("x"), ValueType::Dimension::npos);
+ EXPECT_EQUAL(type("tensor(y[10],x{},z[5])").dimension_index("x"), 0u);
+ EXPECT_EQUAL(type("tensor<float>(y[10],x{},z[5])").dimension_index("y"), 1u);
+ EXPECT_EQUAL(type("tensor(y[10],x{},z[5])").dimension_index("z"), 2u);
+ EXPECT_EQUAL(type("tensor(y[10],x{},z[5])").dimension_index("w"), ValueType::Dimension::npos);
+}
+
+void verify_predicates(const ValueType &type,
+ bool expect_error, bool expect_double, bool expect_tensor,
+ bool expect_sparse, bool expect_dense)
+{
+ EXPECT_EQUAL(type.is_error(), expect_error);
+ EXPECT_EQUAL(type.is_double(), expect_double);
+ EXPECT_EQUAL(type.is_tensor(), expect_tensor);
+ EXPECT_EQUAL(type.is_sparse(), expect_sparse);
+ EXPECT_EQUAL(type.is_dense(), expect_dense);
+}
+
+TEST("require that type-related predicate functions work as expected") {
+ TEST_DO(verify_predicates(type("error"), true, false, false, false, false));
+ TEST_DO(verify_predicates(type("double"), false, true, false, false, false));
+ TEST_DO(verify_predicates(type("tensor()"), false, true, false, false, false));
+ TEST_DO(verify_predicates(type("tensor(x{})"), false, false, true, true, false));
+ TEST_DO(verify_predicates(type("tensor(x{},y{})"), false, false, true, true, false));
+ TEST_DO(verify_predicates(type("tensor(x[5])"), false, false, true, false, true));
+ TEST_DO(verify_predicates(type("tensor(x[5],y[10])"), false, false, true, false, true));
+ TEST_DO(verify_predicates(type("tensor(x[5],y{})"), false, false, true, false, false));
+ TEST_DO(verify_predicates(type("tensor<float>(x{})"), false, false, true, true, false));
+ TEST_DO(verify_predicates(type("tensor<float>(x[5])"), false, false, true, false, true));
+ TEST_DO(verify_predicates(type("tensor<float>(x[5],y{})"), false, false, true, false, false));
+}
+
+TEST("require that dimension predicates work as expected") {
+ ValueType::Dimension x("x");
+ ValueType::Dimension y("y", 10);
+ ValueType::Dimension z("z", 0);
+ EXPECT_TRUE(x.is_mapped());
+ EXPECT_TRUE(!x.is_indexed());
+ EXPECT_TRUE(!x.is_bound());
+ EXPECT_TRUE(!y.is_mapped());
+ EXPECT_TRUE(y.is_indexed());
+ EXPECT_TRUE(y.is_bound());
+ EXPECT_TRUE(!z.is_mapped());
+ EXPECT_TRUE(z.is_indexed());
+ EXPECT_TRUE(!z.is_bound());
+}
+
+TEST("require that removing dimensions from non-tensor types gives error type") {
+ EXPECT_TRUE(type("error").reduce({"x"}).is_error());
+ EXPECT_TRUE(type("double").reduce({"x"}).is_error());
+}
+
+TEST("require that dimensions can be removed from tensor value types") {
+ EXPECT_EQUAL(type("tensor(x[10],y[20],z[30])").reduce({"x"}), type("tensor(y[20],z[30])"));
+ EXPECT_EQUAL(type("tensor(x[10],y[20],z[30])").reduce({"y"}), type("tensor(x[10],z[30])"));
+ EXPECT_EQUAL(type("tensor<float>(x[10],y[20],z[30])").reduce({"z"}), type("tensor<float>(x[10],y[20])"));
+ EXPECT_EQUAL(type("tensor(x[10],y[20],z[30])").reduce({"x", "z"}), type("tensor(y[20])"));
+ EXPECT_EQUAL(type("tensor<float>(x[10],y[20],z[30])").reduce({"z", "x"}), type("tensor<float>(y[20])"));
+}
+
+TEST("require that removing an empty set of dimensions means removing them all") {
+ EXPECT_EQUAL(type("tensor(x[10],y[20],z[30])").reduce({}), type("double"));
+ EXPECT_EQUAL(type("tensor<float>(x[10],y[20],z[30])").reduce({}), type("double"));
+}
+
+TEST("require that removing non-existing dimensions gives error type") {
+ EXPECT_TRUE(type("tensor(y{})").reduce({"x"}).is_error());
+ EXPECT_TRUE(type("tensor<float>(y[10])").reduce({"x"}).is_error());
+}
+
+TEST("require that removing all dimensions gives double type") {
+ EXPECT_EQUAL(type("tensor(x[10],y[20],z[30])").reduce({"x", "y", "z"}), type("double"));
+ EXPECT_EQUAL(type("tensor<float>(x[10],y[20],z[30])").reduce({"x", "y", "z"}), type("double"));
+}
+
+void verify_join(const ValueType &a, const ValueType b, const ValueType &res) {
+ EXPECT_EQUAL(ValueType::join(a, b), res);
+ EXPECT_EQUAL(ValueType::join(b, a), res);
+}
+
+TEST("require that dimensions can be combined for value types") {
+ TEST_DO(verify_join(type("double"), type("double"), type("double")));
+ TEST_DO(verify_join(type("tensor(x{},y{})"), type("tensor(y{},z{})"), type("tensor(x{},y{},z{})")));
+ TEST_DO(verify_join(type("tensor(y{})"), type("tensor(y{})"), type("tensor(y{})")));
+ TEST_DO(verify_join(type("tensor(y{})"), type("double"), type("tensor(y{})")));
+ TEST_DO(verify_join(type("tensor(a[10])"), type("tensor(a[10])"), type("tensor(a[10])")));
+ TEST_DO(verify_join(type("tensor(a[10])"), type("double"), type("tensor(a[10])")));
+ TEST_DO(verify_join(type("tensor(a[10])"), type("tensor(x{},y{},z{})"), type("tensor(a[10],x{},y{},z{})")));
+}
+
+TEST("require that cell type is handled correctly for join") {
+ TEST_DO(verify_join(type("tensor(x{})"), type("tensor<float>(y{})"), type("tensor(x{},y{})")));
+ TEST_DO(verify_join(type("tensor<float>(x{})"), type("tensor<float>(y{})"), type("tensor<float>(x{},y{})")));
+ TEST_DO(verify_join(type("tensor<float>(x{})"), type("double"), type("tensor<float>(x{})")));
+}
+
+void verify_not_joinable(const ValueType &a, const ValueType &b) {
+ EXPECT_TRUE(ValueType::join(a, b).is_error());
+ EXPECT_TRUE(ValueType::join(b, a).is_error());
+}
+
+TEST("require that mapped and indexed dimensions are not joinable") {
+ verify_not_joinable(type("tensor(x[10])"), type("tensor(x{})"));
+}
+
+TEST("require that indexed dimensions of different sizes are not joinable") {
+ verify_not_joinable(type("tensor(x[10])"), type("tensor(x[20])"));
+}
+
+TEST("require that error type combined with anything produces error type") {
+ verify_not_joinable(type("error"), type("error"));
+ verify_not_joinable(type("error"), type("double"));
+ verify_not_joinable(type("error"), type("tensor(x{})"));
+ verify_not_joinable(type("error"), type("tensor(x[10])"));
+}
+
TEST("require that tensor dimensions can be renamed") {
- EXPECT_EQUAL(ValueType::from_spec("tensor(x{})").rename({"x"}, {"y"}),
- ValueType::from_spec("tensor(y{})"));
- EXPECT_EQUAL(ValueType::from_spec("tensor(x{},y[5])").rename({"x","y"}, {"y","x"}),
- ValueType::from_spec("tensor(y{},x[5])"));
- EXPECT_EQUAL(ValueType::from_spec("tensor(x{})").rename({"x"}, {"x"}),
- ValueType::from_spec("tensor(x{})"));
- EXPECT_EQUAL(ValueType::from_spec("tensor(x{})").rename({}, {}), ValueType::error_type());
- EXPECT_EQUAL(ValueType::double_type().rename({}, {}), ValueType::error_type());
- EXPECT_EQUAL(ValueType::from_spec("tensor(x{},y{})").rename({"x"}, {"y","z"}), ValueType::error_type());
- EXPECT_EQUAL(ValueType::from_spec("tensor(x{},y{})").rename({"x","y"}, {"z"}), ValueType::error_type());
- EXPECT_EQUAL(ValueType::double_type().rename({"a"}, {"b"}), ValueType::error_type());
- EXPECT_EQUAL(ValueType::error_type().rename({"a"}, {"b"}), ValueType::error_type());
+ EXPECT_EQUAL(type("tensor(x{})").rename({"x"}, {"y"}), type("tensor(y{})"));
+ EXPECT_EQUAL(type("tensor(x{},y[5])").rename({"x","y"}, {"y","x"}), type("tensor(y{},x[5])"));
+ EXPECT_EQUAL(type("tensor(x{})").rename({"x"}, {"x"}), type("tensor(x{})"));
+ EXPECT_EQUAL(type("tensor(x{})").rename({}, {}), type("error"));
+ EXPECT_EQUAL(type("double").rename({}, {}), type("error"));
+ EXPECT_EQUAL(type("tensor(x{},y{})").rename({"x"}, {"y","z"}), type("error"));
+ EXPECT_EQUAL(type("tensor(x{},y{})").rename({"x","y"}, {"z"}), type("error"));
+ EXPECT_EQUAL(type("double").rename({"a"}, {"b"}), type("error"));
+ EXPECT_EQUAL(type("error").rename({"a"}, {"b"}), type("error"));
+}
+
+void verify_concat(const ValueType &a, const ValueType b, const vespalib::string &dim, const ValueType &res) {
+ EXPECT_EQUAL(ValueType::concat(a, b, dim), res);
+ EXPECT_EQUAL(ValueType::concat(b, a, dim), res);
}
TEST("require that types can be concatenated") {
- ValueType error = ValueType::error_type();
- ValueType scalar = ValueType::double_type();
- ValueType vx_2 = ValueType::from_spec("tensor(x[2])");
- ValueType vx_m = ValueType::from_spec("tensor(x{})");
- ValueType vx_3 = ValueType::from_spec("tensor(x[3])");
- ValueType vx_5 = ValueType::from_spec("tensor(x[5])");
- ValueType vy_7 = ValueType::from_spec("tensor(y[7])");
- ValueType mxy_22 = ValueType::from_spec("tensor(x[2],y[2])");
- ValueType mxy_52 = ValueType::from_spec("tensor(x[5],y[2])");
- ValueType mxy_29 = ValueType::from_spec("tensor(x[2],y[9])");
- ValueType cxyz_572 = ValueType::from_spec("tensor(x[5],y[7],z[2])");
- ValueType cxyz_m72 = ValueType::from_spec("tensor(x{},y[7],z[2])");
-
- EXPECT_EQUAL(ValueType::concat(error, vx_2, "x"), error);
- EXPECT_EQUAL(ValueType::concat(vx_2, error, "x"), error);
- EXPECT_EQUAL(ValueType::concat(vx_m, vx_2, "x"), error);
- EXPECT_EQUAL(ValueType::concat(vx_2, vx_m, "x"), error);
- EXPECT_EQUAL(ValueType::concat(vx_m, vx_m, "x"), error);
- EXPECT_EQUAL(ValueType::concat(vx_m, scalar, "x"), error);
- EXPECT_EQUAL(ValueType::concat(scalar, vx_m, "x"), error);
- EXPECT_EQUAL(ValueType::concat(vx_2, vx_3, "y"), error);
- EXPECT_EQUAL(ValueType::concat(vy_7, vx_m, "z"), cxyz_m72);
- EXPECT_EQUAL(ValueType::concat(scalar, scalar, "x"), vx_2);
- EXPECT_EQUAL(ValueType::concat(vx_2, scalar, "x"), vx_3);
- EXPECT_EQUAL(ValueType::concat(scalar, vx_2, "x"), vx_3);
- EXPECT_EQUAL(ValueType::concat(vx_2, vx_3, "x"), vx_5);
- EXPECT_EQUAL(ValueType::concat(scalar, vx_2, "y"), mxy_22);
- EXPECT_EQUAL(ValueType::concat(vx_2, scalar, "y"), mxy_22);
- EXPECT_EQUAL(ValueType::concat(vx_2, vx_2, "y"), mxy_22);
- EXPECT_EQUAL(ValueType::concat(mxy_22, vx_3, "x"), mxy_52);
- EXPECT_EQUAL(ValueType::concat(vx_3, mxy_22, "x"), mxy_52);
- EXPECT_EQUAL(ValueType::concat(mxy_22, vy_7, "y"), mxy_29);
- EXPECT_EQUAL(ValueType::concat(vy_7, mxy_22, "y"), mxy_29);
- EXPECT_EQUAL(ValueType::concat(vx_5, vy_7, "z"), cxyz_572);
+ TEST_DO(verify_concat(type("error"), type("tensor(x[2])"), "x", type("error")));
+ TEST_DO(verify_concat(type("tensor(x{})"), type("tensor(x[2])"), "x", type("error")));
+ TEST_DO(verify_concat(type("tensor(x{})"), type("tensor(x{})"), "x", type("error")));
+ TEST_DO(verify_concat(type("tensor(x{})"), type("double"), "x", type("error")));
+ TEST_DO(verify_concat(type("tensor(x[3])"), type("tensor(x[2])"), "y", type("error")));
+ TEST_DO(verify_concat(type("tensor(y[7])"), type("tensor(x{})"), "z", type("tensor(x{},y[7],z[2])")));
+ TEST_DO(verify_concat(type("double"), type("double"), "x", type("tensor(x[2])")));
+ TEST_DO(verify_concat(type("tensor(x[2])"), type("double"), "x", type("tensor(x[3])")));
+ TEST_DO(verify_concat(type("tensor(x[3])"), type("tensor(x[2])"), "x", type("tensor(x[5])")));
+ TEST_DO(verify_concat(type("tensor(x[2])"), type("double"), "y", type("tensor(x[2],y[2])")));
+ TEST_DO(verify_concat(type("tensor(x[2])"), type("tensor(x[2])"), "y", type("tensor(x[2],y[2])")));
+ TEST_DO(verify_concat(type("tensor(x[2],y[2])"), type("tensor(x[3])"), "x", type("tensor(x[5],y[2])")));
+ TEST_DO(verify_concat(type("tensor(x[2],y[2])"), type("tensor(y[7])"), "y", type("tensor(x[2],y[9])")));
+ TEST_DO(verify_concat(type("tensor(x[5])"), type("tensor(y[7])"), "z", type("tensor(x[5],y[7],z[2])")));
+}
+
+TEST("require that cell type is handled correctly for concat") {
+ TEST_DO(verify_concat(type("tensor<float>(x[3])"), type("tensor(x[2])"), "x", type("tensor(x[5])")));
+ TEST_DO(verify_concat(type("tensor<float>(x[3])"), type("tensor<float>(x[2])"), "x", type("tensor<float>(x[5])")));
+ TEST_DO(verify_concat(type("tensor<float>(x[3])"), type("double"), "x", type("tensor<float>(x[4])")));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
index 9b46fc3393a..eaf4623afea 100644
--- a/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_add_dimension_optimizer/dense_add_dimension_optimizer_test.cpp
@@ -99,4 +99,8 @@ TEST("require that dimension addition optimization requires unit constant tensor
TEST_DO(verify_not_optimized("tensor(x[2])(1)*tensor(y[2])(1)"));
}
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(verify_not_optimized("x5*tensor<float>(a[1],b[1],c[1])(1)"));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
index fae5db75618..b5a1cac6d0f 100644
--- a/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -103,6 +103,7 @@ EvalFixture::ParamRepo make_params() {
.add("v04_y3", spec({y(3)}, MyVecSeq(10)))
.add("v05_x5", spec({x(5)}, MyVecSeq(6.0)))
.add("v06_x5", spec({x(5)}, MyVecSeq(7.0)))
+ .add("v07_x5f", spec({x(5)}, MyVecSeq(7.0)), "tensor<float>(x[5])")
.add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
.add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0)));
}
@@ -183,11 +184,17 @@ void verify_not_compatible(const vespalib::string &a, const vespalib::string &b)
TEST("require that type compatibility test is appropriate") {
TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[5])"));
+ TEST_DO(verify_not_compatible("tensor(x[5])", "tensor<float>(x[5])"));
+ TEST_DO(verify_not_compatible("tensor<float>(x[5])", "tensor<float>(x[5])"));
+ TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(x[6])"));
TEST_DO(verify_not_compatible("tensor(x[5])", "tensor(y[5])"));
- TEST_DO(verify_compatible("tensor(x[5])", "tensor(x[3])"));
- TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])"));
- TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[5],z[9])"));
- TEST_DO(verify_not_compatible("tensor(x[5],y[7],z[9])", "tensor(x[5],y[7],z[5])"));
+ TEST_DO(verify_compatible("tensor(x[3],y[7],z[9])", "tensor(x[3],y[7],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[3],y[7],z[9])", "tensor(x[5],y[7],z[9])"));
+ TEST_DO(verify_not_compatible("tensor(x[9],y[7],z[5])", "tensor(x[5],y[7],z[9])"));
+}
+
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(assertNotOptimized("reduce(v05_x5*v07_x5f,sum)"));
}
//-----------------------------------------------------------------------------
diff --git a/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp b/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
index 10b4c622a0a..773381b4c77 100644
--- a/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_fast_rename_optimizer/dense_fast_rename_optimizer_test.cpp
@@ -25,6 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref();
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("x5", spec({x(5)}, N()))
+ .add("x5f", spec({x(5)}, N()), "tensor<float>(x[5])")
.add("x_m", spec({x({"a", "b", "c"})}, N()))
.add("x5y3", spec({x(5),y(3)}, N()));
}
@@ -71,4 +72,8 @@ TEST("require that chained optimized renames are compacted into a single operati
TEST_DO(verify_optimized("rename(rename(x5,x,y),y,z)"));
}
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(verify_not_optimized("rename(x5f,x,y)"));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
index 7ee603e1763..c9e581e6b21 100644
--- a/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_join_function/dense_inplace_join_function_test.cpp
@@ -45,6 +45,8 @@ EvalFixture::ParamRepo make_params() {
.add_mutable("mut_x5_A", spec({x(5)}, seq))
.add_mutable("mut_x5_B", spec({x(5)}, seq))
.add_mutable("mut_x5_C", spec({x(5)}, seq))
+ .add_mutable("mut_x5f_D", spec({x(5)}, seq), "tensor<float>(x[5])")
+ .add_mutable("mut_x5f_E", spec({x(5)}, seq), "tensor<float>(x[5])")
.add_mutable("mut_x5y3_A", spec({x(5),y(3)}, seq))
.add_mutable("mut_x5y3_B", spec({x(5),y(3)}, seq))
.add_mutable("mut_x_sparse", spec({x({"a", "b", "c"})}, seq));
@@ -142,4 +144,10 @@ TEST("require that inplace join can be debug dumped") {
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(verify_not_optimized("mut_x5_A-mut_x5f_D"));
+ TEST_DO(verify_not_optimized("mut_x5f_D-mut_x5_A"));
+ TEST_DO(verify_not_optimized("mut_x5f_D-mut_x5f_E"));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp b/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
index a17b7e02eb8..36ebdec028b 100644
--- a/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
+++ b/eval/src/tests/tensor/dense_inplace_map_function/dense_inplace_map_function_test.cpp
@@ -26,6 +26,7 @@ EvalFixture::ParamRepo make_params() {
.add("x5", spec({x(5)}, N()))
.add_mutable("_d", spec(5.0))
.add_mutable("_x5", spec({x(5)}, N()))
+ .add_mutable("_x5f", spec({x(5)}, N()), "tensor<float>(x[5])")
.add_mutable("_x5y3", spec({x(5),y(3)}, N()))
.add_mutable("_x_m", spec({x({"a", "b", "c"})}, N()));
}
@@ -71,4 +72,8 @@ TEST("require that mapped tensors are not optimized") {
TEST_DO(verify_not_optimized("map(_x_m,f(x)(x+10))"));
}
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(verify_not_optimized("map(_x5f,f(x)(x+10))"));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp b/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
index ac451d10b50..65208aedb4b 100644
--- a/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
+++ b/eval/src/tests/tensor/dense_remove_dimension_optimizer/dense_remove_dimension_optimizer_test.cpp
@@ -25,6 +25,7 @@ const TensorEngine &prod_engine = DefaultTensorEngine::ref();
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("x1y5z1", spec({x(1),y(5),z(1)}, N()))
+ .add("x1y5z1f", spec({x(1),y(5),z(1)}, N()), "tensor<float>(x[1],y[5],z[1])")
.add("x1y1z1", spec({x(1),y(1),z(1)}, N()))
.add("x1y5z_m", spec({x(1),y(5),z({"a"})}, N()));
}
@@ -77,4 +78,8 @@ TEST("require that inappropriate tensor types cannot be optimized") {
TEST_DO(verify_not_optimized("reduce(x1y5z_m,sum,z)"));
}
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(verify_not_optimized("reduce(x1y5z1f,avg,x)"));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
index b55e223ab07..5c976073732 100644
--- a/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
+++ b/eval/src/tests/tensor/dense_xw_product_function/dense_xw_product_function_test.cpp
@@ -39,11 +39,13 @@ EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("y1", spec({y(1)}, MyVecSeq()))
.add("y3", spec({y(3)}, MyVecSeq()))
+ .add("y3f", spec({y(3)}, MyVecSeq()), "tensor<float>(y[3])")
.add("y5", spec({y(5)}, MyVecSeq()))
.add("y16", spec({y(16)}, MyVecSeq()))
.add("x1y1", spec({x(1),y(1)}, MyMatSeq()))
.add("y1z1", spec({y(1),z(1)}, MyMatSeq()))
.add("x2y3", spec({x(2),y(3)}, MyMatSeq()))
+ .add("x2y3f", spec({x(2),y(3)}, MyMatSeq()), "tensor<float>(x[2],y[3])")
.add("x2z3", spec({x(2),z(3)}, MyMatSeq()))
.add("y3z2", spec({y(3),z(2)}, MyMatSeq()))
.add("x8y5", spec({x(8),y(5)}, MyMatSeq()))
@@ -117,4 +119,10 @@ TEST("require that xw product can be debug dumped") {
fprintf(stderr, "%s\n", info[0]->as_string().c_str());
}
+TEST("require that optimization is disabled for tensors with non-double cells") {
+ TEST_DO(verify_not_optimized("reduce(y3f*x2y3,sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(y3*x2y3f,sum,y)"));
+ TEST_DO(verify_not_optimized("reduce(y3f*x2y3f,sum,y)"));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index c4f91067260..8b9e440318d 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -558,7 +558,7 @@ void parse_tensor_rename(ParseContext &ctx) {
}
void parse_tensor_lambda(ParseContext &ctx) {
- vespalib::string type_spec("tensor(");
+ vespalib::string type_spec("tensor");
while(!ctx.eos() && (ctx.get() != ')')) {
type_spec.push_back(ctx.get());
ctx.next();
@@ -576,6 +576,7 @@ void parse_tensor_lambda(ParseContext &ctx) {
ctx.skip_spaces();
ctx.eat('(');
parse_expression(ctx);
+ ctx.eat(')');
ctx.pop_resolve_context();
Function lambda(ctx.pop_expression(), std::move(param_names));
ctx.push_expression(std::make_unique<nodes::TensorLambda>(std::move(type), std::move(lambda)));
@@ -611,8 +612,6 @@ bool try_parse_call(ParseContext &ctx, const vespalib::string &name) {
parse_tensor_reduce(ctx);
} else if (name == "rename") {
parse_tensor_rename(ctx);
- } else if (name == "tensor") {
- parse_tensor_lambda(ctx);
} else if (name == "concat") {
parse_tensor_concat(ctx);
} else {
@@ -634,7 +633,9 @@ size_t parse_symbol(ParseContext &ctx, vespalib::string &name, ParseContext::Inp
void parse_symbol_or_call(ParseContext &ctx) {
ParseContext::InputMark before_name = ctx.get_input_mark();
vespalib::string name = get_ident(ctx, true);
- if (!try_parse_call(ctx, name)) {
+ if (name == "tensor") {
+ parse_tensor_lambda(ctx);
+ } else if (!try_parse_call(ctx, name)) {
size_t id = parse_symbol(ctx, name, before_name);
if (name.empty()) {
ctx.fail("missing value");
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 94e69aadf55..29ae02b9e65 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -4,8 +4,7 @@
#include "node_traverser.h"
#include "node_types.h"
-namespace vespalib {
-namespace eval {
+namespace vespalib::eval {
namespace nodes {
namespace {
@@ -208,5 +207,4 @@ NodeTypes::get_type(const nodes::Node &node) const
return pos->second;
}
-} // namespace vespalib::eval
-} // namespace vespalib
+}
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index 269f17b71c5..fc0f3cc5414 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -8,9 +8,27 @@ namespace vespalib::eval {
namespace {
+using CellType = ValueType::CellType;
using Dimension = ValueType::Dimension;
using DimensionList = std::vector<Dimension>;
+CellType unify(CellType a, CellType b) {
+ if (a == b) {
+ return a;
+ } else {
+ return CellType::DOUBLE;
+ }
+}
+
+CellType unify_cell_type(const ValueType &a, const ValueType &b) {
+ if (a.is_double()) {
+ return b.cell_type();
+ } else if (b.is_double()) {
+ return a.cell_type();
+ }
+ return unify(a.cell_type(), b.cell_type());
+}
+
size_t my_dimension_index(const std::vector<Dimension> &list, const vespalib::string &name) {
for (size_t idx = 0; idx < list.size(); ++idx) {
if (list[idx].name == name) {
@@ -184,7 +202,7 @@ ValueType::reduce(const std::vector<vespalib::string> &dimensions_in) const
if (removed != dimensions_in.size()) {
return error_type();
}
- return tensor_type(std::move(result));
+ return tensor_type(std::move(result), _cell_type);
}
ValueType
@@ -202,11 +220,11 @@ ValueType::rename(const std::vector<vespalib::string> &from,
if (!renamer.matched_all()) {
return error_type();
}
- return tensor_type(dim_list);
+ return tensor_type(dim_list, _cell_type);
}
ValueType
-ValueType::tensor_type(std::vector<Dimension> dimensions_in)
+ValueType::tensor_type(std::vector<Dimension> dimensions_in, CellType cell_type)
{
if (dimensions_in.empty()) {
return double_type();
@@ -215,7 +233,7 @@ ValueType::tensor_type(std::vector<Dimension> dimensions_in)
if (!verify_dimensions(dimensions_in)) {
return error_type();
}
- return ValueType(Type::TENSOR, std::move(dimensions_in));
+ return ValueType(Type::TENSOR, cell_type, std::move(dimensions_in));
}
ValueType
@@ -244,7 +262,7 @@ ValueType::join(const ValueType &lhs, const ValueType &rhs)
if (result.mismatch) {
return error_type();
}
- return tensor_type(std::move(result.dimensions));
+ return tensor_type(std::move(result.dimensions), unify(lhs._cell_type, rhs._cell_type));
}
ValueType
@@ -260,12 +278,11 @@ ValueType::concat(const ValueType &lhs, const ValueType &rhs, const vespalib::st
if (!find_dimension(result.dimensions, dimension)) {
result.dimensions.emplace_back(dimension, 2);
}
- return tensor_type(std::move(result.dimensions));
+ return tensor_type(std::move(result.dimensions), unify_cell_type(lhs, rhs));
}
ValueType
-ValueType::either(const ValueType &one, const ValueType &other)
-{
+ValueType::either(const ValueType &one, const ValueType &other) {
if (one != other) {
return error_type();
}
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 6e30a5c0a47..81788c933d7 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -16,6 +16,7 @@ class ValueType
{
public:
enum class Type { ERROR, DOUBLE, TENSOR };
+ enum class CellType { FLOAT, DOUBLE };
struct Dimension {
using size_type = uint32_t;
static constexpr size_type npos = -1;
@@ -35,14 +36,15 @@ public:
};
private:
- Type _type;
+ Type _type;
+ CellType _cell_type;
std::vector<Dimension> _dimensions;
ValueType(Type type_in)
- : _type(type_in), _dimensions() {}
+ : _type(type_in), _cell_type(CellType::DOUBLE), _dimensions() {}
- ValueType(Type type_in, std::vector<Dimension> &&dimensions_in)
- : _type(type_in), _dimensions(std::move(dimensions_in)) {}
+ ValueType(Type type_in, CellType cell_type_in, std::vector<Dimension> &&dimensions_in)
+ : _type(type_in), _cell_type(cell_type_in), _dimensions(std::move(dimensions_in)) {}
public:
ValueType(ValueType &&) = default;
@@ -51,6 +53,7 @@ public:
ValueType &operator=(const ValueType &) = default;
~ValueType();
Type type() const { return _type; }
+ CellType cell_type() const { return _cell_type; }
bool is_error() const { return (_type == Type::ERROR); }
bool is_double() const { return (_type == Type::DOUBLE); }
bool is_tensor() const { return (_type == Type::TENSOR); }
@@ -60,7 +63,9 @@ public:
size_t dimension_index(const vespalib::string &name) const;
std::vector<vespalib::string> dimension_names() const;
bool operator==(const ValueType &rhs) const {
- return ((_type == rhs._type) && (_dimensions == rhs._dimensions));
+ return ((_type == rhs._type) &&
+ (_cell_type == rhs._cell_type) &&
+ (_dimensions == rhs._dimensions));
}
bool operator!=(const ValueType &rhs) const { return !(*this == rhs); }
@@ -70,7 +75,7 @@ public:
static ValueType error_type() { return ValueType(Type::ERROR); }
static ValueType double_type() { return ValueType(Type::DOUBLE); }
- static ValueType tensor_type(std::vector<Dimension> dimensions_in);
+ static ValueType tensor_type(std::vector<Dimension> dimensions_in, CellType cell_type = CellType::DOUBLE);
static ValueType from_spec(const vespalib::string &spec);
vespalib::string to_spec() const;
static ValueType join(const ValueType &lhs, const ValueType &rhs);
diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp
index 737943f902e..bbfa6f4fa28 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.cpp
+++ b/eval/src/vespa/eval/eval/value_type_spec.cpp
@@ -8,8 +8,18 @@
namespace vespalib::eval::value_type {
+using CellType = ValueType::CellType;
+
namespace {
+const char *to_name(CellType cell_type) {
+ switch (cell_type) {
+ case CellType::DOUBLE: return "double";
+ case CellType::FLOAT: return "float";
+ }
+ abort();
+}
+
class ParseContext
{
public:
@@ -130,23 +140,21 @@ ValueType::Dimension parse_dimension(ParseContext &ctx) {
std::vector<ValueType::Dimension> parse_dimension_list(ParseContext &ctx) {
std::vector<ValueType::Dimension> list;
ctx.skip_spaces();
- if (ctx.get() == '(') {
- ctx.eat('(');
- ctx.skip_spaces();
- while (!ctx.eos() && (ctx.get() != ')')) {
- if (!list.empty()) {
- ctx.eat(',');
- }
- list.push_back(parse_dimension(ctx));
- ctx.skip_spaces();
+ ctx.eat('(');
+ ctx.skip_spaces();
+ while (!ctx.eos() && (ctx.get() != ')')) {
+ if (!list.empty()) {
+ ctx.eat(',');
}
- ctx.eat(')');
+ list.push_back(parse_dimension(ctx));
+ ctx.skip_spaces();
}
+ ctx.eat(')');
ctx.skip_spaces();
return list;
}
-vespalib::string parse_cell_type(ParseContext &ctx) {
+CellType parse_cell_type(ParseContext &ctx) {
auto mark = ctx.mark();
ctx.skip_spaces();
ctx.eat('<');
@@ -155,9 +163,14 @@ vespalib::string parse_cell_type(ParseContext &ctx) {
ctx.eat('>');
if (ctx.failed()) {
ctx.revert(mark);
- cell_type = "double";
+ return CellType::DOUBLE;
+ }
+ if (cell_type == "float") {
+ return CellType::FLOAT;
+ } else if (cell_type != "double") {
+ ctx.fail();
}
- return cell_type;
+ return CellType::DOUBLE;
}
} // namespace vespalib::eval::value_type::<anonymous>
@@ -172,13 +185,10 @@ parse_spec(const char *pos_in, const char *end_in, const char *&pos_out)
} else if (type_name == "double") {
return ValueType::double_type();
} else if (type_name == "tensor") {
- vespalib::string cell_type = parse_cell_type(ctx);
- if ((cell_type != "double") && (cell_type != "float")) {
- ctx.fail();
- }
+ ValueType::CellType cell_type = parse_cell_type(ctx);
std::vector<ValueType::Dimension> list = parse_dimension_list(ctx);
if (!ctx.failed()) {
- return ValueType::tensor_type(std::move(list));
+ return ValueType::tensor_type(std::move(list), cell_type);
}
} else {
ctx.fail();
@@ -212,22 +222,21 @@ to_spec(const ValueType &type)
break;
case ValueType::Type::TENSOR:
os << "tensor";
- if (!type.dimensions().empty()) {
- os << "(";
- for (const auto &d: type.dimensions()) {
- if (cnt++ > 0) {
- os << ",";
- }
- if (d.size == ValueType::Dimension::npos) {
- os << d.name << "{}";
- } else if (d.size == 0) {
- os << d.name << "[]";
- } else {
- os << d.name << "[" << d.size << "]";
- }
+ if (type.cell_type() != CellType::DOUBLE) {
+ os << "<" << to_name(type.cell_type()) << ">";
+ }
+ os << "(";
+ for (const auto &d: type.dimensions()) {
+ if (cnt++ > 0) {
+ os << ",";
+ }
+ if (d.size == ValueType::Dimension::npos) {
+ os << d.name << "{}";
+ } else {
+ os << d.name << "[" << d.size << "]";
}
- os << ")";
}
+ os << ")";
break;
}
return os.str();
diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
index 5a16511fe71..263d77e5d1e 100644
--- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
+++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp
@@ -125,19 +125,9 @@ Value::UP
DefaultTensorEngine::from_spec(const TensorSpec &spec) const
{
ValueType type = ValueType::from_spec(spec.type());
- bool is_dense = false;
- bool is_sparse = false;
- for (const auto &dimension: type.dimensions()) {
- if (dimension.is_mapped()) {
- is_sparse = true;
- }
- if (dimension.is_indexed()) {
- is_dense = true;
- }
- }
- if (is_dense && is_sparse) {
+ if (!tensor::Tensor::supported({type})) {
return std::make_unique<WrappedSimpleTensor>(eval::SimpleTensor::create(spec));
- } else if (is_dense) {
+ } else if (type.is_dense()) {
DenseTensorBuilder builder;
std::map<vespalib::string,DenseTensorBuilder::Dimension> dimension_map;
for (const auto &dimension: type.dimensions()) {
@@ -151,7 +141,7 @@ DefaultTensorEngine::from_spec(const TensorSpec &spec) const
builder.addCell(cell.second);
}
return builder.build();
- } else if (is_sparse) {
+ } else if (type.is_sparse()) {
DefaultTensor::builder builder;
std::map<vespalib::string,DefaultTensor::builder::Dimension> dimension_map;
for (const auto &dimension: type.dimensions()) {
diff --git a/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp
index f55566fe199..842e064de43 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_add_dimension_optimizer.cpp
@@ -20,6 +20,9 @@ using namespace eval::operation;
namespace {
bool is_concrete_dense_tensor(const ValueType &type) {
+ if (type.cell_type() != ValueType::CellType::DOUBLE) {
+ return false; // non-double cell types not supported
+ }
return type.is_dense();
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
index 859a7092ce2..988edba7d55 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_dot_product_function.cpp
@@ -51,23 +51,12 @@ DenseDotProductFunction::compile_self(Stash &) const
bool
DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
{
- if (!res.is_double() || !lhs.is_dense() || !rhs.is_dense() ||
- (lhs.dimensions().size() != rhs.dimensions().size()) ||
- (lhs.dimensions().empty()))
+ if (lhs.cell_type() != ValueType::CellType::DOUBLE ||
+ rhs.cell_type() != ValueType::CellType::DOUBLE)
{
- return false;
+ return false; // non-double cell types not supported
}
- for (size_t i = 0; i < lhs.dimensions().size(); ++i) {
- const auto &ldim = lhs.dimensions()[i];
- const auto &rdim = rhs.dimensions()[i];
- bool first = (i == 0);
- bool name_mismatch = (ldim.name != rdim.name);
- bool size_mismatch = ((ldim.size != rdim.size) || !ldim.is_bound());
- if (name_mismatch || (!first && size_mismatch)) {
- return false;
- }
- }
- return true;
+ return (res.is_double() && lhs.is_dense() && (rhs == lhs));
}
const TensorFunction &
diff --git a/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp
index b32a1efa234..09977df25b7 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_fast_rename_optimizer.cpp
@@ -22,6 +22,11 @@ bool is_concrete_dense_stable_rename(const ValueType &from_type, const ValueType
const std::vector<vespalib::string> &from,
const std::vector<vespalib::string> &to)
{
+ if (from_type.cell_type() != ValueType::CellType::DOUBLE ||
+ to_type.cell_type() != ValueType::CellType::DOUBLE)
+ {
+ return false; // non-double cell types not supported
+ }
if (!from_type.is_dense() ||
!to_type.is_dense() ||
(from.size() != to.size()))
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
index 4828808683a..ce6b1743951 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_join_function.cpp
@@ -41,6 +41,11 @@ void my_inplace_join_op(eval::InterpretedFunction::State &state, uint64_t param)
}
bool sameShapeConcreteDenseTensors(const ValueType &a, const ValueType &b) {
+ if (a.cell_type() != ValueType::CellType::DOUBLE ||
+ b.cell_type() != ValueType::CellType::DOUBLE)
+ {
+ return false; // non-double cell types not supported
+ }
return (a.is_dense() && (a == b));
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp
index bac1e336292..c72889ca0ed 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_inplace_map_function.cpp
@@ -30,6 +30,9 @@ void my_inplace_map_op(eval::InterpretedFunction::State &state, uint64_t param)
}
bool isConcreteDenseTensor(const ValueType &type) {
+ if (type.cell_type() != ValueType::CellType::DOUBLE) {
+ return false; // non-double cell types not supported
+ }
return type.is_dense();
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp
index d6716e8ad1a..3c58320a6e6 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_remove_dimension_optimizer.cpp
@@ -15,6 +15,9 @@ using namespace eval::tensor_function;
namespace {
bool is_concrete_dense_tensor(const ValueType &type) {
+ if (type.cell_type() != ValueType::CellType::DOUBLE) {
+ return false; // non-double cell types not supported
+ }
return type.is_dense();
}
diff --git a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
index 660e5e3e0b7..a3056311fab 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_xw_product_function.cpp
@@ -76,6 +76,9 @@ void my_xw_product_op(eval::InterpretedFunction::State &state, uint64_t param) {
}
bool isConcreteDenseTensor(const ValueType &type, size_t d) {
+ if (type.cell_type() != ValueType::CellType::DOUBLE) {
+ return false; // non-double cell types not supported
+ }
return (type.is_dense() && (type.dimensions().size() == d));
}
diff --git a/eval/src/vespa/eval/tensor/tensor.cpp b/eval/src/vespa/eval/tensor/tensor.cpp
index 51c94aab5b0..5697458f3ca 100644
--- a/eval/src/vespa/eval/tensor/tensor.cpp
+++ b/eval/src/vespa/eval/tensor/tensor.cpp
@@ -17,6 +17,9 @@ Tensor::supported(TypeList types)
bool sparse = false;
bool dense = false;
for (const eval::ValueType &type: types) {
+ if (type.cell_type() != eval::ValueType::CellType::DOUBLE) {
+ return false; // non-double cell types not supported
+ }
dense = (dense || type.is_double());
for (const auto &dim: type.dimensions()) {
dense = (dense || dim.is_indexed());