diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-03-03 08:47:14 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-03-03 08:47:14 +0100 |
commit | a6c06485fd2894c88e30c6dc58efcb4669721ccc (patch) | |
tree | 870e8c23428e777530438d8cd7e9eb958e43e189 | |
parent | 898d3f3262cf313e9955b456e3d9cdeb1c5ca87d (diff) | |
parent | fca3ba6823e2d1fa1ffe8489b63e44a271a127b6 (diff) |
Merge pull request #16751 from vespa-engine/havardpe/cell-cast-in-function
enable use of cell_cast in expressions
30 files changed, 538 insertions, 278 deletions
diff --git a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp index a19bff415e1..8e17aadcfff 100644 --- a/eval/src/tests/eval/compiled_function/compiled_function_test.cpp +++ b/eval/src/tests/eval/compiled_function/compiled_function_test.cpp @@ -55,7 +55,8 @@ std::vector<vespalib::string> unsupported = { "reduce(", "rename(", "tensor(", - "concat(" + "concat(", + "cell_cast(" }; bool is_unsupported(const vespalib::string &expression) { diff --git a/eval/src/tests/eval/function/function_test.cpp b/eval/src/tests/eval/function/function_test.cpp index 3dbc2e1aed2..4dff2934873 100644 --- a/eval/src/tests/eval/function/function_test.cpp +++ b/eval/src/tests/eval/function/function_test.cpp @@ -1018,6 +1018,17 @@ TEST("require that tensor concat can be parsed") { //----------------------------------------------------------------------------- +TEST("require that tensor cell cast can be parsed") { + EXPECT_EQUAL("cell_cast(a,float)", Function::parse({"a"}, "cell_cast(a,float)")->dump()); + EXPECT_EQUAL("cell_cast(a,double)", Function::parse({"a"}, " cell_cast ( a , double ) ")->dump()); +} + +TEST("require that tensor cell cast must have valid cell type") { + TEST_DO(verify_error("cell_cast(x,int7)", "[cell_cast(x,int7]...[unknown cell type: 'int7']...[)]")); +} + +//----------------------------------------------------------------------------- + struct CheckExpressions : test::EvalSpec::EvalTest { bool failed = false; size_t seen_cnt = 0; diff --git a/eval/src/tests/eval/node_types/node_types_test.cpp b/eval/src/tests/eval/node_types/node_types_test.cpp index f595c58ef29..89c37af0a83 100644 --- a/eval/src/tests/eval/node_types/node_types_test.cpp +++ b/eval/src/tests/eval/node_types/node_types_test.cpp @@ -292,6 +292,14 @@ TEST("require that tensor concat resolves correct type") { TEST_DO(verify("concat(tensor<float>(x[2]),double,x)", "tensor<float>(x[3])")); } +TEST("require that tensor cell_cast resolves correct type") { + TEST_DO(verify("cell_cast(double,float)", "double")); // NB + TEST_DO(verify("cell_cast(float,double)", "double")); + TEST_DO(verify("cell_cast(tensor<double>(x{},y[5]),float)", "tensor<float>(x{},y[5])")); + TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),double)", "tensor<double>(x{},y[5])")); + TEST_DO(verify("cell_cast(tensor<float>(x{},y[5]),float)", "tensor<float>(x{},y[5])")); +} + TEST("require that double only expressions can be detected") { auto plain_fun = Function::parse("1+2"); auto complex_fun = Function::parse("reduce(a,sum)"); diff --git a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp index 345f04053ac..b6115b78378 100644 --- a/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp +++ b/eval/src/tests/eval/reference_evaluation/reference_evaluation_test.cpp @@ -151,6 +151,12 @@ TEST(ReferenceEvaluationTest, concat_expression_works) { EXPECT_EQ(ref_eval("concat(a,b,x)", {a, b}), expect); } +TEST(ReferenceEvaluationTest, cell_cast_expression_works) { + auto a = make_val("tensor<double>(x[4]):[1,2,3,4]"); + auto expect = make_val("tensor<float>(x[4]):[1,2,3,4]"); + EXPECT_EQ(ref_eval("cell_cast(a,float)", {a}), expect); +} + TEST(ReferenceEvaluationTest, rename_expression_works) { auto a = make_val("tensor(x[2]):[1,2]"); auto expect = make_val("tensor(y[2]):[1,2]"); diff --git a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp index 3fcca5e34d8..d191ee064ee 100644 --- a/eval/src/tests/eval/reference_operations/reference_operations_test.cpp +++ b/eval/src/tests/eval/reference_operations/reference_operations_test.cpp @@ -1,11 +1,14 @@ // Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/eval/eval/cell_type.h> #include <vespa/eval/eval/test/reference_operations.h> +#include <vespa/eval/eval/test/gen_spec.h> #include <vespa/vespalib/gtest/gtest.h> #include <iostream> using namespace vespalib; using namespace vespalib::eval; +using vespalib::eval::test::GenSpec; TensorSpec dense_2d_some_cells(bool square) { return TensorSpec("tensor(a[3],d[5])") @@ -48,7 +51,6 @@ TensorSpec sparse_1d_all_two() { //----------------------------------------------------------------------------- - TEST(ReferenceConcatTest, concat_numbers) { auto a = TensorSpec("double").add({}, 7.0); auto b = TensorSpec("double").add({}, 4.0); @@ -132,6 +134,27 @@ TEST(ReferenceConcatTest, concat_mixed_tensors) { //----------------------------------------------------------------------------- +TEST(ReferenceCellCastTest, cell_cast_works) { + std::vector<GenSpec> gen_list = { + GenSpec(42), + GenSpec(-3).idx("x", 10), + GenSpec(-3).map("x", 10, 1), + GenSpec(-3).map("x", 4, 1).idx("y", 4) + }; + for (CellType from_type: CellTypeUtils::list_types()) { + for (CellType to_type: CellTypeUtils::list_types()) { + for (const auto &gen: gen_list) { + TensorSpec input = gen.cpy().cells(from_type); + TensorSpec expect = gen.cpy().cells(to_type); + auto actual = ReferenceOperations::cell_cast(input, to_type); + EXPECT_EQ(actual, expect); + } + } + } +} + +//----------------------------------------------------------------------------- + TEST(ReferenceCreateTest, simple_create_works) { auto a = TensorSpec("double").add({}, 1.5); auto b = TensorSpec("tensor(z[2])").add({{"z",0}}, 2.0).add({{"z",1}}, 3.0); @@ -525,4 +548,3 @@ TEST(ReferenceLambdaTest, make_matrix) { //----------------------------------------------------------------------------- GTEST_MAIN_RUN_ALL_TESTS() - diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index 4396a6773ea..f1bd900b350 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -31,6 +31,9 @@ struct EvalCtx { tensors.push_back(std::move(tensor)); return id; } + ValueType type_of(size_t idx) { + return params[idx].get().type(); + } void replace_tensor(size_t idx, Value::UP tensor) { params[idx] = *tensor; tensors[idx] = std::move(tensor); @@ -91,6 +94,14 @@ struct EvalCtx { .add({{"x", 1}, {"y", 0}}, 3.0) .add({{"x", 1}, {"y", 1}}, 4.0), factory); } + Value::UP make_float_tensor_matrix() { + return value_from_spec( + TensorSpec("tensor<float>(x[2],y[2])") + .add({{"x", 0}, {"y", 0}}, 1.0) + .add({{"x", 0}, {"y", 1}}, 2.0) + .add({{"x", 1}, {"y", 0}}, 3.0) + .add({{"x", 1}, {"y", 1}}, 4.0), factory); + } Value::UP make_tensor_matrix_renamed() { return value_from_spec( TensorSpec("tensor(y[2],z[2])") @@ -273,6 +284,16 @@ TEST("require that tensor concat works") { TEST_DO(verify_equal(*expect, ctx.eval(fun))); } +TEST("require that tensor cell cast works") { + EvalCtx ctx(simple_factory); + size_t a_id = ctx.add_tensor(ctx.make_tensor_matrix()); + Value::UP expect = ctx.make_float_tensor_matrix(); + const auto &fun = cell_cast(inject(ctx.type_of(a_id), a_id, ctx.stash), CellType::FLOAT, ctx.stash); + EXPECT_TRUE(fun.result_is_mutable()); + EXPECT_EQUAL(expect->type(), fun.result_type()); + TEST_DO(verify_equal(*expect, ctx.eval(fun))); +} + TEST("require that tensor create works") { EvalCtx ctx(simple_factory); size_t a_id = ctx.add_tensor(ctx.make_double(1.0)); @@ -450,6 +471,10 @@ TEST("require that push_children works") { EXPECT_EQUAL(&refs[10].get().get(), &b); EXPECT_EQUAL(&refs[11].get().get(), &c); //------------------------------------------------------------------------- + cell_cast(a, CellType::FLOAT, stash).push_children(refs); + ASSERT_EQUAL(refs.size(), 13u); + EXPECT_EQUAL(&refs[12].get().get(), &a); + //------------------------------------------------------------------------- } TEST("require that tensor function can be dumped for debugging") { @@ -459,7 +484,9 @@ TEST("require that tensor function can be dumped for debugging") { auto my_value_3 = stash.create<DoubleValue>(3.0); //------------------------------------------------------------------------- const auto &x5 = inject(ValueType::from_spec("tensor(x[5])"), 0, stash); - const auto &mapped_x5 = map(x5, operation::Relu::f, stash); + const auto &float_x5 = cell_cast(x5, CellType::FLOAT, stash); + const auto &double_x5 = cell_cast(float_x5, CellType::DOUBLE, stash); + const auto &mapped_x5 = map(double_x5, operation::Relu::f, stash); const auto &const_1 = const_value(my_value_1, stash); const auto &joined_x5 = join(mapped_x5, const_1, operation::Mul::f, stash); //------------------------------------------------------------------------- diff --git a/eval/src/tests/eval/value_type/value_type_test.cpp b/eval/src/tests/eval/value_type/value_type_test.cpp index d58adbbcef0..c1b25d48bf7 100644 --- a/eval/src/tests/eval/value_type/value_type_test.cpp +++ b/eval/src/tests/eval/value_type/value_type_test.cpp @@ -501,4 +501,51 @@ TEST("require that cell type is handled correctly for concat") { TEST_DO(verify_concat(type("tensor<float>(x[3])"), type("double"), "x", type("tensor<float>(x[4])"))); } +void verify_cell_cast(const ValueType &type) { + for (CellType cell_type: CellTypeUtils::list_types()) { + auto res_type = type.cell_cast(cell_type); + if (type.is_error()) { + EXPECT_TRUE(res_type.is_error()); + EXPECT_EQUAL(res_type, type); + } else if (type.is_scalar()) { + EXPECT_TRUE(res_type.is_double()); // NB + } else { + EXPECT_FALSE(res_type.is_error()); + EXPECT_EQUAL(int(res_type.cell_type()), int(cell_type)); + EXPECT_TRUE(res_type.dimensions() == type.dimensions()); + } + } +} + +TEST("require that value type cell cast works correctly") { + TEST_DO(verify_cell_cast(type("error"))); + TEST_DO(verify_cell_cast(type("float"))); + TEST_DO(verify_cell_cast(type("double"))); + TEST_DO(verify_cell_cast(type("tensor<float>(x[10])"))); + TEST_DO(verify_cell_cast(type("tensor<double>(x[10])"))); + TEST_DO(verify_cell_cast(type("tensor<float>(x{})"))); + TEST_DO(verify_cell_cast(type("tensor<double>(x{})"))); + TEST_DO(verify_cell_cast(type("tensor<float>(x{},y[5])"))); + TEST_DO(verify_cell_cast(type("tensor<double>(x{},y[5])"))); +} + +TEST("require that actual cell type can be converted to cell type name") { + EXPECT_EQUAL(value_type::cell_type_to_name(CellType::FLOAT), "float"); + EXPECT_EQUAL(value_type::cell_type_to_name(CellType::DOUBLE), "double"); +} + +TEST("require that cell type name can be converted to actual cell type") { + EXPECT_EQUAL(int(value_type::cell_type_from_name("float").value()), int(CellType::FLOAT)); + EXPECT_EQUAL(int(value_type::cell_type_from_name("double").value()), int(CellType::DOUBLE)); + EXPECT_FALSE(value_type::cell_type_from_name("int7").has_value()); +} + +TEST("require that cell type name recognition is strict") { + EXPECT_FALSE(value_type::cell_type_from_name("Float").has_value()); + EXPECT_FALSE(value_type::cell_type_from_name(" float").has_value()); + EXPECT_FALSE(value_type::cell_type_from_name("float ").has_value()); + EXPECT_FALSE(value_type::cell_type_from_name("f").has_value()); + EXPECT_FALSE(value_type::cell_type_from_name("").has_value()); +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/cell_type.h b/eval/src/vespa/eval/eval/cell_type.h index 4c4f7e78413..beb0b32386f 100644 --- a/eval/src/vespa/eval/eval/cell_type.h +++ b/eval/src/vespa/eval/eval/cell_type.h @@ -3,6 +3,7 @@ #pragma once #include <vespa/vespalib/util/typify.h> +#include <vector> #include <cstdint> namespace vespalib::eval { @@ -43,6 +44,10 @@ struct CellTypeUtils { default: bad_argument((uint32_t)cell_type); } } + + static std::vector<CellType> list_types() { + return {CellType::FLOAT, CellType::DOUBLE}; + } }; struct TypifyCellType { diff --git a/eval/src/vespa/eval/eval/function.cpp b/eval/src/vespa/eval/eval/function.cpp index aeb53eaa6ba..b03c2c1ed24 100644 --- a/eval/src/vespa/eval/eval/function.cpp +++ b/eval/src/vespa/eval/eval/function.cpp @@ -830,6 +830,19 @@ void parse_tensor_concat(ParseContext &ctx) { ctx.push_expression(std::make_unique<nodes::TensorConcat>(std::move(lhs), std::move(rhs), dimension)); } +void parse_tensor_cell_cast(ParseContext &ctx) { + Node_UP child = get_expression(ctx); + ctx.eat(','); + auto cell_type_name = get_ident(ctx, false); + auto cell_type = value_type::cell_type_from_name(cell_type_name); + ctx.skip_spaces(); + if (cell_type.has_value()) { + ctx.push_expression(std::make_unique<nodes::TensorCellCast>(std::move(child), cell_type.value())); + } else { + ctx.fail(make_string("unknown cell type: '%s'", cell_type_name.c_str())); + } +} + bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) { ctx.skip_spaces(); if (ctx.get() == '(') { @@ -852,6 +865,8 @@ bool maybe_parse_call(ParseContext &ctx, const vespalib::string &name) { parse_tensor_rename(ctx); } else if (name == "concat") { parse_tensor_concat(ctx); + } else if (name == "cell_cast") { + parse_tensor_cell_cast(ctx); } else { ctx.fail(make_string("unknown function: '%s'", name.c_str())); return false; diff --git a/eval/src/vespa/eval/eval/key_gen.cpp b/eval/src/vespa/eval/eval/key_gen.cpp index 31167be5fe1..80803d3b2a2 100644 --- a/eval/src/vespa/eval/eval/key_gen.cpp +++ b/eval/src/vespa/eval/eval/key_gen.cpp @@ -31,61 +31,62 @@ struct KeyGen : public NodeVisitor, public NodeTraverser { add_double(node.get_entry(i).get_const_value()); } } - void visit(const Neg &) override { add_byte( 5); } - void visit(const Not &) override { add_byte( 6); } - void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); } - void visit(const Error &) override { add_byte( 9); } - void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key - void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key - void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key - void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key - void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key - void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key - void visit(const TensorCreate &) override { add_byte(16); } // type/addr should be part of key - void visit(const TensorLambda &) override { add_byte(17); } // type/lambda should be part of key - void visit(const TensorPeek &) override { add_byte(18); } // addr should be part of key - void visit(const Add &) override { add_byte(20); } - void visit(const Sub &) override { add_byte(21); } - void visit(const Mul &) override { add_byte(22); } - void visit(const Div &) override { add_byte(23); } - void visit(const Mod &) override { add_byte(24); } - void visit(const Pow &) override { add_byte(25); } - void visit(const Equal &) override { add_byte(26); } - void visit(const NotEqual &) override { add_byte(27); } - void visit(const Approx &) override { add_byte(28); } - void visit(const Less &) override { add_byte(29); } - void visit(const LessEqual &) override { add_byte(30); } - void visit(const Greater &) override { add_byte(31); } - void visit(const GreaterEqual &) override { add_byte(32); } - void visit(const And &) override { add_byte(34); } - void visit(const Or &) override { add_byte(35); } - void visit(const Cos &) override { add_byte(36); } - void visit(const Sin &) override { add_byte(37); } - void visit(const Tan &) override { add_byte(38); } - void visit(const Cosh &) override { add_byte(39); } - void visit(const Sinh &) override { add_byte(40); } - void visit(const Tanh &) override { add_byte(41); } - void visit(const Acos &) override { add_byte(42); } - void visit(const Asin &) override { add_byte(43); } - void visit(const Atan &) override { add_byte(44); } - void visit(const Exp &) override { add_byte(45); } - void visit(const Log10 &) override { add_byte(46); } - void visit(const Log &) override { add_byte(47); } - void visit(const Sqrt &) override { add_byte(48); } - void visit(const Ceil &) override { add_byte(49); } - void visit(const Fabs &) override { add_byte(50); } - void visit(const Floor &) override { add_byte(51); } - void visit(const Atan2 &) override { add_byte(52); } - void visit(const Ldexp &) override { add_byte(53); } - void visit(const Pow2 &) override { add_byte(54); } - void visit(const Fmod &) override { add_byte(55); } - void visit(const Min &) override { add_byte(56); } - void visit(const Max &) override { add_byte(57); } - void visit(const IsNan &) override { add_byte(58); } - void visit(const Relu &) override { add_byte(59); } - void visit(const Sigmoid &) override { add_byte(60); } - void visit(const Elu &) override { add_byte(61); } - void visit(const Erf &) override { add_byte(62); } + void visit(const Neg &) override { add_byte( 5); } + void visit(const Not &) override { add_byte( 6); } + void visit(const If &node) override { add_byte( 7); add_double(node.p_true()); } + void visit(const Error &) override { add_byte( 9); } + void visit(const TensorMap &) override { add_byte(10); } // lambda should be part of key + void visit(const TensorJoin &) override { add_byte(11); } // lambda should be part of key + void visit(const TensorMerge &) override { add_byte(12); } // lambda should be part of key + void visit(const TensorReduce &) override { add_byte(13); } // aggr/dimensions should be part of key + void visit(const TensorRename &) override { add_byte(14); } // dimensions should be part of key + void visit(const TensorConcat &) override { add_byte(15); } // dimension should be part of key + void visit(const TensorCellCast &) override { add_byte(16); } // cell type should be part of key + void visit(const TensorCreate &) override { add_byte(17); } // type/addr should be part of key + void visit(const TensorLambda &) override { add_byte(18); } // type/lambda should be part of key + void visit(const TensorPeek &) override { add_byte(19); } // addr should be part of key + void visit(const Add &) override { add_byte(20); } + void visit(const Sub &) override { add_byte(21); } + void visit(const Mul &) override { add_byte(22); } + void visit(const Div &) override { add_byte(23); } + void visit(const Mod &) override { add_byte(24); } + void visit(const Pow &) override { add_byte(25); } + void visit(const Equal &) override { add_byte(26); } + void visit(const NotEqual &) override { add_byte(27); } + void visit(const Approx &) override { add_byte(28); } + void visit(const Less &) override { add_byte(29); } + void visit(const LessEqual &) override { add_byte(30); } + void visit(const Greater &) override { add_byte(31); } + void visit(const GreaterEqual &) override { add_byte(32); } + void visit(const And &) override { add_byte(34); } + void visit(const Or &) override { add_byte(35); } + void visit(const Cos &) override { add_byte(36); } + void visit(const Sin &) override { add_byte(37); } + void visit(const Tan &) override { add_byte(38); } + void visit(const Cosh &) override { add_byte(39); } + void visit(const Sinh &) override { add_byte(40); } + void visit(const Tanh &) override { add_byte(41); } + void visit(const Acos &) override { add_byte(42); } + void visit(const Asin &) override { add_byte(43); } + void visit(const Atan &) override { add_byte(44); } + void visit(const Exp &) override { add_byte(45); } + void visit(const Log10 &) override { add_byte(46); } + void visit(const Log &) override { add_byte(47); } + void visit(const Sqrt &) override { add_byte(48); } + void visit(const Ceil &) override { add_byte(49); } + void visit(const Fabs &) override { add_byte(50); } + void visit(const Floor &) override { add_byte(51); } + void visit(const Atan2 &) override { add_byte(52); } + void visit(const Ldexp &) override { add_byte(53); } + void visit(const Pow2 &) override { add_byte(54); } + void visit(const Fmod &) override { add_byte(55); } + void visit(const Min &) override { add_byte(56); } + void visit(const Max &) override { add_byte(57); } + void visit(const IsNan &) override { add_byte(58); } + void visit(const Relu &) override { add_byte(59); } + void visit(const Sigmoid &) override { add_byte(60); } + void visit(const Elu &) override { add_byte(61); } + void visit(const Erf &) override { add_byte(62); } // traverse bool open(const Node &node) override { node.accept(*this); return true; } diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp index 043ad248251..4ce4fcd7747 100644 --- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp +++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp @@ -132,6 +132,7 @@ CompiledFunction::detect_issues(const nodes::Node &node) nodes::TensorReduce, nodes::TensorRename, nodes::TensorConcat, + nodes::TensorCellCast, nodes::TensorCreate, nodes::TensorLambda, nodes::TensorPeek>(node)) diff --git a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp index 8fbb2c5ac09..fce9abb7316 100644 --- a/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp +++ b/eval/src/vespa/eval/eval/llvm/llvm_wrapper.cpp @@ -478,6 +478,9 @@ struct FunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const TensorConcat &node) override { make_error(node.num_children()); } + void visit(const TensorCellCast &node) override { + make_error(node.num_children()); + } void visit(const TensorCreate &node) override { make_error(node.num_children()); } diff --git a/eval/src/vespa/eval/eval/make_tensor_function.cpp b/eval/src/vespa/eval/eval/make_tensor_function.cpp index 9ee42f164de..15f188db51a 100644 --- a/eval/src/vespa/eval/eval/make_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/make_tensor_function.cpp @@ -75,6 +75,12 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { stack.back() = tensor_function::concat(a, b, dimension, stash); } + void make_cell_cast(const Node &, CellType cell_type) { + assert(stack.size() >= 1); + const auto &a = stack.back().get(); + stack.back() = tensor_function::cell_cast(a, cell_type, stash); + } + bool maybe_make_const(const Node &node) { if (auto create = as<TensorCreate>(node)) { bool is_const = true; @@ -213,6 +219,9 @@ struct TensorFunctionBuilder : public NodeVisitor, public NodeTraverser { void visit(const TensorConcat &node) override { make_concat(node, node.dimension()); } + void visit(const TensorCellCast &node) override { + make_cell_cast(node, node.cell_type()); + } void visit(const TensorCreate &node) override { make_create(node); } diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp index 1c194736138..e7341bc1755 100644 --- a/eval/src/vespa/eval/eval/node_tools.cpp +++ b/eval/src/vespa/eval/eval/node_tools.cpp @@ -126,61 +126,62 @@ struct CopyNode : NodeTraverser, NodeVisitor { } // tensor nodes - void visit(const TensorMap &node) override { not_implemented(node); } - void visit(const TensorJoin &node) override { not_implemented(node); } - void visit(const TensorMerge &node) override { not_implemented(node); } - void visit(const TensorReduce &node) override { not_implemented(node); } - void visit(const TensorRename &node) override { not_implemented(node); } - void visit(const TensorConcat &node) override { not_implemented(node); } - void visit(const TensorCreate &node) override { not_implemented(node); } - void visit(const TensorLambda &node) override { not_implemented(node); } - void visit(const TensorPeek &node) override { not_implemented(node); } + void visit(const TensorMap &node) override { not_implemented(node); } + void visit(const TensorJoin &node) override { not_implemented(node); } + void visit(const TensorMerge &node) override { not_implemented(node); } + void visit(const TensorReduce &node) override { not_implemented(node); } + void visit(const TensorRename &node) override { not_implemented(node); } + void visit(const TensorConcat &node) override { not_implemented(node); } + void visit(const TensorCellCast &node) override { not_implemented(node); } + void visit(const TensorCreate &node) override { not_implemented(node); } + void visit(const TensorLambda &node) override { not_implemented(node); } + void visit(const TensorPeek &node) override { not_implemented(node); } // operator nodes - void visit(const Add &node) override { copy_operator(node); } - void visit(const Sub &node) override { copy_operator(node); } - void visit(const Mul &node) override { copy_operator(node); } - void visit(const Div &node) override { copy_operator(node); } - void visit(const Mod &node) override { copy_operator(node); } - void visit(const Pow &node) override { copy_operator(node); } - void visit(const Equal &node) override { copy_operator(node); } - void visit(const NotEqual &node) override { copy_operator(node); } - void visit(const Approx &node) override { copy_operator(node); } - void visit(const Less &node) override { copy_operator(node); } - void visit(const LessEqual &node) override { copy_operator(node); } - void visit(const Greater &node) override { copy_operator(node); } - void visit(const GreaterEqual &node) override { copy_operator(node); } - void visit(const And &node) override { copy_operator(node); } - void visit(const Or &node) override { copy_operator(node); } + void visit(const Add &node) override { copy_operator(node); } + void visit(const Sub &node) override { copy_operator(node); } + void visit(const Mul &node) override { copy_operator(node); } + void visit(const Div &node) override { copy_operator(node); } + void visit(const Mod &node) override { copy_operator(node); } + void visit(const Pow &node) override { copy_operator(node); } + void visit(const Equal &node) override { copy_operator(node); } + void visit(const NotEqual &node) override { copy_operator(node); } + void visit(const Approx &node) override { copy_operator(node); } + void visit(const Less &node) override { copy_operator(node); } + void visit(const LessEqual &node) override { copy_operator(node); } + void visit(const Greater &node) override { copy_operator(node); } + void visit(const GreaterEqual &node) override { copy_operator(node); } + void visit(const And &node) override { copy_operator(node); } + void visit(const Or &node) override { copy_operator(node); } // call nodes - void visit(const Cos &node) override { copy_call(node); } - void visit(const Sin &node) override { copy_call(node); } - void visit(const Tan &node) override { copy_call(node); } - void visit(const Cosh &node) override { copy_call(node); } - void visit(const Sinh &node) override { copy_call(node); } - void visit(const Tanh &node) override { copy_call(node); } - void visit(const Acos &node) override { copy_call(node); } - void visit(const Asin &node) override { copy_call(node); } - void visit(const Atan &node) override { copy_call(node); } - void visit(const Exp &node) override { copy_call(node); } - void visit(const Log10 &node) override { copy_call(node); } - void visit(const Log &node) override { copy_call(node); } - void visit(const Sqrt &node) override { copy_call(node); } - void visit(const Ceil &node) override { copy_call(node); } - void visit(const Fabs &node) override { copy_call(node); } - void visit(const Floor &node) override { copy_call(node); } - void visit(const Atan2 &node) override { copy_call(node); } - void visit(const Ldexp &node) override { copy_call(node); } - void visit(const Pow2 &node) override { copy_call(node); } - void visit(const Fmod &node) override { copy_call(node); } - void visit(const Min &node) override { copy_call(node); } - void visit(const Max &node) override { copy_call(node); } - void visit(const IsNan &node) override { copy_call(node); } - void visit(const Relu &node) override { copy_call(node); } - void visit(const Sigmoid &node) override { copy_call(node); } - void visit(const Elu &node) override { copy_call(node); } - void visit(const Erf &node) override { copy_call(node); } + void visit(const Cos &node) override { copy_call(node); } + void visit(const Sin &node) override { copy_call(node); } + void visit(const Tan &node) override { copy_call(node); } + void visit(const Cosh &node) override { copy_call(node); } + void visit(const Sinh &node) override { copy_call(node); } + void visit(const Tanh &node) override { copy_call(node); } + void visit(const Acos &node) override { copy_call(node); } + void visit(const Asin &node) override { copy_call(node); } + void visit(const Atan &node) override { copy_call(node); } + void visit(const Exp &node) override { copy_call(node); } + void visit(const Log10 &node) override { copy_call(node); } + void visit(const Log &node) override { copy_call(node); } + void visit(const Sqrt &node) override { copy_call(node); } + void visit(const Ceil &node) override { copy_call(node); } + void visit(const Fabs &node) override { copy_call(node); } + void visit(const Floor &node) override { copy_call(node); } + void visit(const Atan2 &node) override { copy_call(node); } + void visit(const Ldexp &node) override { copy_call(node); } + void visit(const Pow2 &node) override { copy_call(node); } + void visit(const Fmod &node) override { copy_call(node); } + void visit(const Min &node) override { copy_call(node); } + void visit(const Max &node) override { copy_call(node); } + void visit(const IsNan &node) override { copy_call(node); } + void visit(const Relu &node) override { copy_call(node); } + void visit(const Sigmoid &node) override { copy_call(node); } + void visit(const Elu &node) override { copy_call(node); } + void visit(const Erf &node) override { copy_call(node); } // traverse nodes bool open(const Node &) override { return !error; } diff --git a/eval/src/vespa/eval/eval/node_types.cpp b/eval/src/vespa/eval/eval/node_types.cpp index 7d76f2064a3..2df22d5433c 100644 --- a/eval/src/vespa/eval/eval/node_types.cpp +++ b/eval/src/vespa/eval/eval/node_types.cpp @@ -177,6 +177,9 @@ struct TypeResolver : public NodeVisitor, public NodeTraverser { bind(ValueType::concat(type(node.get_child(0)), type(node.get_child(1)), node.dimension()), node); } + void visit(const TensorCellCast &node) override { + bind(type(node.get_child(0)).cell_cast(node.cell_type()), node); + } void visit(const TensorCreate &node) override { for (size_t i = 0; i < node.num_children(); ++i) { if (!type(node.get_child(i)).is_double()) { diff --git a/eval/src/vespa/eval/eval/node_visitor.h b/eval/src/vespa/eval/eval/node_visitor.h index 95a5bec8be7..172cd48fe2a 100644 --- a/eval/src/vespa/eval/eval/node_visitor.h +++ b/eval/src/vespa/eval/eval/node_visitor.h @@ -19,71 +19,72 @@ namespace eval { struct NodeVisitor { // basic nodes - virtual void visit(const nodes::Number &) = 0; - virtual void visit(const nodes::Symbol &) = 0; - virtual void visit(const nodes::String &) = 0; - virtual void visit(const nodes::In &) = 0; - virtual void visit(const nodes::Neg &) = 0; - virtual void visit(const nodes::Not &) = 0; - virtual void visit(const nodes::If &) = 0; - virtual void visit(const nodes::Error &) = 0; + virtual void visit(const nodes::Number &) = 0; + virtual void visit(const nodes::Symbol &) = 0; + virtual void visit(const nodes::String &) = 0; + virtual void visit(const nodes::In &) = 0; + virtual void visit(const nodes::Neg &) = 0; + virtual void visit(const nodes::Not &) = 0; + virtual void visit(const nodes::If &) = 0; + virtual void visit(const nodes::Error &) = 0; // tensor nodes - virtual void visit(const nodes::TensorMap &) = 0; - virtual void visit(const nodes::TensorJoin &) = 0; - virtual void visit(const nodes::TensorMerge &) = 0; - virtual void visit(const nodes::TensorReduce &) = 0; - virtual void visit(const nodes::TensorRename &) = 0; - virtual void visit(const nodes::TensorConcat &) = 0; - virtual void visit(const nodes::TensorCreate &) = 0; - virtual void visit(const nodes::TensorLambda &) = 0; - virtual void visit(const nodes::TensorPeek &) = 0; + virtual void visit(const nodes::TensorMap &) = 0; + virtual void visit(const nodes::TensorJoin &) = 0; + virtual void visit(const nodes::TensorMerge &) = 0; + virtual void visit(const nodes::TensorReduce &) = 0; + virtual void visit(const nodes::TensorRename &) = 0; + virtual void visit(const nodes::TensorConcat &) = 0; + virtual void visit(const nodes::TensorCellCast &) = 0; + virtual void visit(const nodes::TensorCreate &) = 0; + virtual void visit(const nodes::TensorLambda &) = 0; + virtual void visit(const nodes::TensorPeek &) = 0; // operator nodes - virtual void visit(const nodes::Add &) = 0; - virtual void visit(const nodes::Sub &) = 0; - virtual void visit(const nodes::Mul &) = 0; - virtual void visit(const nodes::Div &) = 0; - virtual void visit(const nodes::Mod &) = 0; - virtual void visit(const nodes::Pow &) = 0; - virtual void visit(const nodes::Equal &) = 0; - virtual void visit(const nodes::NotEqual &) = 0; - virtual void visit(const nodes::Approx &) = 0; - virtual void visit(const nodes::Less &) = 0; - virtual void visit(const nodes::LessEqual &) = 0; - virtual void visit(const nodes::Greater &) = 0; - virtual void visit(const nodes::GreaterEqual &) = 0; - virtual void visit(const nodes::And &) = 0; - virtual void visit(const nodes::Or &) = 0; + virtual void visit(const nodes::Add &) = 0; + virtual void visit(const nodes::Sub &) = 0; + virtual void visit(const nodes::Mul &) = 0; + virtual void visit(const nodes::Div &) = 0; + virtual void visit(const nodes::Mod &) = 0; + virtual void visit(const nodes::Pow &) = 0; + virtual void visit(const nodes::Equal &) = 0; + virtual void visit(const nodes::NotEqual &) = 0; + virtual void visit(const nodes::Approx &) = 0; + virtual void visit(const nodes::Less &) = 0; + virtual void visit(const nodes::LessEqual &) = 0; + virtual void visit(const nodes::Greater &) = 0; + virtual void visit(const nodes::GreaterEqual &) = 0; + virtual void visit(const nodes::And &) = 0; + virtual void visit(const nodes::Or &) = 0; // call nodes - virtual void visit(const nodes::Cos &) = 0; - virtual void visit(const nodes::Sin &) = 0; - virtual void visit(const nodes::Tan &) = 0; - virtual void visit(const nodes::Cosh &) = 0; - virtual void visit(const nodes::Sinh &) = 0; - virtual void visit(const nodes::Tanh &) = 0; - virtual void visit(const nodes::Acos &) = 0; - virtual void visit(const nodes::Asin &) = 0; - virtual void visit(const nodes::Atan &) = 0; - virtual void visit(const nodes::Exp &) = 0; - virtual void visit(const nodes::Log10 &) = 0; - virtual void visit(const nodes::Log &) = 0; - virtual void visit(const nodes::Sqrt &) = 0; - virtual void visit(const nodes::Ceil &) = 0; - virtual void visit(const nodes::Fabs &) = 0; - virtual void visit(const nodes::Floor &) = 0; - virtual void visit(const nodes::Atan2 &) = 0; - virtual void visit(const nodes::Ldexp &) = 0; - virtual void visit(const nodes::Pow2 &) = 0; - virtual void visit(const nodes::Fmod &) = 0; - virtual void visit(const nodes::Min &) = 0; - virtual void visit(const nodes::Max &) = 0; - virtual void visit(const nodes::IsNan &) = 0; - virtual void visit(const nodes::Relu &) = 0; - virtual void visit(const nodes::Sigmoid &) = 0; - virtual void visit(const nodes::Elu &) = 0; - virtual void visit(const nodes::Erf &) = 0; + virtual void visit(const nodes::Cos &) = 0; + virtual void visit(const nodes::Sin &) = 0; + virtual void visit(const nodes::Tan &) = 0; + virtual void visit(const nodes::Cosh &) = 0; + virtual void visit(const nodes::Sinh &) = 0; + virtual void visit(const nodes::Tanh &) = 0; + virtual void visit(const nodes::Acos &) = 0; + virtual void visit(const nodes::Asin &) = 0; + virtual void visit(const nodes::Atan &) = 0; + virtual void visit(const nodes::Exp &) = 0; + virtual void visit(const nodes::Log10 &) = 0; + virtual void visit(const nodes::Log &) = 0; + virtual void visit(const nodes::Sqrt &) = 0; + virtual void visit(const nodes::Ceil &) = 0; + virtual void visit(const nodes::Fabs &) = 0; + virtual void visit(const nodes::Floor &) = 0; + virtual void visit(const nodes::Atan2 &) = 0; + virtual void visit(const nodes::Ldexp &) = 0; + virtual void visit(const nodes::Pow2 &) = 0; + virtual void visit(const nodes::Fmod &) = 0; + virtual void visit(const nodes::Min &) = 0; + virtual void visit(const nodes::Max &) = 0; + virtual void visit(const nodes::IsNan &) = 0; + virtual void visit(const nodes::Relu &) = 0; + virtual void visit(const nodes::Sigmoid &) = 0; + virtual void visit(const nodes::Elu &) = 0; + virtual void visit(const nodes::Erf &) = 0; virtual ~NodeVisitor() {} }; @@ -93,65 +94,66 @@ struct NodeVisitor { * of all types not specifically handled. **/ struct EmptyNodeVisitor : NodeVisitor { - void visit(const nodes::Number &) override {} - void visit(const nodes::Symbol &) override {} - void visit(const nodes::String &) override {} - void visit(const nodes::In &) override {} - void visit(const nodes::Neg &) override {} - void visit(const nodes::Not &) override {} - void visit(const nodes::If &) override {} - void visit(const nodes::Error &) override {} - void visit(const nodes::TensorMap &) override {} - void visit(const nodes::TensorJoin &) override {} - void visit(const nodes::TensorMerge &) override {} - void visit(const nodes::TensorReduce &) override {} - void visit(const nodes::TensorRename &) override {} - void visit(const nodes::TensorConcat &) override {} - void visit(const nodes::TensorCreate &) override {} - void visit(const nodes::TensorLambda &) override {} - void visit(const nodes::TensorPeek &) override {} - void visit(const nodes::Add &) override {} - void visit(const nodes::Sub &) override {} - void visit(const nodes::Mul &) override {} - void visit(const nodes::Div &) override {} - void visit(const nodes::Mod &) override {} - void visit(const nodes::Pow &) override {} - void visit(const nodes::Equal &) override {} - void visit(const nodes::NotEqual &) override {} - void visit(const nodes::Approx &) override {} - void visit(const nodes::Less &) override {} - void visit(const nodes::LessEqual &) override {} - void visit(const nodes::Greater &) override {} - void visit(const nodes::GreaterEqual &) override {} - void visit(const nodes::And &) override {} - void visit(const nodes::Or &) override {} - void visit(const nodes::Cos &) override {} - void visit(const nodes::Sin &) override {} - void visit(const nodes::Tan &) override {} - void visit(const nodes::Cosh &) override {} - void visit(const nodes::Sinh &) override {} - void visit(const nodes::Tanh &) override {} - void visit(const nodes::Acos &) override {} - void visit(const nodes::Asin &) override {} - void visit(const nodes::Atan &) override {} - void visit(const nodes::Exp &) override {} - void visit(const nodes::Log10 &) override {} - void visit(const nodes::Log &) override {} - void visit(const nodes::Sqrt &) override {} - void visit(const nodes::Ceil &) override {} - void visit(const nodes::Fabs &) override {} - void visit(const nodes::Floor &) override {} - void visit(const nodes::Atan2 &) override {} - void visit(const nodes::Ldexp &) override {} - void visit(const nodes::Pow2 &) override {} - void visit(const nodes::Fmod &) override {} - void visit(const nodes::Min &) override {} - void visit(const nodes::Max &) override {} - void visit(const nodes::IsNan &) override {} - void visit(const nodes::Relu &) override {} - void visit(const nodes::Sigmoid &) override {} - void visit(const nodes::Elu &) override {} - void visit(const nodes::Erf &) override {} + void visit(const nodes::Number &) override {} + void visit(const nodes::Symbol &) override {} + void visit(const nodes::String &) override {} + void visit(const nodes::In &) override {} + void visit(const nodes::Neg &) override {} + void visit(const nodes::Not &) override {} + void visit(const nodes::If &) override {} + void visit(const nodes::Error &) override {} + void visit(const nodes::TensorMap &) override {} + void visit(const nodes::TensorJoin &) override {} + void visit(const nodes::TensorMerge &) override {} + void visit(const nodes::TensorReduce &) override {} + void visit(const nodes::TensorRename &) override {} + void visit(const nodes::TensorConcat &) override {} + void visit(const nodes::TensorCellCast &) override {} + void visit(const nodes::TensorCreate &) override {} + void visit(const nodes::TensorLambda &) override {} + void visit(const nodes::TensorPeek &) override {} + void visit(const nodes::Add &) override {} + void visit(const nodes::Sub &) override {} + void visit(const nodes::Mul &) override {} + void visit(const nodes::Div &) override {} + void visit(const nodes::Mod &) override {} + void visit(const nodes::Pow &) override {} + void visit(const nodes::Equal &) override {} + void visit(const nodes::NotEqual &) override {} + void visit(const nodes::Approx &) override {} + void visit(const nodes::Less &) override {} + void visit(const nodes::LessEqual &) override {} + void visit(const nodes::Greater &) override {} + void visit(const nodes::GreaterEqual &) override {} + void visit(const nodes::And &) override {} + void visit(const nodes::Or &) override {} + void visit(const nodes::Cos &) override {} + void visit(const nodes::Sin &) override {} + void visit(const nodes::Tan &) override {} + void visit(const nodes::Cosh &) override {} + void visit(const nodes::Sinh &) override {} + void visit(const nodes::Tanh &) override {} + void visit(const nodes::Acos &) override {} + void visit(const nodes::Asin &) override {} + void visit(const nodes::Atan &) override {} + void visit(const nodes::Exp &) override {} + void visit(const nodes::Log10 &) override {} + void visit(const nodes::Log &) override {} + void visit(const nodes::Sqrt &) override {} + void visit(const nodes::Ceil &) override {} + void visit(const nodes::Fabs &) override {} + void visit(const nodes::Floor &) override {} + void visit(const nodes::Atan2 &) override {} + void visit(const nodes::Ldexp &) override {} + void visit(const nodes::Pow2 &) override {} + void visit(const nodes::Fmod &) override {} + void visit(const nodes::Min &) override {} + void visit(const nodes::Max &) override {} + void visit(const nodes::IsNan &) override {} + void visit(const nodes::Relu &) override {} + void visit(const nodes::Sigmoid &) override {} + void visit(const nodes::Elu &) override {} + void visit(const nodes::Erf &) override {} }; } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 19dbe812d3e..fb8d2f4e03c 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -5,6 +5,7 @@ #include "operation.h" #include "visit_stuff.h" #include "string_stuff.h" +#include "value_type_spec.h" #include <vespa/eval/instruction/generic_cell_cast.h> #include <vespa/eval/instruction/generic_concat.h> #include <vespa/eval/instruction/generic_create.h> @@ -216,6 +217,21 @@ Concat::visit_self(vespalib::ObjectVisitor &visitor) const //----------------------------------------------------------------------------- +InterpretedFunction::Instruction +CellCast::compile_self(const ValueBuilderFactory &, Stash &stash) const +{ + return instruction::GenericCellCast::make_instruction(child().result_type(), cell_type(), stash); +} + +void +CellCast::visit_self(vespalib::ObjectVisitor &visitor) const +{ + Super::visit_self(visitor); + visitor.visitString("cell_type", value_type::cell_type_to_name(cell_type())); +} + +//----------------------------------------------------------------------------- + void Create::push_children(std::vector<Child::CREF> &children) const { @@ -312,14 +328,6 @@ Lambda::visit_self(vespalib::ObjectVisitor &visitor) const //----------------------------------------------------------------------------- -InterpretedFunction::Instruction -CellCast::compile_self(const ValueBuilderFactory &, Stash &stash) const -{ - return instruction::GenericCellCast::make_instruction(child().result_type(), result_type().cell_type(), stash); -} - -//----------------------------------------------------------------------------- - void Peek::push_children(std::vector<Child::CREF> &children) const { @@ -466,6 +474,11 @@ const TensorFunction &lambda(const ValueType &type, const std::vector<size_t> &b return stash.create<Lambda>(type, bindings, function, std::move(node_types)); } +const TensorFunction &cell_cast(const TensorFunction &child, CellType cell_type, Stash &stash) { + ValueType result_type = child.result_type().cell_cast(cell_type); + return stash.create<CellCast>(result_type, child, cell_type); +} + const TensorFunction &peek(const TensorFunction ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, TensorFunction::CREF>> &spec, Stash &stash) { std::vector<vespalib::string> dimensions; for (const auto &dim_spec: spec) { diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index 55532bc4bf7..47610d02aca 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -306,6 +306,22 @@ public: //----------------------------------------------------------------------------- +class CellCast : public Op1 +{ + using Super = Op1; +private: + CellType _cell_type; +public: + CellCast(const ValueType &result_type_in, const TensorFunction &child_in, CellType cell_type) + : Super(result_type_in, child_in), _cell_type(cell_type) {} + CellType cell_type() const { return _cell_type; } + bool result_is_mutable() const override { return true; } + InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; + void visit_self(vespalib::ObjectVisitor &visitor) const override; +}; + +//----------------------------------------------------------------------------- + class Create : public Node { using Super = Node; @@ -355,20 +371,6 @@ public: //----------------------------------------------------------------------------- -class CellCast : public Op1 -{ -private: - using Super = Op1; -public: - CellCast(const TensorFunction &child_in, CellType to_cell_type) - : Super(ValueType::cell_cast(child_in.result_type(), to_cell_type), child_in) - {} - bool result_is_mutable() const override { return true; } - InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; -}; - -//----------------------------------------------------------------------------- - class Peek : public Node { using Super = Node; @@ -465,6 +467,7 @@ const TensorFunction &merge(const TensorFunction &lhs, const TensorFunction &rhs const TensorFunction &concat(const TensorFunction &lhs, const TensorFunction &rhs, const vespalib::string &dimension, Stash &stash); const TensorFunction &create(const ValueType &type, const std::map<TensorSpec::Address, TensorFunction::CREF> &spec, Stash &stash); const TensorFunction &lambda(const ValueType &type, const std::vector<size_t> &bindings, const Function &function, NodeTypes node_types, Stash &stash); +const TensorFunction &cell_cast(const TensorFunction &child, CellType cell_type, Stash &stash); const TensorFunction &peek(const TensorFunction ¶m, const std::map<vespalib::string, std::variant<TensorSpec::Label, TensorFunction::CREF>> &spec, Stash &stash); const TensorFunction &rename(const TensorFunction &child, const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to, Stash &stash); const TensorFunction &if_node(const TensorFunction &cond, const TensorFunction &true_child, const TensorFunction &false_child, Stash &stash); diff --git a/eval/src/vespa/eval/eval/tensor_nodes.cpp b/eval/src/vespa/eval/eval/tensor_nodes.cpp index 5cb064ad127..ca148b4275c 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.cpp +++ b/eval/src/vespa/eval/eval/tensor_nodes.cpp @@ -7,15 +7,16 @@ namespace vespalib { namespace eval { namespace nodes { -void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorReduce::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorRename::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorConcat::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorCreate::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorLambda::accept(NodeVisitor &visitor) const { visitor.visit(*this); } -void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMap ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorJoin ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorMerge ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorReduce ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorRename ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorConcat ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorCellCast::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorCreate ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorLambda ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } +void TensorPeek ::accept(NodeVisitor &visitor) const { visitor.visit(*this); } } // namespace vespalib::eval::nodes } // namespace vespalib::eval diff --git a/eval/src/vespa/eval/eval/tensor_nodes.h b/eval/src/vespa/eval/eval/tensor_nodes.h index 618e03f229e..90979953531 100644 --- a/eval/src/vespa/eval/eval/tensor_nodes.h +++ b/eval/src/vespa/eval/eval/tensor_nodes.h @@ -7,6 +7,7 @@ #include "tensor_spec.h" #include "aggr.h" #include "string_stuff.h" +#include "value_type_spec.h" #include <vespa/vespalib/stllike/string.h> #include <vector> #include <map> @@ -67,7 +68,7 @@ public: str += ")"; return str; } - void accept(NodeVisitor &visitor) const override ; + void accept(NodeVisitor &visitor) const override; size_t num_children() const override { return 2; } const Node &get_child(size_t idx) const override { assert(idx < 2); @@ -101,7 +102,7 @@ public: str += ")"; return str; } - void accept(NodeVisitor &visitor) const override ; + void accept(NodeVisitor &visitor) const override; size_t num_children() const override { return 2; } const Node &get_child(size_t idx) const override { assert(idx < 2); @@ -217,7 +218,7 @@ public: str += ")"; return str; } - void accept(NodeVisitor &visitor) const override ; + void accept(NodeVisitor &visitor) const override; size_t num_children() const override { return 2; } const Node &get_child(size_t idx) const override { assert(idx < 2); @@ -229,6 +230,36 @@ public: } }; +class TensorCellCast : public Node { +private: + Node_UP _child; + CellType _cell_type; +public: + TensorCellCast(Node_UP child, CellType cell_type) + : _child(std::move(child)), _cell_type(cell_type) {} + const Node &child() const { return *_child; } + CellType cell_type() const { return _cell_type; } + vespalib::string dump(DumpContext &ctx) const override { + vespalib::string str; + str += "cell_cast("; + str += _child->dump(ctx); + str += ","; + str += value_type::cell_type_to_name(_cell_type); + str += ")"; + return str; + } + void accept(NodeVisitor &visitor) const override; + size_t num_children() const override { return 1; } + const Node &get_child(size_t idx) const override { + (void) idx; + assert(idx == 0); + return *_child; + } + void detach_children(NodeHandler &handler) override { + handler.handle(std::move(_child)); + } +}; + class TensorCreate : public Node { public: using Spec = std::map<TensorSpec::Address, Node_UP>; @@ -259,7 +290,7 @@ public: str += "}"; return str; } - void accept(NodeVisitor &visitor) const override ; + void accept(NodeVisitor &visitor) const override; size_t num_children() const override { return _cells.size(); } const Node &get_child(size_t idx) const override { assert(idx < _cells.size()); @@ -365,7 +396,7 @@ public: str += "}"; return str; } - void accept(NodeVisitor &visitor) const override ; + void accept(NodeVisitor &visitor) const override; size_t num_children() const override { return (1 + _expr_dims.size()); } const Node &get_child(size_t idx) const override { assert(idx < num_children()); diff --git a/eval/src/vespa/eval/eval/test/eval_spec.cpp b/eval/src/vespa/eval/eval/test/eval_spec.cpp index 6b80b65df6c..63a3a23d9ae 100644 --- a/eval/src/vespa/eval/eval/test/eval_spec.cpp +++ b/eval/src/vespa/eval/eval/test/eval_spec.cpp @@ -181,6 +181,7 @@ EvalSpec::add_tensor_operation_cases() { add_expression({}, "tensor(x[10],y[10])(x==y)"); add_expression({"a","b"}, "concat(a,b,x)"); add_expression({"a","b"}, "concat(a,b,y)"); + add_expression({"a"}, "cell_cast(a,float)"); add_expression({}, "tensor(x[3]):{{x:0}:0,{x:1}:1,{x:2}:2}"); add_expression({"a"}, "a{x:3}"); } diff --git a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp index c20d8af32ec..4857b5dabc2 100644 --- a/eval/src/vespa/eval/eval/test/reference_evaluation.cpp +++ b/eval/src/vespa/eval/eval/test/reference_evaluation.cpp @@ -81,6 +81,10 @@ struct EvalNode : public NodeVisitor { result = ReferenceOperations::concat(eval_node(a, params), eval_node(b, params), dimension); } + void eval_cell_cast(const Node &a, CellType cell_type) { + result = ReferenceOperations::cell_cast(eval_node(a, params), cell_type); + } + void eval_create(const TensorCreate &node) { std::map<TensorSpec::Address, size_t> spec; std::vector<TensorSpec> children; @@ -193,6 +197,9 @@ struct EvalNode : public NodeVisitor { void visit(const TensorConcat &node) override { eval_concat(node.lhs(), node.rhs(), node.dimension()); } + void visit(const TensorCellCast &node) override { + eval_cell_cast(node.child(), node.cell_type()); + } void visit(const TensorCreate &node) override { eval_create(node); } diff --git a/eval/src/vespa/eval/eval/test/reference_operations.cpp b/eval/src/vespa/eval/eval/test/reference_operations.cpp index 8709150bdc4..024fa14954d 100644 --- a/eval/src/vespa/eval/eval/test/reference_operations.cpp +++ b/eval/src/vespa/eval/eval/test/reference_operations.cpp @@ -84,7 +84,7 @@ struct CopyCellsWithCast { TensorSpec ReferenceOperations::cell_cast(const TensorSpec &a, CellType to) { ValueType a_type = ValueType::from_spec(a.type()); - ValueType res_type = ValueType::cell_cast(a_type, to); + ValueType res_type = a_type.cell_cast(to); TensorSpec result(res_type.to_spec()); if (res_type.is_error()) { return result; diff --git a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp index 779ce30663f..117e8c9b149 100644 --- a/eval/src/vespa/eval/eval/test/tensor_conformance.cpp +++ b/eval/src/vespa/eval/eval/test/tensor_conformance.cpp @@ -7,6 +7,7 @@ #include <vespa/eval/eval/aggr.h> #include <vespa/eval/eval/value_codec.h> #include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/eval/value_type_spec.h> #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/vespalib/util/stringfmt.h> #include <vespa/vespalib/objects/nbostream.h> @@ -557,6 +558,28 @@ struct TestContext { //------------------------------------------------------------------------- + void test_cell_cast(const GenSpec &a) { + for (CellType cell_type: CellTypeUtils::list_types()) { + vespalib::string expr = fmt("cell_cast(a,%s)", value_type::cell_type_to_name(cell_type).c_str()); + TEST_DO(verify_result(factory, expr, {a}, a.cpy().cells(cell_type))); + } + } + + void test_cell_cast() { + std::vector<GenSpec> gen_list; + for (CellType cell_type: CellTypeUtils::list_types()) { + gen_list.push_back(GenSpec(-3).cells(cell_type)); + } + for (const auto &gen: gen_list) { + TEST_DO(test_cell_cast(gen)); + TEST_DO(test_cell_cast(gen.cpy().idx("x", 10))); + TEST_DO(test_cell_cast(gen.cpy().map("x", 10, 1))); + TEST_DO(test_cell_cast(gen.cpy().map("x", 4, 1).idx("y", 4))); + } + } + + //------------------------------------------------------------------------- + void test_rename(const vespalib::string &expr, const TensorSpec &input, const TensorSpec &expect) @@ -735,6 +758,7 @@ struct TestContext { TEST_DO(test_tensor_apply()); TEST_DO(test_dot_product()); TEST_DO(test_concat()); + TEST_DO(test_cell_cast()); TEST_DO(test_rename()); TEST_DO(test_tensor_lambda()); TEST_DO(test_tensor_create()); diff --git a/eval/src/vespa/eval/eval/value_type.cpp b/eval/src/vespa/eval/eval/value_type.cpp index 57b101f6eeb..b70edef7153 100644 --- a/eval/src/vespa/eval/eval/value_type.cpp +++ b/eval/src/vespa/eval/eval/value_type.cpp @@ -287,6 +287,16 @@ ValueType::rename(const std::vector<vespalib::string> &from, } ValueType +ValueType::cell_cast(CellType to_cell_type) const +{ + if (is_error()) { + return error_type(); + } + // TODO: return make_type(to_cell_type, _dimensions); + return tensor_type(_dimensions, to_cell_type); +} + +ValueType ValueType::make_type(CellType cell_type, std::vector<Dimension> dimensions_in) { sort_dimensions(dimensions_in); @@ -388,11 +398,6 @@ ValueType::either(const ValueType &one, const ValueType &other) { return one; } -ValueType -ValueType::cell_cast(const ValueType &from, CellType to_cell_type) { - return make_type(to_cell_type, from.dimensions()); -} - std::ostream & operator<<(std::ostream &os, const ValueType &type) { return os << type.to_spec(); diff --git a/eval/src/vespa/eval/eval/value_type.h b/eval/src/vespa/eval/eval/value_type.h index 4609f6a0b38..247912b274a 100644 --- a/eval/src/vespa/eval/eval/value_type.h +++ b/eval/src/vespa/eval/eval/value_type.h @@ -86,6 +86,7 @@ public: ValueType reduce(const std::vector<vespalib::string> &dimensions_in) const; ValueType rename(const std::vector<vespalib::string> &from, const std::vector<vespalib::string> &to) const; + ValueType cell_cast(CellType to_cell_type) const; static ValueType error_type() { return ValueType(); } static ValueType make_type(CellType cell_type, std::vector<Dimension> dimensions_in); @@ -105,8 +106,6 @@ public: static CellType unify_cell_types(const ValueType &a, const ValueType &b); static ValueType concat(const ValueType &lhs, const ValueType &rhs, const vespalib::string &dimension); static ValueType either(const ValueType &one, const ValueType &other); - static ValueType cell_cast(const ValueType &from, CellType to_cell_type); - }; std::ostream &operator<<(std::ostream &os, const ValueType &type); diff --git a/eval/src/vespa/eval/eval/value_type_spec.cpp b/eval/src/vespa/eval/eval/value_type_spec.cpp index a4575e33c2f..1eabbd0a9fc 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.cpp +++ b/eval/src/vespa/eval/eval/value_type_spec.cpp @@ -8,9 +8,17 @@ namespace vespalib::eval::value_type { -namespace { +std::optional<CellType> cell_type_from_name(const vespalib::string &name) { + if (name == "double") { + return CellType::DOUBLE; + } else if (name == "float") { + return CellType::FLOAT; + } else { + return std::nullopt; + } +} -const char *to_name(CellType cell_type) { +vespalib::string cell_type_to_name(CellType cell_type) { switch (cell_type) { case CellType::DOUBLE: return "double"; case CellType::FLOAT: return "float"; @@ -18,6 +26,8 @@ const char *to_name(CellType cell_type) { abort(); } +namespace { + class ParseContext { public: @@ -232,11 +242,11 @@ to_spec(const ValueType &type) if (type.is_error()) { os << "error"; } else if (type.is_scalar()) { - os << to_name(type.cell_type()); + os << cell_type_to_name(type.cell_type()); } else { os << "tensor"; if (type.cell_type() != CellType::DOUBLE) { - os << "<" << to_name(type.cell_type()) << ">"; + os << "<" << cell_type_to_name(type.cell_type()) << ">"; } os << "("; for (const auto &d: type.dimensions()) { diff --git a/eval/src/vespa/eval/eval/value_type_spec.h b/eval/src/vespa/eval/eval/value_type_spec.h index ff5113c769a..39168b34fcb 100644 --- a/eval/src/vespa/eval/eval/value_type_spec.h +++ b/eval/src/vespa/eval/eval/value_type_spec.h @@ -3,9 +3,13 @@ #pragma once #include "value_type.h" +#include <optional> namespace vespalib::eval::value_type { +std::optional<CellType> cell_type_from_name(const vespalib::string &name); +vespalib::string cell_type_to_name(CellType cell_type); + ValueType parse_spec(const char *pos_in, const char *end_in, const char *&pos_out, std::vector<ValueType::Dimension> *unsorted = nullptr); diff --git a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp index 58221b9c62f..ff84a65857c 100644 --- a/eval/src/vespa/eval/instruction/generic_cell_cast.cpp +++ b/eval/src/vespa/eval/instruction/generic_cell_cast.cpp @@ -46,7 +46,7 @@ GenericCellCast::make_instruction(const ValueType &input_type, Stash &stash) { CellType from = input_type.cell_type(); - auto result_type = ValueType::cell_cast(input_type, to_cell_type); + auto result_type = input_type.cell_cast(to_cell_type); auto ¶m = stash.create<ValueType>(result_type); CellType to = result_type.cell_type(); auto op = typify_invoke<2,TypifyCellType,SelectGenericCellCastOp>(from, to); diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h index 0319f1a929f..e1dfd35e7fd 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.h +++ b/eval/src/vespa/eval/instruction/generic_merge.h @@ -15,7 +15,7 @@ struct MergeParam { const ValueBuilderFactory &factory; MergeParam(const ValueType &lhs_type, const ValueType &rhs_type, join_fun_t function_in, const ValueBuilderFactory &factory_in) - : res_type(ValueType::join(lhs_type, rhs_type)), + : res_type(ValueType::merge(lhs_type, rhs_type)), function(function_in), num_mapped_dimensions(lhs_type.count_mapped_dimensions()), dense_subspace_size(lhs_type.dense_subspace_size()), |