aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2021-03-03 08:47:14 +0100
committerGitHub <noreply@github.com>2021-03-03 08:47:14 +0100
commita6c06485fd2894c88e30c6dc58efcb4669721ccc (patch)
tree870e8c23428e777530438d8cd7e9eb958e43e189
parent898d3f3262cf313e9955b456e3d9cdeb1c5ca87d (diff)
parentfca3ba6823e2d1fa1ffe8489b63e44a271a127b6 (diff)
Merge pull request #16751 from vespa-engine/havardpe/cell-cast-in-function
enable use of cell_cast in expressions
-rw-r--r--eval/src/tests/eval/compiled_function/compiled_function_test.cpp3
-rw-r--r--eval/src/tests/eval/function/function_test.cpp11
-rw-r--r--eval/src/tests/eval/node_types/node_types_test.cpp8
-rw-r--r--eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp6
-rw-r--r--eval/src/tests/eval/reference_operations/reference_operations_test.cpp26
-rw-r--r--eval/src/tests/eval/tensor_function/tensor_function_test.cpp29
-rw-r--r--eval/src/tests/eval/value_type/value_type_test.cpp47
-rw-r--r--eval/src/vespa/eval/eval/cell_type.h5
-rw-r--r--eval/src/vespa/eval/eval/function.cpp15
-rw-r--r--eval/src/vespa/eval/eval/key_gen.cpp111
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp1
-rw-r--r--eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp3
-rw-r--r--eval/src/vespa/eval/eval/make_tensor_function.cpp9
-rw-r--r--eval/src/vespa/eval/eval/node_tools.cpp103
-rw-r--r--eval/src/vespa/eval/eval/node_types.cpp3
-rw-r--r--eval/src/vespa/eval/eval/node_visitor.h238
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.cpp29
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h31
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.cpp19
-rw-r--r--eval/src/vespa/eval/eval/tensor_nodes.h41
-rw-r--r--eval/src/vespa/eval/eval/test/eval_spec.cpp1
-rw-r--r--eval/src/vespa/eval/eval/test/reference_evaluation.cpp7
-rw-r--r--eval/src/vespa/eval/eval/test/reference_operations.cpp2
-rw-r--r--eval/src/vespa/eval/eval/test/tensor_conformance.cpp24
-rw-r--r--eval/src/vespa/eval/eval/value_type.cpp15
-rw-r--r--eval/src/vespa/eval/eval/value_type.h3
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.cpp18
-rw-r--r--eval/src/vespa/eval/eval/value_type_spec.h4
-rw-r--r--eval/src/vespa/eval/instruction/generic_cell_cast.cpp2
-rw-r--r--eval/src/vespa/eval/instruction/generic_merge.h2
30 files changed, 538 insertions, 278 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
index a19bff415e1..8e17aadcfff 100644
--- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
+++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp
@@ -55,7 +55,8 @@ std::vector<vespalib::string> unsupported = {
"reduce(",
"rename(",
"tensor(",
- "concat("
+ "concat(",
+ "cell_cast("
};
bool is_unsupported(const vespalib::string &expression) {
diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp
index 3dbc2e1aed2..4dff2934873 100644
--- a/eval/src/tests/eval/function/function_test.cpp
+++ b/eval/src/tests/eval/function/function_test.cpp
@@ -1018,6 +1018,17 @@ TEST("require that tensor concat can be parsed") {
//-----------------------------------------------------------------------------
+TEST("require that tensor cell cast can be parsed") {
+ EXPECT_EQUAL("cell_cast(a,float)", Function::parse({"a"}, "cell_cast(a,float)")->dump());
+ EXPECT_EQUAL("cell_cast(a,double)", Function::parse({"a"}, " cell_cast ( a , double ) ")->dump());
+}
+
+TEST("require that tensor cell cast must have valid cell type") {
+ TEST_DO(verify_error("cell_cast(x,int7)", "[cell_cast(x,int7]...[unknown cell type: 'int7']...[)]"));
+}
+
+//-----------------------------------------------------------------------------
+
struct CheckExpressions : test::EvalSpec::EvalTest {
bool failed = false;
size_t seen_cnt = 0;
diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp
index f595c58ef29..89c37af0a83 100644
--- a/eval/src/tests/eval/node_types/node_types_test.cpp
+++ b/eval/src/tests/eval/node_types/node_types_test.cpp
@@ -292,6 +292,14 @@ TEST("require that tensor concat resolves correct type") {
TEST_DO(verify("concat(tensor<float>(x[2]),double,x)", "tensor<float>(x[3])"));
}
+TEST("require that tensor cell_cast resolves correct type") {
+ TEST_DO(verify("cell_cast(double,float)", "double")); // NB
+ TEST_DO(verify("cell_cast(float,double)", "double"));
+ TEST_DO(verify("cell_cast(tensor<double>(x{},y[5]),float)", "tensor<float>(x{},y[5])"));
+ TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),double)", "tensor<double>(x{},y[5])"));
+ TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),float)", "tensor<float>(x{},y[5])"));
+}
+
TEST("require that double only expressions can be detected") {
auto plain_fun = Function::parse("1+2");
auto complex_fun = Function::parse("reduce(a,sum)");
diff --git a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp
index 345f04053ac..b6115b78378 100644
--- a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp
+++ b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp
@@ -151,6 +151,12 @@ TEST(ReferenceEvaluationTest, concat_expression_works) {
EXPECT_EQ(ref_eval("concat(a,b,x)", {a, b}), expect);
}
+TEST(ReferenceEvaluationTest, cell_cast_expression_works) {
+ auto a = make_val("tensor<double>(x[4]):[1,2,3,4]");
+ auto expect = make_val("tensor<float>(x[4]):[1,2,3,4]");
+ EXPECT_EQ(ref_eval("cell_cast(a,float)", {a}), expect);
+}
+
TEST(ReferenceEvaluationTest, rename_expression_works) {
auto a = make_val("tensor(x[2]):[1,2]");
auto expect = make_val("tensor(y[2]):[1,2]");
diff --git a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp
index 3fcca5e34d8..d191ee064ee 100644
--- a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp
+++ b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp
@@ -1,11 +1,14 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include <vespa/eval/eval/cell_type.h>
#include <vespa/eval/eval/test/reference_operations.h>
+#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <iostream>
using namespace vespalib;
using namespace vespalib::eval;
+using vespalib::eval::test::GenSpec;
TensorSpec dense_2d_some_cells(bool square) {
return TensorSpec("tensor(a[3],d[5])")
@@ -48,7 +51,6 @@ TensorSpec sparse_1d_all_two() {
//-----------------------------------------------------------------------------
-
TEST(ReferenceConcatTest, concat_numbers) {
auto a = TensorSpec("double").add({}, 7.0);
auto b = TensorSpec("double").add({}, 4.0);
@@ -132,6 +134,27 @@ TEST(ReferenceConcatTest, concat_mixed_tensors) {
//-----------------------------------------------------------------------------
+TEST(ReferenceCellCastTest, cell_cast_works) {
+ std::vector<GenSpec> gen_list = {
+ GenSpec(42),
+ GenSpec(-3).idx("x", 10),
+ GenSpec(-3).map("x", 10, 1),
+ GenSpec(-3).map("x", 4, 1).idx("y", 4)
+ };
+ for (CellType from_type: CellTypeUtils::list_types()) {
+ for (CellType to_type: CellTypeUtils::list_types()) {
+ for (const auto &gen: gen_list) {
+ TensorSpec input = gen.cpy().cells(from_type);
+ TensorSpec expect = gen.cpy().cells(to_type);
+ auto actual = ReferenceOperations::cell_cast(input, to_type);
+ EXPECT_EQ(actual, expect);
+ }
+ }
+ }
+}
+
+//-----------------------------------------------------------------------------
+
TEST(ReferenceCreateTest, simple_create_works) {
auto a = TensorSpec("double").add({}, 1.5);
auto b = TensorSpec("tensor(z[2])").add({{"z",0}}, 2.0).add({{"z",1}}, 3.0);
@@ -525,4 +548,3 @@ TEST(ReferenceLambdaTest, make_matrix) {
//-----------------------------------------------------------------------------
GTEST_MAIN_RUN_ALL_TESTS()
-
diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
index 4396a6773ea..f1bd900b350 100644
--- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
+++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp
@@ -31,6 +31,9 @@ struct EvalCtx {
tensors.push_back(std::move(tensor));
return id;
}
+ ValueType type_of(size_t idx) {
+ return params[idx].get().type();
+ }
void replace_tensor(size_t idx, Value::UP tensor) {
params[idx] = *tensor;
tensors[idx] = std::move(tensor);
@@ -91,6 +94,14 @@ struct EvalCtx {
.add({{"x", 1}, {"y", 0}}, 3.0)
.add({{"x", 1}, {"y", 1}}, 4.0), factory);
}
+ Value::UP make_float_tensor_matrix() {
+ return value_from_spec(
+ TensorSpec("tensor<float>(x[2],y[2])")
+ .add({{"x", 0}, {"y", 0}}, 1.0)
+ .add({{"x", 0}, {"y", 1}}, 2.0)
+ .add({{"x", 1}, {"y", 0}}, 3.0)
+ .add({{"x", 1}, {"y", 1}}, 4.0), factory);
+ }
Value::UP make_tensor_matrix_renamed() {
return value_from_spec(
TensorSpec("tensor(y[2],z[2])")
@@ -273,6 +284,16 @@ TEST("require that tensor concat works") {
TEST_DO(verify_equal(*expect, ctx.eval(fun)));
}
+TEST("require that tensor cell cast works") {
+ EvalCtx ctx(simple_factory);
+ size_t a_id = ctx.add_tensor(ctx.make_tensor_matrix());
+ Value::UP expect = ctx.make_float_tensor_matrix();
+ const auto &fun = cell_cast(inject(ctx.type_of(a_id), a_id, ctx.stash), CellType::FLOAT, ctx.stash);
+ EXPECT_TRUE(fun.result_is_mutable());
+ EXPECT_EQUAL(expect->type(), fun.result_type());
+ TEST_DO(verify_equal(*expect, ctx.eval(fun)));
+}
+
TEST("require that tensor create works") {
EvalCtx ctx(simple_factory);
size_t a_id = ctx.add_tensor(ctx.make_double(1.0));
@@ -450,6 +471,10 @@ TEST("require that push_children works") {
EXPECT_EQUAL(&refs[10].get().get(), &b);
EXPECT_EQUAL(&refs[11].get().get(), &c);
//-------------------------------------------------------------------------
+ cell_cast(a, CellType::FLOAT, stash).push_children(refs);
+ ASSERT_EQUAL(refs.size(), 13u);
+ EXPECT_EQUAL(&refs[12].get().get(), &a);
+ //-------------------------------------------------------------------------
}
TEST("require that tensor function can be dumped for debugging") {
@@ -459,7 +484,9 @@ TEST("require that tensor function can be dumped for debugging") {
auto my_value_3 = stash.create<DoubleValue>(3.0);
//-------------------------------------------------------------------------
const auto &x5 = inject(ValueType::from_spec("tensor(x[5])"), 0, stash);
- const auto &mapped_x5 = map(x5, operation::Relu::f, stash);
+ const auto &float_x5 = cell_cast(x5, CellType::FLOAT, stash);
+ const auto &double_x5 = cell_cast(float_x5, CellType::DOUBLE, stash);
+ const auto &mapped_x5 = map(double_x5, operation::Relu::f, stash);
const auto &const_1 = const_value(my_value_1, stash);
const auto &joined_x5 = join(mapped_x5, const_1, operation::Mul::f, stash);
//-------------------------------------------------------------------------
diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp
index d58adbbcef0..c1b25d48bf7 100644
--- a/eval/src/tests/eval/value_type/value_type_test.cpp
+++ b/eval/src/tests/eval/value_type/value_type_test.cpp
@@ -501,4 +501,51 @@ TEST("require that cell type is handled correctly for concat") {
TEST_DO(verify_concat(type("tensor<float>(x[3])"), type("double"), "x", type("tensor<float>(x[4])")));
}
+void verify_cell_cast(const ValueType &type) {
+ for (CellType cell_type: CellTypeUtils::list_types()) {
+ auto res_type = type.cell_cast(cell_type);
+ if (type.is_error()) {
+ EXPECT_TRUE(res_type.is_error());
+ EXPECT_EQUAL(res_type, type);
+ } else if (type.is_scalar()) {
+ EXPECT_TRUE(res_type.is_double()); // NB
+ } else {
+ EXPECT_FALSE(res_type.is_error());
+ EXPECT_EQUAL(int(res_type.cell_type()), int(cell_type));
+ EXPECT_TRUE(res_type.dimensions() == type.dimensions());
+ }
+ }
+}
+
+TEST("require that value type cell cast works correctly") {
+ TEST_DO(verify_cell_cast(type("error")));
+ TEST_DO(verify_cell_cast(type("float")));
+ TEST_DO(verify_cell_cast(type("double")));
+ TEST_DO(verify_cell_cast(type("tensor<float>(x[10])")));
+ TEST_DO(verify_cell_cast(type("tensor<double>(x[10])")));
+ TEST_DO(verify_cell_cast(type("tensor<float>(x{})")));
+ TEST_DO(verify_cell_cast(type("tensor<double>(x{})")));
+ TEST_DO(verify_cell_cast(type("tensor<float>(x{},y[5])")));
+ TEST_DO(verify_cell_cast(type("tensor<double>(x{},y[5])")));
+}
+
+TEST("require that actual cell type can be converted to cell type name") {
+ EXPECT_EQUAL(value_type::cell_type_to_name(CellType::FLOAT), "float");
+ EXPECT_EQUAL(value_type::cell_type_to_name(CellType::DOUBLE), "double");
+}
+
+TEST("require that cell type name can be converted to actual cell type") {
+ EXPECT_EQUAL(int(value_type::cell_type_from_name("float").value()), int(CellType::FLOAT));
+ EXPECT_EQUAL(int(value_type::cell_type_from_name("double").value()), int(CellType::DOUBLE));
+ EXPECT_FALSE(value_type::cell_type_from_name("int7").has_value());
+}
+
+TEST("require that cell type name recognition is strict") {
+ EXPECT_FALSE(value_type::cell_type_from_name("Float").has_value());
+ EXPECT_FALSE(value_type::cell_type_from_name(" float").has_value());
+ EXPECT_FALSE(value_type::cell_type_from_name("float ").has_value());
+ EXPECT_FALSE(value_type::cell_type_from_name("f").has_value());
+ EXPECT_FALSE(value_type::cell_type_from_name("").has_value());
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h
index 4c4f7e78413..beb0b32386f 100644
--- a/eval/src/vespa/eval/eval/cell_type.h
+++ b/eval/src/vespa/eval/eval/cell_type.h
@@ -3,6 +3,7 @@
#pragma once
#include <vespa/vespalib/util/typify.h>
+#include <vector>
#include <cstdint>
namespace vespalib::eval {
@@ -43,6 +44,10 @@ struct CellTypeUtils {
default: bad_argument((uint32_t)cell_type);
}
}
+
+ static std::vector<CellType> list_types() {
+ return {CellType::FLOAT, CellType::DOUBLE};
+ }
};
struct TypifyCellType {
diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp
index aeb53eaa6ba..b03c2c1ed24 100644
--- a/eval/src/vespa/eval/eval/function.cpp
+++ b/eval/src/vespa/eval/eval/function.cpp
@@ -830,6 +830,19 @@ void parse_tensor_concat(ParseContext &ctx) {
ctx.push_expression(std::make_unique<nodes::TensorConcat>(std::move(lhs), std::move(rhs), dimension));
}
+void parse_tensor_cell_cast(ParseContext &ctx) {
+ Node_UP child = get_expression(ctx);
+ ctx.eat(',');
+ auto cell_type_name = get_ident(ctx, false);
+ auto cell_type = value_type::cell_type_from_name(cell_type_name);
+ ctx.skip_spaces();
+ if (cell_type.has_value()) {
+ ctx.push_expression(std::make_unique<nodes::TensorCellCast>(std::move(child), cell_type.value()));
+ } else {
+ ctx.fail(make_string("unknown cell type: '%s'", cell_type_name.c_str()));
+ }
+}
+
bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) {
ctx.skip_spaces();
if (ctx.get() == '(') {
@@ -852,6 +865,8 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) {
parse_tensor_rename(ctx);
} else if (name == "concat") {
parse_tensor_concat(ctx);
+ } else if (name == "cell_cast") {
+ parse_tensor_cell_cast(ctx);
} else {
ctx.fail(make_string("unknown function: '%s'", name.c_str()));
return false;
diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp
index 31167be5fe1..80803d3b2a2 100644
--- a/eval/src/vespa/eval/eval/key_gen.cpp
+++ b/eval/src/vespa/eval/eval/key_gen.cpp
@@ -31,61 +31,62 @@ struct KeyGen : public NodeVisitor, public NodeTraverser {
add_double(node.get_entry(i).get_const_value());
}
}
- void visit(const Neg &) override { add_byte( 5); }
- void visit(const Not &) override { add_byte( 6); }
- void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); }
- void visit(const Error &) override { add_byte( 9); }
- void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key
- void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key
- void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key
- void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key
- void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key
- void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key
- void visit(const TensorCreate &) override { add_byte(16); } // type/addr should be part of key
- void visit(const TensorLambda &) override { add_byte(17); } // type/lambda should be part of key
- void visit(const TensorPeek &) override { add_byte(18); } // addr should be part of key
- void visit(const Add &) override { add_byte(20); }
- void visit(const Sub &) override { add_byte(21); }
- void visit(const Mul &) override { add_byte(22); }
- void visit(const Div &) override { add_byte(23); }
- void visit(const Mod &) override { add_byte(24); }
- void visit(const Pow &) override { add_byte(25); }
- void visit(const Equal &) override { add_byte(26); }
- void visit(const NotEqual &) override { add_byte(27); }
- void visit(const Approx &) override { add_byte(28); }
- void visit(const Less &) override { add_byte(29); }
- void visit(const LessEqual &) override { add_byte(30); }
- void visit(const Greater &) override { add_byte(31); }
- void visit(const GreaterEqual &) override { add_byte(32); }
- void visit(const And &) override { add_byte(34); }
- void visit(const Or &) override { add_byte(35); }
- void visit(const Cos &) override { add_byte(36); }
- void visit(const Sin &) override { add_byte(37); }
- void visit(const Tan &) override { add_byte(38); }
- void visit(const Cosh &) override { add_byte(39); }
- void visit(const Sinh &) override { add_byte(40); }
- void visit(const Tanh &) override { add_byte(41); }
- void visit(const Acos &) override { add_byte(42); }
- void visit(const Asin &) override { add_byte(43); }
- void visit(const Atan &) override { add_byte(44); }
- void visit(const Exp &) override { add_byte(45); }
- void visit(const Log10 &) override { add_byte(46); }
- void visit(const Log &) override { add_byte(47); }
- void visit(const Sqrt &) override { add_byte(48); }
- void visit(const Ceil &) override { add_byte(49); }
- void visit(const Fabs &) override { add_byte(50); }
- void visit(const Floor &) override { add_byte(51); }
- void visit(const Atan2 &) override { add_byte(52); }
- void visit(const Ldexp &) override { add_byte(53); }
- void visit(const Pow2 &) override { add_byte(54); }
- void visit(const Fmod &) override { add_byte(55); }
- void visit(const Min &) override { add_byte(56); }
- void visit(const Max &) override { add_byte(57); }
- void visit(const IsNan &) override { add_byte(58); }
- void visit(const Relu &) override { add_byte(59); }
- void visit(const Sigmoid &) override { add_byte(60); }
- void visit(const Elu &) override { add_byte(61); }
- void visit(const Erf &) override { add_byte(62); }
+ void visit(const Neg &) override { add_byte( 5); }
+ void visit(const Not &) override { add_byte( 6); }
+ void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); }
+ void visit(const Error &) override { add_byte( 9); }
+ void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key
+ void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key
+ void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key
+ void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key
+ void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key
+ void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key
+ void visit(const TensorCellCast &) override { add_byte(16); } // cell type should be part of key
+ void visit(const TensorCreate &) override { add_byte(17); } // type/addr should be part of key
+ void visit(const TensorLambda &) override { add_byte(18); } // type/lambda should be part of key
+ void visit(const TensorPeek &) override { add_byte(19); } // addr should be part of key
+ void visit(const Add &) override { add_byte(20); }
+ void visit(const Sub &) override { add_byte(21); }
+ void visit(const Mul &) override { add_byte(22); }
+ void visit(const Div &) override { add_byte(23); }
+ void visit(const Mod &) override { add_byte(24); }
+ void visit(const Pow &) override { add_byte(25); }
+ void visit(const Equal &) override { add_byte(26); }
+ void visit(const NotEqual &) override { add_byte(27); }
+ void visit(const Approx &) override { add_byte(28); }
+ void visit(const Less &) override { add_byte(29); }
+ void visit(const LessEqual &) override { add_byte(30); }
+ void visit(const Greater &) override { add_byte(31); }
+ void visit(const GreaterEqual &) override { add_byte(32); }
+ void visit(const And &) override { add_byte(34); }
+ void visit(const Or &) override { add_byte(35); }
+ void visit(const Cos &) override { add_byte(36); }
+ void visit(const Sin &) override { add_byte(37); }
+ void visit(const Tan &) override { add_byte(38); }
+ void visit(const Cosh &) override { add_byte(39); }
+ void visit(const Sinh &) override { add_byte(40); }
+ void visit(const Tanh &) override { add_byte(41); }
+ void visit(const Acos &) override { add_byte(42); }
+ void visit(const Asin &) override { add_byte(43); }
+ void visit(const Atan &) override { add_byte(44); }
+ void visit(const Exp &) override { add_byte(45); }
+ void visit(const Log10 &) override { add_byte(46); }
+ void visit(const Log &) override { add_byte(47); }
+ void visit(const Sqrt &) override { add_byte(48); }
+ void visit(const Ceil &) override { add_byte(49); }
+ void visit(const Fabs &) override { add_byte(50); }
+ void visit(const Floor &) override { add_byte(51); }
+ void visit(const Atan2 &) override { add_byte(52); }
+ void visit(const Ldexp &) override { add_byte(53); }
+ void visit(const Pow2 &) override { add_byte(54); }
+ void visit(const Fmod &) override { add_byte(55); }
+ void visit(const Min &) override { add_byte(56); }
+ void visit(const Max &) override { add_byte(57); }
+ void visit(const IsNan &) override { add_byte(58); }
+ void visit(const Relu &) override { add_byte(59); }
+ void visit(const Sigmoid &) override { add_byte(60); }
+ void visit(const Elu &) override { add_byte(61); }
+ void visit(const Erf &) override { add_byte(62); }
// traverse
bool open(const Node &node) override { node.accept(*this); return true; }
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index 043ad248251..4ce4fcd7747 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -132,6 +132,7 @@ CompiledFunction::detect_issues(const nodes::Node &node)
nodes::TensorReduce,
nodes::TensorRename,
nodes::TensorConcat,
+ nodes::TensorCellCast,
nodes::TensorCreate,
nodes::TensorLambda,
nodes::TensorPeek>(node))
diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
index 8fbb2c5ac09..fce9abb7316 100644
--- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
+++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp
@@ -478,6 +478,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const TensorConcat &node) override {
make_error(node.num_children());
}
+ void visit(const TensorCellCast &node) override {
+ make_error(node.num_children());
+ }
void visit(const TensorCreate &node) override {
make_error(node.num_children());
}
diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp
index 9ee42f164de..15f188db51a 100644
--- a/eval/src/vespa/eval/eval/make_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp
@@ -75,6 +75,12 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
stack.back() = tensor_function::concat(a, b, dimension, stash);
}
+ void make_cell_cast(const Node &, CellType cell_type) {
+ assert(stack.size() >= 1);
+ const auto &a = stack.back().get();
+ stack.back() = tensor_function::cell_cast(a, cell_type, stash);
+ }
+
bool maybe_make_const(const Node &node) {
if (auto create = as<TensorCreate>(node)) {
bool is_const = true;
@@ -213,6 +219,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser {
void visit(const TensorConcat &node) override {
make_concat(node, node.dimension());
}
+ void visit(const TensorCellCast &node) override {
+ make_cell_cast(node, node.cell_type());
+ }
void visit(const TensorCreate &node) override {
make_create(node);
}
diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp
index 1c194736138..e7341bc1755 100644
--- a/eval/src/vespa/eval/eval/node_tools.cpp
+++ b/eval/src/vespa/eval/eval/node_tools.cpp
@@ -126,61 +126,62 @@ struct CopyNode : NodeTraverser, NodeVisitor {
}
// tensor nodes
- void visit(const TensorMap &node) override { not_implemented(node); }
- void visit(const TensorJoin &node) override { not_implemented(node); }
- void visit(const TensorMerge &node) override { not_implemented(node); }
- void visit(const TensorReduce &node) override { not_implemented(node); }
- void visit(const TensorRename &node) override { not_implemented(node); }
- void visit(const TensorConcat &node) override { not_implemented(node); }
- void visit(const TensorCreate &node) override { not_implemented(node); }
- void visit(const TensorLambda &node) override { not_implemented(node); }
- void visit(const TensorPeek &node) override { not_implemented(node); }
+ void visit(const TensorMap &node) override { not_implemented(node); }
+ void visit(const TensorJoin &node) override { not_implemented(node); }
+ void visit(const TensorMerge &node) override { not_implemented(node); }
+ void visit(const TensorReduce &node) override { not_implemented(node); }
+ void visit(const TensorRename &node) override { not_implemented(node); }
+ void visit(const TensorConcat &node) override { not_implemented(node); }
+ void visit(const TensorCellCast &node) override { not_implemented(node); }
+ void visit(const TensorCreate &node) override { not_implemented(node); }
+ void visit(const TensorLambda &node) override { not_implemented(node); }
+ void visit(const TensorPeek &node) override { not_implemented(node); }
// operator nodes
- void visit(const Add &node) override { copy_operator(node); }
- void visit(const Sub &node) override { copy_operator(node); }
- void visit(const Mul &node) override { copy_operator(node); }
- void visit(const Div &node) override { copy_operator(node); }
- void visit(const Mod &node) override { copy_operator(node); }
- void visit(const Pow &node) override { copy_operator(node); }
- void visit(const Equal &node) override { copy_operator(node); }
- void visit(const NotEqual &node) override { copy_operator(node); }
- void visit(const Approx &node) override { copy_operator(node); }
- void visit(const Less &node) override { copy_operator(node); }
- void visit(const LessEqual &node) override { copy_operator(node); }
- void visit(const Greater &node) override { copy_operator(node); }
- void visit(const GreaterEqual &node) override { copy_operator(node); }
- void visit(const And &node) override { copy_operator(node); }
- void visit(const Or &node) override { copy_operator(node); }
+ void visit(const Add &node) override { copy_operator(node); }
+ void visit(const Sub &node) override { copy_operator(node); }
+ void visit(const Mul &node) override { copy_operator(node); }
+ void visit(const Div &node) override { copy_operator(node); }
+ void visit(const Mod &node) override { copy_operator(node); }
+ void visit(const Pow &node) override { copy_operator(node); }
+ void visit(const Equal &node) override { copy_operator(node); }
+ void visit(const NotEqual &node) override { copy_operator(node); }
+ void visit(const Approx &node) override { copy_operator(node); }
+ void visit(const Less &node) override { copy_operator(node); }
+ void visit(const LessEqual &node) override { copy_operator(node); }
+ void visit(const Greater &node) override { copy_operator(node); }
+ void visit(const GreaterEqual &node) override { copy_operator(node); }
+ void visit(const And &node) override { copy_operator(node); }
+ void visit(const Or &node) override { copy_operator(node); }
// call nodes
- void visit(const Cos &node) override { copy_call(node); }
- void visit(const Sin &node) override { copy_call(node); }
- void visit(const Tan &node) override { copy_call(node); }
- void visit(const Cosh &node) override { copy_call(node); }
- void visit(const Sinh &node) override { copy_call(node); }
- void visit(const Tanh &node) override { copy_call(node); }
- void visit(const Acos &node) override { copy_call(node); }
- void visit(const Asin &node) override { copy_call(node); }
- void visit(const Atan &node) override { copy_call(node); }
- void visit(const Exp &node) override { copy_call(node); }
- void visit(const Log10 &node) override { copy_call(node); }
- void visit(const Log &node) override { copy_call(node); }
- void visit(const Sqrt &node) override { copy_call(node); }
- void visit(const Ceil &node) override { copy_call(node); }
- void visit(const Fabs &node) override { copy_call(node); }
- void visit(const Floor &node) override { copy_call(node); }
- void visit(const Atan2 &node) override { copy_call(node); }
- void visit(const Ldexp &node) override { copy_call(node); }
- void visit(const Pow2 &node) override { copy_call(node); }
- void visit(const Fmod &node) override { copy_call(node); }
- void visit(const Min &node) override { copy_call(node); }
- void visit(const Max &node) override { copy_call(node); }
- void visit(const IsNan &node) override { copy_call(node); }
- void visit(const Relu &node) override { copy_call(node); }
- void visit(const Sigmoid &node) override { copy_call(node); }
- void visit(const Elu &node) override { copy_call(node); }
- void visit(const Erf &node) override { copy_call(node); }
+ void visit(const Cos &node) override { copy_call(node); }
+ void visit(const Sin &node) override { copy_call(node); }
+ void visit(const Tan &node) override { copy_call(node); }
+ void visit(const Cosh &node) override { copy_call(node); }
+ void visit(const Sinh &node) override { copy_call(node); }
+ void visit(const Tanh &node) override { copy_call(node); }
+ void visit(const Acos &node) override { copy_call(node); }
+ void visit(const Asin &node) override { copy_call(node); }
+ void visit(const Atan &node) override { copy_call(node); }
+ void visit(const Exp &node) override { copy_call(node); }
+ void visit(const Log10 &node) override { copy_call(node); }
+ void visit(const Log &node) override { copy_call(node); }
+ void visit(const Sqrt &node) override { copy_call(node); }
+ void visit(const Ceil &node) override { copy_call(node); }
+ void visit(const Fabs &node) override { copy_call(node); }
+ void visit(const Floor &node) override { copy_call(node); }
+ void visit(const Atan2 &node) override { copy_call(node); }
+ void visit(const Ldexp &node) override { copy_call(node); }
+ void visit(const Pow2 &node) override { copy_call(node); }
+ void visit(const Fmod &node) override { copy_call(node); }
+ void visit(const Min &node) override { copy_call(node); }
+ void visit(const Max &node) override { copy_call(node); }
+ void visit(const IsNan &node) override { copy_call(node); }
+ void visit(const Relu &node) override { copy_call(node); }
+ void visit(const Sigmoid &node) override { copy_call(node); }
+ void visit(const Elu &node) override { copy_call(node); }
+ void visit(const Erf &node) override { copy_call(node); }
// traverse nodes
bool open(const Node &) override { return !error; }
diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp
index 7d76f2064a3..2df22d5433c 100644
--- a/eval/src/vespa/eval/eval/node_types.cpp
+++ b/eval/src/vespa/eval/eval/node_types.cpp
@@ -177,6 +177,9 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser {
bind(ValueType::concat(type(node.get_child(0)),
type(node.get_child(1)), node.dimension()), node);
}
+ void visit(const TensorCellCast &node) override {
+ bind(type(node.get_child(0)).cell_cast(node.cell_type()), node);
+ }
void visit(const TensorCreate &node) override {
for (size_t i = 0; i < node.num_children(); ++i) {
if (!type(node.get_child(i)).is_double()) {
diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h
index 95a5bec8be7..172cd48fe2a 100644
--- a/eval/src/vespa/eval/eval/node_visitor.h
+++ b/eval/src/vespa/eval/eval/node_visitor.h
@@ -19,71 +19,72 @@ namespace eval {
struct NodeVisitor {
// basic nodes
- virtual void visit(const nodes::Number &) = 0;
- virtual void visit(const nodes::Symbol &) = 0;
- virtual void visit(const nodes::String &) = 0;
- virtual void visit(const nodes::In &) = 0;
- virtual void visit(const nodes::Neg &) = 0;
- virtual void visit(const nodes::Not &) = 0;
- virtual void visit(const nodes::If &) = 0;
- virtual void visit(const nodes::Error &) = 0;
+ virtual void visit(const nodes::Number &) = 0;
+ virtual void visit(const nodes::Symbol &) = 0;
+ virtual void visit(const nodes::String &) = 0;
+ virtual void visit(const nodes::In &) = 0;
+ virtual void visit(const nodes::Neg &) = 0;
+ virtual void visit(const nodes::Not &) = 0;
+ virtual void visit(const nodes::If &) = 0;
+ virtual void visit(const nodes::Error &) = 0;
// tensor nodes
- virtual void visit(const nodes::TensorMap &) = 0;
- virtual void visit(const nodes::TensorJoin &) = 0;
- virtual void visit(const nodes::TensorMerge &) = 0;
- virtual void visit(const nodes::TensorReduce &) = 0;
- virtual void visit(const nodes::TensorRename &) = 0;
- virtual void visit(const nodes::TensorConcat &) = 0;
- virtual void visit(const nodes::TensorCreate &) = 0;
- virtual void visit(const nodes::TensorLambda &) = 0;
- virtual void visit(const nodes::TensorPeek &) = 0;
+ virtual void visit(const nodes::TensorMap &) = 0;
+ virtual void visit(const nodes::TensorJoin &) = 0;
+ virtual void visit(const nodes::TensorMerge &) = 0;
+ virtual void visit(const nodes::TensorReduce &) = 0;
+ virtual void visit(const nodes::TensorRename &) = 0;
+ virtual void visit(const nodes::TensorConcat &) = 0;
+ virtual void visit(const nodes::TensorCellCast &) = 0;
+ virtual void visit(const nodes::TensorCreate &) = 0;
+ virtual void visit(const nodes::TensorLambda &) = 0;
+ virtual void visit(const nodes::TensorPeek &) = 0;
// operator nodes
- virtual void visit(const nodes::Add &) = 0;
- virtual void visit(const nodes::Sub &) = 0;
- virtual void visit(const nodes::Mul &) = 0;
- virtual void visit(const nodes::Div &) = 0;
- virtual void visit(const nodes::Mod &) = 0;
- virtual void visit(const nodes::Pow &) = 0;
- virtual void visit(const nodes::Equal &) = 0;
- virtual void visit(const nodes::NotEqual &) = 0;
- virtual void visit(const nodes::Approx &) = 0;
- virtual void visit(const nodes::Less &) = 0;
- virtual void visit(const nodes::LessEqual &) = 0;
- virtual void visit(const nodes::Greater &) = 0;
- virtual void visit(const nodes::GreaterEqual &) = 0;
- virtual void visit(const nodes::And &) = 0;
- virtual void visit(const nodes::Or &) = 0;
+ virtual void visit(const nodes::Add &) = 0;
+ virtual void visit(const nodes::Sub &) = 0;
+ virtual void visit(const nodes::Mul &) = 0;
+ virtual void visit(const nodes::Div &) = 0;
+ virtual void visit(const nodes::Mod &) = 0;
+ virtual void visit(const nodes::Pow &) = 0;
+ virtual void visit(const nodes::Equal &) = 0;
+ virtual void visit(const nodes::NotEqual &) = 0;
+ virtual void visit(const nodes::Approx &) = 0;
+ virtual void visit(const nodes::Less &) = 0;
+ virtual void visit(const nodes::LessEqual &) = 0;
+ virtual void visit(const nodes::Greater &) = 0;
+ virtual void visit(const nodes::GreaterEqual &) = 0;
+ virtual void visit(const nodes::And &) = 0;
+ virtual void visit(const nodes::Or &) = 0;
// call nodes
- virtual void visit(const nodes::Cos &) = 0;
- virtual void visit(const nodes::Sin &) = 0;
- virtual void visit(const nodes::Tan &) = 0;
- virtual void visit(const nodes::Cosh &) = 0;
- virtual void visit(const nodes::Sinh &) = 0;
- virtual void visit(const nodes::Tanh &) = 0;
- virtual void visit(const nodes::Acos &) = 0;
- virtual void visit(const nodes::Asin &) = 0;
- virtual void visit(const nodes::Atan &) = 0;
- virtual void visit(const nodes::Exp &) = 0;
- virtual void visit(const nodes::Log10 &) = 0;
- virtual void visit(const nodes::Log &) = 0;
- virtual void visit(const nodes::Sqrt &) = 0;
- virtual void visit(const nodes::Ceil &) = 0;
- virtual void visit(const nodes::Fabs &) = 0;
- virtual void visit(const nodes::Floor &) = 0;
- virtual void visit(const nodes::Atan2 &) = 0;
- virtual void visit(const nodes::Ldexp &) = 0;
- virtual void visit(const nodes::Pow2 &) = 0;
- virtual void visit(const nodes::Fmod &) = 0;
- virtual void visit(const nodes::Min &) = 0;
- virtual void visit(const nodes::Max &) = 0;
- virtual void visit(const nodes::IsNan &) = 0;
- virtual void visit(const nodes::Relu &) = 0;
- virtual void visit(const nodes::Sigmoid &) = 0;
- virtual void visit(const nodes::Elu &) = 0;
- virtual void visit(const nodes::Erf &) = 0;
+ virtual void visit(const nodes::Cos &) = 0;
+ virtual void visit(const nodes::Sin &) = 0;
+ virtual void visit(const nodes::Tan &) = 0;
+ virtual void visit(const nodes::Cosh &) = 0;
+ virtual void visit(const nodes::Sinh &) = 0;
+ virtual void visit(const nodes::Tanh &) = 0;
+ virtual void visit(const nodes::Acos &) = 0;
+ virtual void visit(const nodes::Asin &) = 0;
+ virtual void visit(const nodes::Atan &) = 0;
+ virtual void visit(const nodes::Exp &) = 0;
+ virtual void visit(const nodes::Log10 &) = 0;
+ virtual void visit(const nodes::Log &) = 0;
+ virtual void visit(const nodes::Sqrt &) = 0;
+ virtual void visit(const nodes::Ceil &) = 0;
+ virtual void visit(const nodes::Fabs &) = 0;
+ virtual void visit(const nodes::Floor &) = 0;
+ virtual void visit(const nodes::Atan2 &) = 0;
+ virtual void visit(const nodes::Ldexp &) = 0;
+ virtual void visit(const nodes::Pow2 &) = 0;
+ virtual void visit(const nodes::Fmod &) = 0;
+ virtual void visit(const nodes::Min &) = 0;
+ virtual void visit(const nodes::Max &) = 0;
+ virtual void visit(const nodes::IsNan &) = 0;
+ virtual void visit(const nodes::Relu &) = 0;
+ virtual void visit(const nodes::Sigmoid &) = 0;
+ virtual void visit(const nodes::Elu &) = 0;
+ virtual void visit(const nodes::Erf &) = 0;
virtual ~NodeVisitor() {}
};
@@ -93,65 +94,66 @@ struct NodeVisitor {
* of all types not specifically handled.
**/
struct EmptyNodeVisitor : NodeVisitor {
- void visit(const nodes::Number &) override {}
- void visit(const nodes::Symbol &) override {}
- void visit(const nodes::String &) override {}
- void visit(const nodes::In &) override {}
- void visit(const nodes::Neg &) override {}
- void visit(const nodes::Not &) override {}
- void visit(const nodes::If &) override {}
- void visit(const nodes::Error &) override {}
- void visit(const nodes::TensorMap &) override {}
- void visit(const nodes::TensorJoin &) override {}
- void visit(const nodes::TensorMerge &) override {}
- void visit(const nodes::TensorReduce &) override {}
- void visit(const nodes::TensorRename &) override {}
- void visit(const nodes::TensorConcat &) override {}
- void visit(const nodes::TensorCreate &) override {}
- void visit(const nodes::TensorLambda &) override {}
- void visit(const nodes::TensorPeek &) override {}
- void visit(const nodes::Add &) override {}
- void visit(const nodes::Sub &) override {}
- void visit(const nodes::Mul &) override {}
- void visit(const nodes::Div &) override {}
- void visit(const nodes::Mod &) override {}
- void visit(const nodes::Pow &) override {}
- void visit(const nodes::Equal &) override {}
- void visit(const nodes::NotEqual &) override {}
- void visit(const nodes::Approx &) override {}
- void visit(const nodes::Less &) override {}
- void visit(const nodes::LessEqual &) override {}
- void visit(const nodes::Greater &) override {}
- void visit(const nodes::GreaterEqual &) override {}
- void visit(const nodes::And &) override {}
- void visit(const nodes::Or &) override {}
- void visit(const nodes::Cos &) override {}
- void visit(const nodes::Sin &) override {}
- void visit(const nodes::Tan &) override {}
- void visit(const nodes::Cosh &) override {}
- void visit(const nodes::Sinh &) override {}
- void visit(const nodes::Tanh &) override {}
- void visit(const nodes::Acos &) override {}
- void visit(const nodes::Asin &) override {}
- void visit(const nodes::Atan &) override {}
- void visit(const nodes::Exp &) override {}
- void visit(const nodes::Log10 &) override {}
- void visit(const nodes::Log &) override {}
- void visit(const nodes::Sqrt &) override {}
- void visit(const nodes::Ceil &) override {}
- void visit(const nodes::Fabs &) override {}
- void visit(const nodes::Floor &) override {}
- void visit(const nodes::Atan2 &) override {}
- void visit(const nodes::Ldexp &) override {}
- void visit(const nodes::Pow2 &) override {}
- void visit(const nodes::Fmod &) override {}
- void visit(const nodes::Min &) override {}
- void visit(const nodes::Max &) override {}
- void visit(const nodes::IsNan &) override {}
- void visit(const nodes::Relu &) override {}
- void visit(const nodes::Sigmoid &) override {}
- void visit(const nodes::Elu &) override {}
- void visit(const nodes::Erf &) override {}
+ void visit(const nodes::Number &) override {}
+ void visit(const nodes::Symbol &) override {}
+ void visit(const nodes::String &) override {}
+ void visit(const nodes::In &) override {}
+ void visit(const nodes::Neg &) override {}
+ void visit(const nodes::Not &) override {}
+ void visit(const nodes::If &) override {}
+ void visit(const nodes::Error &) override {}
+ void visit(const nodes::TensorMap &) override {}
+ void visit(const nodes::TensorJoin &) override {}
+ void visit(const nodes::TensorMerge &) override {}
+ void visit(const nodes::TensorReduce &) override {}
+ void visit(const nodes::TensorRename &) override {}
+ void visit(const nodes::TensorConcat &) override {}
+ void visit(const nodes::TensorCellCast &) override {}
+ void visit(const nodes::TensorCreate &) override {}
+ void visit(const nodes::TensorLambda &) override {}
+ void visit(const nodes::TensorPeek &) override {}
+ void visit(const nodes::Add &) override {}
+ void visit(const nodes::Sub &) override {}
+ void visit(const nodes::Mul &) override {}
+ void visit(const nodes::Div &) override {}
+ void visit(const nodes::Mod &) override {}
+ void visit(const nodes::Pow &) override {}
+ void visit(const nodes::Equal &) override {}
+ void visit(const nodes::NotEqual &) override {}
+ void visit(const nodes::Approx &) override {}
+ void visit(const nodes::Less &) override {}
+ void visit(const nodes::LessEqual &) override {}
+ void visit(const nodes::Greater &) override {}
+ void visit(const nodes::GreaterEqual &) override {}
+ void visit(const nodes::And &) override {}
+ void visit(const nodes::Or &) override {}
+ void visit(const nodes::Cos &) override {}
+ void visit(const nodes::Sin &) override {}
+ void visit(const nodes::Tan &) override {}
+ void visit(const nodes::Cosh &) override {}
+ void visit(const nodes::Sinh &) override {}
+ void visit(const nodes::Tanh &) override {}
+ void visit(const nodes::Acos &) override {}
+ void visit(const nodes::Asin &) override {}
+ void visit(const nodes::Atan &) override {}
+ void visit(const nodes::Exp &) override {}
+ void visit(const nodes::Log10 &) override {}
+ void visit(const nodes::Log &) override {}
+ void visit(const nodes::Sqrt &) override {}
+ void visit(const nodes::Ceil &) override {}
+ void visit(const nodes::Fabs &) override {}
+ void visit(const nodes::Floor &) override {}
+ void visit(const nodes::Atan2 &) override {}
+ void visit(const nodes::Ldexp &) override {}
+ void visit(const nodes::Pow2 &) override {}
+ void visit(const nodes::Fmod &) override {}
+ void visit(const nodes::Min &) override {}
+ void visit(const nodes::Max &) override {}
+ void visit(const nodes::IsNan &) override {}
+ void visit(const nodes::Relu &) override {}
+ void visit(const nodes::Sigmoid &) override {}
+ void visit(const nodes::Elu &) override {}
+ void visit(const nodes::Erf &) override {}
};
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp
index 19dbe812d3e..fb8d2f4e03c 100644
--- a/eval/src/vespa/eval/eval/tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/tensor_function.cpp
@@ -5,6 +5,7 @@
#include "operation.h"
#include "visit_stuff.h"
#include "string_stuff.h"
+#include "value_type_spec.h"
#include <vespa/eval/instruction/generic_cell_cast.h>
#include <vespa/eval/instruction/generic_concat.h>
#include <vespa/eval/instruction/generic_create.h>
@@ -216,6 +217,21 @@ Concat::visit_self(vespalib::ObjectVisitor &visitor) const
//-----------------------------------------------------------------------------
+InterpretedFunction::Instruction
+CellCast::compile_self(const ValueBuilderFactory &, Stash &stash) const
+{
+ return instruction::GenericCellCast::make_instruction(child().result_type(), cell_type(), stash);
+}
+
+void
+CellCast::visit_self(vespalib::ObjectVisitor &visitor) const
+{
+ Super::visit_self(visitor);
+ visitor.visitString("cell_type", value_type::cell_type_to_name(cell_type()));
+}
+
+//-----------------------------------------------------------------------------
+
void
Create::push_children(std::vector<Child::CREF> &children) const
{
@@ -312,14 +328,6 @@ Lambda::visit_self(vespalib::ObjectVisitor &visitor) const
//-----------------------------------------------------------------------------
-InterpretedFunction::Instruction
-CellCast::compile_self(const ValueBuilderFactory &, Stash &stash) const
-{
- return instruction::GenericCellCast::make_instruction(child().result_type(), result_type().cell_type(), stash);
-}
-
-//-----------------------------------------------------------------------------
-
void
Peek::push_children(std::vector<Child::CREF> &children) const
{
@@ -466,6 +474,11 @@ const TensorFunction &lambda(const ValueType &type, const std::vector<size_t> &b
return stash.create<Lambda>(type, bindings, function, std::move(node_types));
}
+const TensorFunction &cell_cast(const TensorFunction &child, CellType cell_type, Stash &stash) {
+ ValueType result_type = child.result_type().cell_cast(cell_type);
+ return stash.create<CellCast>(result_type, child, cell_type);
+}
+
const TensorFunction &peek(const TensorFunction &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, TensorFunction::CREF>> &spec, Stash &stash) {
std::vector<vespalib::string> dimensions;
for (const auto &dim_spec: spec) {
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index 55532bc4bf7..47610d02aca 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -306,6 +306,22 @@ public:
//-----------------------------------------------------------------------------
+class CellCast : public Op1
+{
+ using Super = Op1;
+private:
+ CellType _cell_type;
+public:
+ CellCast(const ValueType &result_type_in, const TensorFunction &child_in, CellType cell_type)
+ : Super(result_type_in, child_in), _cell_type(cell_type) {}
+ CellType cell_type() const { return _cell_type; }
+ bool result_is_mutable() const override { return true; }
+ InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
+ void visit_self(vespalib::ObjectVisitor &visitor) const override;
+};
+
+//-----------------------------------------------------------------------------
+
class Create : public Node
{
using Super = Node;
@@ -355,20 +371,6 @@ public:
//-----------------------------------------------------------------------------
-class CellCast : public Op1
-{
-private:
- using Super = Op1;
-public:
- CellCast(const TensorFunction &child_in, CellType to_cell_type)
- : Super(ValueType::cell_cast(child_in.result_type(), to_cell_type), child_in)
- {}
- bool result_is_mutable() const override { return true; }
- InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
-};
-
-//-----------------------------------------------------------------------------
-
class Peek : public Node
{
using Super = Node;
@@ -465,6 +467,7 @@ const TensorFunction &merge(const TensorFunction &lhs, const TensorFunction &rhs
const TensorFunction &concat(const TensorFunction &lhs, const TensorFunction &rhs, const vespalib::string &dimension, Stash &stash);
const TensorFunction &create(const ValueType &type, const std::map<TensorSpec::Address, TensorFunction::CREF> &spec, Stash &stash);
const TensorFunction &lambda(const ValueType &type, const std::vector<size_t> &bindings, const Function &function, NodeTypes node_types, Stash &stash);
+const TensorFunction &cell_cast(const TensorFunction &child, CellType cell_type, Stash &stash);
const TensorFunction &peek(const TensorFunction &param, const std::map<vespalib::string, std::variant<TensorSpec::Label, TensorFunction::CREF>> &spec, Stash &stash);
const TensorFunction &rename(const TensorFunction &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash);
const TensorFunction &if_node(const TensorFunction &cond, const TensorFunction &true_child, const TensorFunction &false_child, Stash &stash);
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp
index 5cb064ad127..ca148b4275c 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.cpp
+++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp
@@ -7,15 +7,16 @@ namespace vespalib {
namespace eval {
namespace nodes {
-void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorConcat::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorCreate::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorLambda::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
-void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorReduce ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorRename ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorConcat ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorCellCast::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorCreate ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorLambda ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
+void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); }
} // namespace vespalib::eval::nodes
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h
index 618e03f229e..90979953531 100644
--- a/eval/src/vespa/eval/eval/tensor_nodes.h
+++ b/eval/src/vespa/eval/eval/tensor_nodes.h
@@ -7,6 +7,7 @@
#include "tensor_spec.h"
#include "aggr.h"
#include "string_stuff.h"
+#include "value_type_spec.h"
#include <vespa/vespalib/stllike/string.h>
#include <vector>
#include <map>
@@ -67,7 +68,7 @@ public:
str += ")";
return str;
}
- void accept(NodeVisitor &visitor) const override ;
+ void accept(NodeVisitor &visitor) const override;
size_t num_children() const override { return 2; }
const Node &get_child(size_t idx) const override {
assert(idx < 2);
@@ -101,7 +102,7 @@ public:
str += ")";
return str;
}
- void accept(NodeVisitor &visitor) const override ;
+ void accept(NodeVisitor &visitor) const override;
size_t num_children() const override { return 2; }
const Node &get_child(size_t idx) const override {
assert(idx < 2);
@@ -217,7 +218,7 @@ public:
str += ")";
return str;
}
- void accept(NodeVisitor &visitor) const override ;
+ void accept(NodeVisitor &visitor) const override;
size_t num_children() const override { return 2; }
const Node &get_child(size_t idx) const override {
assert(idx < 2);
@@ -229,6 +230,36 @@ public:
}
};
+class TensorCellCast : public Node {
+private:
+ Node_UP _child;
+ CellType _cell_type;
+public:
+ TensorCellCast(Node_UP child, CellType cell_type)
+ : _child(std::move(child)), _cell_type(cell_type) {}
+ const Node &child() const { return *_child; }
+ CellType cell_type() const { return _cell_type; }
+ vespalib::string dump(DumpContext &ctx) const override {
+ vespalib::string str;
+ str += "cell_cast(";
+ str += _child->dump(ctx);
+ str += ",";
+ str += value_type::cell_type_to_name(_cell_type);
+ str += ")";
+ return str;
+ }
+ void accept(NodeVisitor &visitor) const override;
+ size_t num_children() const override { return 1; }
+ const Node &get_child(size_t idx) const override {
+ (void) idx;
+ assert(idx == 0);
+ return *_child;
+ }
+ void detach_children(NodeHandler &handler) override {
+ handler.handle(std::move(_child));
+ }
+};
+
class TensorCreate : public Node {
public:
using Spec = std::map<TensorSpec::Address, Node_UP>;
@@ -259,7 +290,7 @@ public:
str += "}";
return str;
}
- void accept(NodeVisitor &visitor) const override ;
+ void accept(NodeVisitor &visitor) const override;
size_t num_children() const override { return _cells.size(); }
const Node &get_child(size_t idx) const override {
assert(idx < _cells.size());
@@ -365,7 +396,7 @@ public:
str += "}";
return str;
}
- void accept(NodeVisitor &visitor) const override ;
+ void accept(NodeVisitor &visitor) const override;
size_t num_children() const override { return (1 + _expr_dims.size()); }
const Node &get_child(size_t idx) const override {
assert(idx < num_children());
diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp
index 6b80b65df6c..63a3a23d9ae 100644
--- a/eval/src/vespa/eval/eval/test/eval_spec.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp
@@ -181,6 +181,7 @@ EvalSpec::add_tensor_operation_cases() {
add_expression({}, "tensor(x[10],y[10])(x==y)");
add_expression({"a","b"}, "concat(a,b,x)");
add_expression({"a","b"}, "concat(a,b,y)");
+ add_expression({"a"}, "cell_cast(a,float)");
add_expression({}, "tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}");
add_expression({"a"}, "a{x:3}");
}
diff --git a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
index c20d8af32ec..4857b5dabc2 100644
--- a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
+++ b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp
@@ -81,6 +81,10 @@ struct EvalNode : public NodeVisitor {
result = ReferenceOperations::concat(eval_node(a, params), eval_node(b, params), dimension);
}
+ void eval_cell_cast(const Node &a, CellType cell_type) {
+ result = ReferenceOperations::cell_cast(eval_node(a, params), cell_type);
+ }
+
void eval_create(const TensorCreate &node) {
std::map<TensorSpec::Address, size_t> spec;
std::vector<TensorSpec> children;
@@ -193,6 +197,9 @@ struct EvalNode : public NodeVisitor {
void visit(const TensorConcat &node) override {
eval_concat(node.lhs(), node.rhs(), node.dimension());
}
+ void visit(const TensorCellCast &node) override {
+ eval_cell_cast(node.child(), node.cell_type());
+ }
void visit(const TensorCreate &node) override {
eval_create(node);
}
diff --git a/eval/src/vespa/eval/eval/test/reference_operations.cpp b/eval/src/vespa/eval/eval/test/reference_operations.cpp
index 8709150bdc4..024fa14954d 100644
--- a/eval/src/vespa/eval/eval/test/reference_operations.cpp
+++ b/eval/src/vespa/eval/eval/test/reference_operations.cpp
@@ -84,7 +84,7 @@ struct CopyCellsWithCast {
TensorSpec ReferenceOperations::cell_cast(const TensorSpec &a, CellType to) {
ValueType a_type = ValueType::from_spec(a.type());
- ValueType res_type = ValueType::cell_cast(a_type, to);
+ ValueType res_type = a_type.cell_cast(to);
TensorSpec result(res_type.to_spec());
if (res_type.is_error()) {
return result;
diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
index 779ce30663f..117e8c9b149 100644
--- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
+++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp
@@ -7,6 +7,7 @@
#include <vespa/eval/eval/aggr.h>
#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/eval/value_type_spec.h>
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/objects/nbostream.h>
@@ -557,6 +558,28 @@ struct TestContext {
//-------------------------------------------------------------------------
+ void test_cell_cast(const GenSpec &a) {
+ for (CellType cell_type: CellTypeUtils::list_types()) {
+ vespalib::string expr = fmt("cell_cast(a,%s)", value_type::cell_type_to_name(cell_type).c_str());
+ TEST_DO(verify_result(factory, expr, {a}, a.cpy().cells(cell_type)));
+ }
+ }
+
+ void test_cell_cast() {
+ std::vector<GenSpec> gen_list;
+ for (CellType cell_type: CellTypeUtils::list_types()) {
+ gen_list.push_back(GenSpec(-3).cells(cell_type));
+ }
+ for (const auto &gen: gen_list) {
+ TEST_DO(test_cell_cast(gen));
+ TEST_DO(test_cell_cast(gen.cpy().idx("x", 10)));
+ TEST_DO(test_cell_cast(gen.cpy().map("x", 10, 1)));
+ TEST_DO(test_cell_cast(gen.cpy().map("x", 4, 1).idx("y", 4)));
+ }
+ }
+
+ //-------------------------------------------------------------------------
+
void test_rename(const vespalib::string &expr,
const TensorSpec &input,
const TensorSpec &expect)
@@ -735,6 +758,7 @@ struct TestContext {
TEST_DO(test_tensor_apply());
TEST_DO(test_dot_product());
TEST_DO(test_concat());
+ TEST_DO(test_cell_cast());
TEST_DO(test_rename());
TEST_DO(test_tensor_lambda());
TEST_DO(test_tensor_create());
diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp
index 57b101f6eeb..b70edef7153 100644
--- a/eval/src/vespa/eval/eval/value_type.cpp
+++ b/eval/src/vespa/eval/eval/value_type.cpp
@@ -287,6 +287,16 @@ ValueType::rename(const std::vector<vespalib::string> &from,
}
ValueType
+ValueType::cell_cast(CellType to_cell_type) const
+{
+ if (is_error()) {
+ return error_type();
+ }
+ // TODO: return make_type(to_cell_type, _dimensions);
+ return tensor_type(_dimensions, to_cell_type);
+}
+
+ValueType
ValueType::make_type(CellType cell_type, std::vector<Dimension> dimensions_in)
{
sort_dimensions(dimensions_in);
@@ -388,11 +398,6 @@ ValueType::either(const ValueType &one, const ValueType &other) {
return one;
}
-ValueType
-ValueType::cell_cast(const ValueType &from, CellType to_cell_type) {
- return make_type(to_cell_type, from.dimensions());
-}
-
std::ostream &
operator<<(std::ostream &os, const ValueType &type) {
return os << type.to_spec();
diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h
index 4609f6a0b38..247912b274a 100644
--- a/eval/src/vespa/eval/eval/value_type.h
+++ b/eval/src/vespa/eval/eval/value_type.h
@@ -86,6 +86,7 @@ public:
ValueType reduce(const std::vector<vespalib::string> &dimensions_in) const;
ValueType rename(const std::vector<vespalib::string> &from,
const std::vector<vespalib::string> &to) const;
+ ValueType cell_cast(CellType to_cell_type) const;
static ValueType error_type() { return ValueType(); }
static ValueType make_type(CellType cell_type, std::vector<Dimension> dimensions_in);
@@ -105,8 +106,6 @@ public:
static CellType unify_cell_types(const ValueType &a, const ValueType &b);
static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension);
static ValueType either(const ValueType &one, const ValueType &other);
- static ValueType cell_cast(const ValueType &from, CellType to_cell_type);
-
};
std::ostream &operator<<(std::ostream &os, const ValueType &type);
diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp
index a4575e33c2f..1eabbd0a9fc 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.cpp
+++ b/eval/src/vespa/eval/eval/value_type_spec.cpp
@@ -8,9 +8,17 @@
namespace vespalib::eval::value_type {
-namespace {
+std::optional<CellType> cell_type_from_name(const vespalib::string &name) {
+ if (name == "double") {
+ return CellType::DOUBLE;
+ } else if (name == "float") {
+ return CellType::FLOAT;
+ } else {
+ return std::nullopt;
+ }
+}
-const char *to_name(CellType cell_type) {
+vespalib::string cell_type_to_name(CellType cell_type) {
switch (cell_type) {
case CellType::DOUBLE: return "double";
case CellType::FLOAT: return "float";
@@ -18,6 +26,8 @@ const char *to_name(CellType cell_type) {
abort();
}
+namespace {
+
class ParseContext
{
public:
@@ -232,11 +242,11 @@ to_spec(const ValueType &type)
if (type.is_error()) {
os << "error";
} else if (type.is_scalar()) {
- os << to_name(type.cell_type());
+ os << cell_type_to_name(type.cell_type());
} else {
os << "tensor";
if (type.cell_type() != CellType::DOUBLE) {
- os << "<" << to_name(type.cell_type()) << ">";
+ os << "<" << cell_type_to_name(type.cell_type()) << ">";
}
os << "(";
for (const auto &d: type.dimensions()) {
diff --git a/eval/src/vespa/eval/eval/value_type_spec.h b/eval/src/vespa/eval/eval/value_type_spec.h
index ff5113c769a..39168b34fcb 100644
--- a/eval/src/vespa/eval/eval/value_type_spec.h
+++ b/eval/src/vespa/eval/eval/value_type_spec.h
@@ -3,9 +3,13 @@
#pragma once
#include "value_type.h"
+#include <optional>
namespace vespalib::eval::value_type {
+std::optional<CellType> cell_type_from_name(const vespalib::string &name);
+vespalib::string cell_type_to_name(CellType cell_type);
+
ValueType parse_spec(const char *pos_in, const char *end_in, const char *&pos_out,
std::vector<ValueType::Dimension> *unsorted = nullptr);
diff --git a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp
index 58221b9c62f..ff84a65857c 100644
--- a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp
+++ b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp
@@ -46,7 +46,7 @@ GenericCellCast::make_instruction(const ValueType &input_type,
Stash &stash)
{
CellType from = input_type.cell_type();
- auto result_type = ValueType::cell_cast(input_type, to_cell_type);
+ auto result_type = input_type.cell_cast(to_cell_type);
auto &param = stash.create<ValueType>(result_type);
CellType to = result_type.cell_type();
auto op = typify_invoke<2,TypifyCellType,SelectGenericCellCastOp>(from, to);
diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h
index 0319f1a929f..e1dfd35e7fd 100644
--- a/eval/src/vespa/eval/instruction/generic_merge.h
+++ b/eval/src/vespa/eval/instruction/generic_merge.h
@@ -15,7 +15,7 @@ struct MergeParam {
const ValueBuilderFactory &factory;
MergeParam(const ValueType &lhs_type, const ValueType &rhs_type,
join_fun_t function_in, const ValueBuilderFactory &factory_in)
- : res_type(ValueType::join(lhs_type, rhs_type)),
+ : res_type(ValueType::merge(lhs_type, rhs_type)),
function(function_in),
num_mapped_dimensions(lhs_type.count_mapped_dimensions()),
dense_subspace_size(lhs_type.dense_subspace_size()),