diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2021-02-05 16:15:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2021-02-05 16:15:04 +0100 |
commit | 0e5959f756feb6f0883388215152772fa5d8181f (patch) | |
tree | 1fd9a52d1df8d362a38960aa8952e3ec40bd2593 /eval | |
parent | 4c6758eea33c9b28a1667730f12d6106a60cec67 (diff) | |
parent | dd2cc138985dde640bddee367e896ed1ca6679e7 (diff) |
Merge pull request #16390 from vespa-engine/arnej/refactor-sparse-merge-optimiser
Arnej/refactor sparse merge optimiser
Diffstat (limited to 'eval')
10 files changed, 301 insertions, 98 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index 23127cc12b5..0cba519bf88 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -66,6 +66,7 @@ vespa_define_module( src/tests/instruction/pow_as_map_optimizer src/tests/instruction/remove_trivial_dimension_optimizer src/tests/instruction/sparse_dot_product_function + src/tests/instruction/sparse_merge_function src/tests/instruction/sum_max_dot_product_function src/tests/instruction/vector_from_doubles_function src/tests/streamed/value diff --git a/eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt new file mode 100644 index 00000000000..f905bdd8c1b --- /dev/null +++ b/eval/src/tests/instruction/sparse_merge_function/CMakeLists.txt @@ -0,0 +1,9 @@ +# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_sparse_merge_function_test_app TEST + SOURCES + sparse_merge_function_test.cpp + DEPENDS + vespaeval + GTest::GTest +) +vespa_add_test(NAME eval_sparse_merge_function_test_app COMMAND eval_sparse_merge_function_test_app) diff --git a/eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp b/eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp new file mode 100644 index 00000000000..e175286e18c --- /dev/null +++ b/eval/src/tests/instruction/sparse_merge_function/sparse_merge_function_test.cpp @@ -0,0 +1,82 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/eval/eval/fast_value.h> +#include <vespa/eval/eval/simple_value.h> +#include <vespa/eval/instruction/sparse_merge_function.h> +#include <vespa/eval/eval/test/eval_fixture.h> +#include <vespa/eval/eval/test/gen_spec.h> +#include <vespa/vespalib/gtest/gtest.h> + +using namespace vespalib::eval; +using namespace vespalib::eval::test; + +const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get(); +const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get(); + +//----------------------------------------------------------------------------- + +EvalFixture::ParamRepo make_params() { + return EvalFixture::ParamRepo() + .add("scalar1", GenSpec(1.0).gen()) + .add("scalar2", GenSpec(2.0).gen()) + .add_variants("v1_x", GenSpec(3.0).map("x", 32, 1)) + .add_variants("v2_x", GenSpec(4.0).map("x", 16, 2)) + .add_variants("v3_xz", GenSpec(5.0).map("x", 16, 2).idx("z", 1)) + .add("dense", GenSpec(6.0).idx("x", 10).gen()) + .add("m1_xy", GenSpec(7.0).map("x", 32, 1).map("y", 16, 2).gen()) + .add("m2_xy", GenSpec(8.0).map("x", 16, 2).map("y", 32, 1).gen()) + .add("mixed", GenSpec(9.0).map("x", 8, 1).idx("y", 5).gen()); +} +EvalFixture::ParamRepo param_repo = make_params(); + +void assert_optimized(const vespalib::string &expr) { + EvalFixture fast_fixture(prod_factory, expr, param_repo, true); + EvalFixture test_fixture(test_factory, expr, param_repo, true); + EvalFixture slow_fixture(prod_factory, expr, param_repo, false); + EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(fast_fixture.find_all<SparseMergeFunction>().size(), 1u); + EXPECT_EQ(test_fixture.find_all<SparseMergeFunction>().size(), 1u); + EXPECT_EQ(slow_fixture.find_all<SparseMergeFunction>().size(), 0u); +} + +void assert_not_optimized(const vespalib::string &expr) { + EvalFixture fast_fixture(prod_factory, expr, param_repo, true); + EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); + EXPECT_EQ(fast_fixture.find_all<SparseMergeFunction>().size(), 0u); +} + +//----------------------------------------------------------------------------- + +TEST(SparseMerge, expression_can_be_optimized) +{ + assert_optimized("merge(v1_x,v2_x,f(x,y)(x+y))"); + assert_optimized("merge(v1_x,v2_x,f(x,y)(max(x,y)))"); + assert_optimized("merge(v1_x,v2_x,f(x,y)(x+y+1))"); + assert_optimized("merge(v1_x_f,v2_x_f,f(x,y)(x+y))"); + assert_optimized("merge(v3_xz,v3_xz,f(x,y)(x+y))"); +} + +TEST(SparseMerge, multi_dimensional_expression_can_be_optimized) +{ + assert_optimized("merge(m1_xy,m2_xy,f(x,y)(x+y))"); + assert_optimized("merge(m1_xy,m2_xy,f(x,y)(x*y))"); +} + +TEST(SparseMerge, similar_expressions_are_not_optimized) +{ + assert_not_optimized("merge(scalar1,scalar2,f(x,y)(x+y))"); + assert_not_optimized("merge(dense,dense,f(x,y)(x+y))"); + assert_not_optimized("merge(mixed,mixed,f(x,y)(x+y))"); +} + +TEST(SparseMerge, mixed_cell_types_are_not_optimized) +{ + assert_not_optimized("merge(v1_x,v2_x_f,f(x,y)(x+y))"); + assert_not_optimized("merge(v1_x_f,v2_x,f(x,y)(x+y))"); +} + +//----------------------------------------------------------------------------- + +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/eval/fast_value.hpp b/eval/src/vespa/eval/eval/fast_value.hpp index 88319df7590..6673494ccd2 100644 --- a/eval/src/vespa/eval/eval/fast_value.hpp +++ b/eval/src/vespa/eval/eval/fast_value.hpp @@ -155,12 +155,6 @@ struct FastValueIndex final : Value::Index { const std::vector<JoinAddrSource> &addr_sources, ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash); - template <typename LCT, typename RCT, typename OCT, typename Fun> - static const Value &sparse_only_merge(const ValueType &res_type, const Fun &fun, - const FastValueIndex &lhs, const FastValueIndex &rhs, - ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, - Stash &stash) __attribute((noinline)); - size_t size() const override { return map.size(); } std::unique_ptr<View> create_view(const std::vector<size_t> &dims) const override; }; @@ -429,32 +423,4 @@ FastValueIndex::sparse_no_overlap_join(const ValueType &res_type, const Fun &fun //----------------------------------------------------------------------------- -template <typename LCT, typename RCT, typename OCT, typename Fun> -const Value & -FastValueIndex::sparse_only_merge(const ValueType &res_type, const Fun &fun, - const FastValueIndex &lhs, const FastValueIndex &rhs, - ConstArrayRef<LCT> lhs_cells, ConstArrayRef<RCT> rhs_cells, Stash &stash) -{ - size_t guess_size = lhs.map.size() + rhs.map.size(); - auto &result = stash.create<FastValue<OCT,true>>(res_type, lhs.map.addr_size(), 1, guess_size); - lhs.map.each_map_entry([&](auto lhs_subspace, auto hash) - { - result.add_mapping(lhs.map.get_addr(lhs_subspace), hash); - result.my_cells.push_back_fast(lhs_cells[lhs_subspace]); - }); - rhs.map.each_map_entry([&](auto rhs_subspace, auto hash) - { - auto rhs_addr = rhs.map.get_addr(rhs_subspace); - auto result_subspace = result.my_index.map.lookup(rhs_addr, hash); - if (result_subspace == FastAddrMap::npos()) { - result.add_mapping(rhs_addr, hash); - result.my_cells.push_back_fast(rhs_cells[rhs_subspace]); - } else { - OCT &out_cell = *result.my_cells.get(result_subspace); - out_cell = fun(out_cell, rhs_cells[rhs_subspace]); - } - }); - return result; -} - } diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp index 25612b8d5fd..196e8a98679 100644 --- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp @@ -6,6 +6,7 @@ #include <vespa/eval/instruction/dense_dot_product_function.h> #include <vespa/eval/instruction/sparse_dot_product_function.h> +#include <vespa/eval/instruction/sparse_merge_function.h> #include <vespa/eval/instruction/mixed_inner_product_function.h> #include <vespa/eval/instruction/sum_max_dot_product_function.h> #include <vespa/eval/instruction/dense_xw_product_function.h> @@ -72,6 +73,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te child.set(MixedSimpleJoinFunction::optimize(child.get(), stash)); child.set(JoinWithNumberFunction::optimize(child.get(), stash)); child.set(DenseSingleReduceFunction::optimize(child.get(), stash)); + child.set(SparseMergeFunction::optimize(child.get(), stash)); nodes.pop_back(); } } diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt index cac69d23640..3def8907ac8 100644 --- a/eval/src/vespa/eval/instruction/CMakeLists.txt +++ b/eval/src/vespa/eval/instruction/CMakeLists.txt @@ -33,6 +33,7 @@ vespa_add_library(eval_instruction OBJECT remove_trivial_dimension_optimizer.cpp replace_type_function.cpp sparse_dot_product_function.cpp + sparse_merge_function.cpp sum_max_dot_product_function.cpp vector_from_doubles_function.cpp ) diff --git a/eval/src/vespa/eval/instruction/generic_merge.cpp b/eval/src/vespa/eval/instruction/generic_merge.cpp index b40388aa547..9b098db7763 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.cpp +++ b/eval/src/vespa/eval/instruction/generic_merge.cpp @@ -17,37 +17,6 @@ namespace vespalib::eval::instruction { using State = InterpretedFunction::State; using Instruction = InterpretedFunction::Instruction; -namespace { - -//----------------------------------------------------------------------------- - -struct MergeParam { - const ValueType res_type; - const join_fun_t function; - const size_t num_mapped_dimensions; - const size_t dense_subspace_size; - std::vector<size_t> all_view_dims; - const ValueBuilderFactory &factory; - MergeParam(const ValueType &lhs_type, const ValueType &rhs_type, - join_fun_t function_in, const ValueBuilderFactory &factory_in) - : res_type(ValueType::join(lhs_type, rhs_type)), - function(function_in), - num_mapped_dimensions(lhs_type.count_mapped_dimensions()), - dense_subspace_size(lhs_type.dense_subspace_size()), - all_view_dims(num_mapped_dimensions), - factory(factory_in) - { - assert(!res_type.is_error()); - assert(num_mapped_dimensions == rhs_type.count_mapped_dimensions()); - assert(num_mapped_dimensions == res_type.count_mapped_dimensions()); - assert(dense_subspace_size == rhs_type.dense_subspace_size()); - assert(dense_subspace_size == res_type.dense_subspace_size()); - for (size_t i = 0; i < num_mapped_dimensions; ++i) { - all_view_dims[i] = i; - } - } - ~MergeParam(); -}; MergeParam::~MergeParam() = default; //----------------------------------------------------------------------------- @@ -108,39 +77,14 @@ generic_mixed_merge(const Value &a, const Value &b, return builder->build(std::move(builder)); } -template <typename LCT, typename RCT, typename OCT, typename Fun> -void my_mixed_merge_op(State &state, uint64_t param_in) { - const auto ¶m = unwrap_param<MergeParam>(param_in); - const Value &lhs = state.peek(1); - const Value &rhs = state.peek(0); - auto up = generic_mixed_merge<LCT, RCT, OCT, Fun>(lhs, rhs, param); - auto &result = state.stash.create<std::unique_ptr<Value>>(std::move(up)); - const Value &result_ref = *(result.get()); - state.pop_pop_push(result_ref); -}; + +namespace { template <typename LCT, typename RCT, typename OCT, typename Fun> -void my_sparse_merge_op(State &state, uint64_t param_in) { +void my_mixed_merge_op(State &state, uint64_t param_in) { const auto ¶m = unwrap_param<MergeParam>(param_in); const Value &lhs = state.peek(1); const Value &rhs = state.peek(0); - if (auto indexes = detect_type<FastValueIndex>(lhs.index(), rhs.index())) { - auto lhs_cells = lhs.cells().typify<LCT>(); - auto rhs_cells = rhs.cells().typify<RCT>(); - if (lhs_cells.size() < rhs_cells.size()) { - return state.pop_pop_push( - FastValueIndex::sparse_only_merge<RCT,LCT,OCT,Fun>( - param.res_type, Fun(param.function), - indexes.get<1>(), indexes.get<0>(), - rhs_cells, lhs_cells, state.stash)); - } else { - return state.pop_pop_push( - FastValueIndex::sparse_only_merge<LCT,RCT,OCT,Fun>( - param.res_type, Fun(param.function), - indexes.get<0>(), indexes.get<1>(), - lhs_cells, rhs_cells, state.stash)); - } - } auto up = generic_mixed_merge<LCT, RCT, OCT, Fun>(lhs, rhs, param); auto &result = state.stash.create<std::unique_ptr<Value>>(std::move(up)); const Value &result_ref = *(result.get()); @@ -148,10 +92,7 @@ void my_sparse_merge_op(State &state, uint64_t param_in) { }; struct SelectGenericMergeOp { - template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke(const MergeParam ¶m) { - if (param.dense_subspace_size == 1) { - return my_sparse_merge_op<LCT,RCT,OCT,Fun>; - } + template <typename LCT, typename RCT, typename OCT, typename Fun> static auto invoke() { return my_mixed_merge_op<LCT,RCT,OCT,Fun>; } }; @@ -167,7 +108,7 @@ GenericMerge::make_instruction(const ValueType &lhs_type, const ValueType &rhs_t const ValueBuilderFactory &factory, Stash &stash) { const auto ¶m = stash.create<MergeParam>(lhs_type, rhs_type, function, factory); - auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function, param); + auto fun = typify_invoke<4,MergeTypify,SelectGenericMergeOp>(lhs_type.cell_type(), rhs_type.cell_type(), param.res_type.cell_type(), function); return Instruction(fun, wrap_param<MergeParam>(param)); } diff --git a/eval/src/vespa/eval/instruction/generic_merge.h b/eval/src/vespa/eval/instruction/generic_merge.h index 2b2964366cc..0319f1a929f 100644 --- a/eval/src/vespa/eval/instruction/generic_merge.h +++ b/eval/src/vespa/eval/instruction/generic_merge.h @@ -6,6 +6,39 @@ namespace vespalib::eval::instruction { +struct MergeParam { + const ValueType res_type; + const join_fun_t function; + const size_t num_mapped_dimensions; + const size_t dense_subspace_size; + std::vector<size_t> all_view_dims; + const ValueBuilderFactory &factory; + MergeParam(const ValueType &lhs_type, const ValueType &rhs_type, + join_fun_t function_in, const ValueBuilderFactory &factory_in) + : res_type(ValueType::join(lhs_type, rhs_type)), + function(function_in), + num_mapped_dimensions(lhs_type.count_mapped_dimensions()), + dense_subspace_size(lhs_type.dense_subspace_size()), + all_view_dims(num_mapped_dimensions), + factory(factory_in) + { + assert(!res_type.is_error()); + assert(num_mapped_dimensions == rhs_type.count_mapped_dimensions()); + assert(num_mapped_dimensions == res_type.count_mapped_dimensions()); + assert(dense_subspace_size == rhs_type.dense_subspace_size()); + assert(dense_subspace_size == res_type.dense_subspace_size()); + for (size_t i = 0; i < num_mapped_dimensions; ++i) { + all_view_dims[i] = i; + } + } + ~MergeParam(); +}; + +template <typename LCT, typename RCT, typename OCT, typename Fun> +std::unique_ptr<Value> +generic_mixed_merge(const Value &a, const Value &b, + const MergeParam ¶ms); + struct GenericMerge { static InterpretedFunction::Instruction make_instruction(const ValueType &lhs_type, const ValueType &rhs_type, diff --git a/eval/src/vespa/eval/instruction/sparse_merge_function.cpp b/eval/src/vespa/eval/instruction/sparse_merge_function.cpp new file mode 100644 index 00000000000..924c4d69fe9 --- /dev/null +++ b/eval/src/vespa/eval/instruction/sparse_merge_function.cpp @@ -0,0 +1,146 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include "sparse_merge_function.h" +#include "generic_merge.h" +#include <vespa/eval/eval/fast_value.hpp> +#include <vespa/vespalib/util/typify.h> + +namespace vespalib::eval { + +using namespace tensor_function; +using namespace operation; +using namespace instruction; + +namespace { + +template <typename CT, bool single_dim, typename Fun> +const Value& my_fast_sparse_merge(const FastAddrMap &a_map, const FastAddrMap &b_map, + const CT *a_cells, const CT *b_cells, + const MergeParam ¶ms, + Stash &stash) +{ + Fun fun(params.function); + size_t guess_size = a_map.size() + b_map.size(); + auto &result = stash.create<FastValue<CT,true>>(params.res_type, params.num_mapped_dimensions, 1u, guess_size); + if constexpr (single_dim) { + string_id cur_label; + ConstArrayRef<string_id> addr(&cur_label, 1); + const auto &a_labels = a_map.labels(); + for (size_t i = 0; i < a_labels.size(); ++i) { + cur_label = a_labels[i]; + result.add_mapping(addr, cur_label.hash()); + result.my_cells.push_back_fast(a_cells[i]); + } + const auto &b_labels = b_map.labels(); + for (size_t i = 0; i < b_labels.size(); ++i) { + cur_label = b_labels[i]; + auto result_subspace = result.my_index.map.lookup_singledim(cur_label); + if (result_subspace == FastAddrMap::npos()) { + result.add_mapping(addr, cur_label.hash()); + result.my_cells.push_back_fast(b_cells[i]); + } else { + CT *out_cell = result.my_cells.get(result_subspace); + out_cell[0] = fun(out_cell[0], b_cells[i]); + } + } + } else { + a_map.each_map_entry([&](auto lhs_subspace, auto hash) + { + result.add_mapping(a_map.get_addr(lhs_subspace), hash); + result.my_cells.push_back_fast(a_cells[lhs_subspace]); + }); + b_map.each_map_entry([&](auto rhs_subspace, auto hash) + { + auto rhs_addr = b_map.get_addr(rhs_subspace); + auto result_subspace = result.my_index.map.lookup(rhs_addr, hash); + if (result_subspace == FastAddrMap::npos()) { + result.add_mapping(rhs_addr, hash); + result.my_cells.push_back_fast(b_cells[rhs_subspace]); + } else { + CT *out_cell = result.my_cells.get(result_subspace); + out_cell[0] = fun(out_cell[0], b_cells[rhs_subspace]); + } + }); + } + return result; +} + +template <typename CT, bool single_dim, typename Fun> +void my_sparse_merge_op(InterpretedFunction::State &state, uint64_t param_in) { + const auto ¶m = unwrap_param<MergeParam>(param_in); + assert(param.dense_subspace_size == 1u); + const Value &a = state.peek(1); + const Value &b = state.peek(0); + const auto &a_idx = a.index(); + const auto &b_idx = b.index(); + if (__builtin_expect(are_fast(a_idx, b_idx), true)) { + auto a_cells = a.cells().typify<CT>(); + auto b_cells = b.cells().typify<CT>(); + const Value &v = my_fast_sparse_merge<CT,single_dim,Fun>(as_fast(a_idx).map, as_fast(b_idx).map, + a_cells.cbegin(), b_cells.cbegin(), + param, state.stash); + state.pop_pop_push(v); + } else { + auto up = generic_mixed_merge<CT,CT,CT,Fun>(a, b, param); + state.pop_pop_push(*state.stash.create<std::unique_ptr<Value>>(std::move(up))); + } +} + +struct SelectSparseMergeOp { + template <typename CT, typename SINGLE_DIM, typename Fun> + static auto invoke() { return my_sparse_merge_op<CT,SINGLE_DIM::value,Fun>; } +}; + +using MyTypify = TypifyValue<TypifyCellType,TypifyBool,operation::TypifyOp2>; + +} // namespace <unnamed> + +SparseMergeFunction::SparseMergeFunction(const tensor_function::Merge &original) + : tensor_function::Merge(original.result_type(), + original.lhs(), + original.rhs(), + original.function()) +{ + assert(compatible_types(result_type(), lhs().result_type(), rhs().result_type())); +} + +InterpretedFunction::Instruction +SparseMergeFunction::compile_self(const ValueBuilderFactory &factory, Stash &stash) const +{ + const auto ¶m = stash.create<MergeParam>(lhs().result_type(), rhs().result_type(), + function(), factory); + size_t num_dims = result_type().count_mapped_dimensions(); + auto op = typify_invoke<3,MyTypify,SelectSparseMergeOp>(result_type().cell_type(), + num_dims == 1, + function()); + return InterpretedFunction::Instruction(op, wrap_param<MergeParam>(param)); +} + +bool +SparseMergeFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) +{ + if ((lhs.cell_type() == rhs.cell_type()) + && (lhs.count_mapped_dimensions() > 0) + && (lhs.dense_subspace_size() == 1)) + { + assert(res == lhs); + assert(res == rhs); + return true; + } + return false; +} + +const TensorFunction & +SparseMergeFunction::optimize(const TensorFunction &expr, Stash &stash) +{ + if (auto merge = as<Merge>(expr)) { + const TensorFunction &lhs = merge->lhs(); + const TensorFunction &rhs = merge->rhs(); + if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) { + return stash.create<SparseMergeFunction>(*merge); + } + } + return expr; +} + +} // namespace diff --git a/eval/src/vespa/eval/instruction/sparse_merge_function.h b/eval/src/vespa/eval/instruction/sparse_merge_function.h new file mode 100644 index 00000000000..d2b26196ed6 --- /dev/null +++ b/eval/src/vespa/eval/instruction/sparse_merge_function.h @@ -0,0 +1,22 @@ +// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vespa/eval/eval/tensor_function.h> + +namespace vespalib::eval { + +/** + * Tensor function for merging two sparse tensors. + */ +class SparseMergeFunction : public tensor_function::Merge +{ +public: + SparseMergeFunction(const tensor_function::Merge &original); + InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; + bool result_is_mutable() const override { return true; } + static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs); + static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash); +}; + +} // namespace |