diff options
author | Håvard Pettersen <havardpe@yahooinc.com> | 2023-11-09 16:08:21 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@yahooinc.com> | 2023-11-09 16:10:09 +0000 |
commit | 1642f0e8f1b3195fcce3cc2fc476f2b69d2b18d4 (patch) | |
tree | cbefd82d9f1a0bc8c8720be95120f2c90c7a4a39 /eval | |
parent | c5b365a96a3a9c06685bf81f704cd94ced5ec28a (diff) |
detect unpack bits in more cases
- with multiple dimensions
- inside map_subspaces
Diffstat (limited to 'eval')
-rw-r--r-- | eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp | 109 | ||||
-rw-r--r-- | eval/src/vespa/eval/instruction/unpack_bits_function.cpp | 125 |
2 files changed, 155 insertions, 79 deletions
diff --git a/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp b/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp index 1125ef4272a..a6f44831043 100644 --- a/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp +++ b/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp @@ -16,16 +16,15 @@ const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get(); auto my_seq = Seq({-128, -43, 85, 127}); -EvalFixture::ParamRepo make_params() { - return EvalFixture::ParamRepo() - .add("full", GenSpec(-128).idx("x", 256).cells(CellType::INT8)) - .add("vx8", GenSpec().seq(my_seq).idx("x", 8).cells(CellType::INT8)) - .add("vy8", GenSpec().seq(my_seq).idx("y", 8).cells(CellType::INT8)) - .add("vxf", GenSpec().seq(my_seq).idx("x", 8).cells(CellType::FLOAT)); -} -EvalFixture::ParamRepo param_repo = make_params(); - -void assert_optimized(const vespalib::string &expr) { +auto full = GenSpec(-128).idx("x", 32).cells(CellType::INT8); +auto vx8 = GenSpec().seq(my_seq).idx("x", 8).cells(CellType::INT8); +auto vy8 = GenSpec().seq(my_seq).idx("y", 8).cells(CellType::INT8); +auto vxf = GenSpec().seq(my_seq).idx("x", 8).cells(CellType::FLOAT); +auto tmxy8 = GenSpec().seq(my_seq).idx("t", 1).idx("x", 3).idx("y", 4).cells(CellType::INT8); + +void assert_expr(const GenSpec &spec, const vespalib::string &expr, bool optimized) { + EvalFixture::ParamRepo param_repo; + param_repo.add("a", spec); EvalFixture fast_fixture(prod_factory, expr, param_repo, true); EvalFixture test_fixture(test_factory, expr, param_repo, true); EvalFixture slow_fixture(prod_factory, expr, param_repo, false); @@ -33,62 +32,94 @@ void assert_optimized(const vespalib::string &expr) { EXPECT_EQ(fast_fixture.result(), expect); EXPECT_EQ(test_fixture.result(), expect); EXPECT_EQ(slow_fixture.result(), expect); - EXPECT_EQ(fast_fixture.find_all<UnpackBitsFunction>().size(), 1u); - EXPECT_EQ(test_fixture.find_all<UnpackBitsFunction>().size(), 1u); + EXPECT_EQ(fast_fixture.find_all<UnpackBitsFunction>().size(), optimized ? 1u : 0u); + EXPECT_EQ(test_fixture.find_all<UnpackBitsFunction>().size(), optimized ? 1u : 0u); EXPECT_EQ(slow_fixture.find_all<UnpackBitsFunction>().size(), 0u); } -void assert_not_optimized(const vespalib::string &expr) { - EvalFixture fast_fixture(prod_factory, expr, param_repo, true); - EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo)); - EXPECT_EQ(fast_fixture.find_all<UnpackBitsFunction>().size(), 0u); +void assert_impl(const GenSpec &spec, const vespalib::string &expr, bool optimized) { + assert_expr(spec, expr, optimized); + vespalib::string wrapped_expr("map_subspaces(a,f(a)("); + wrapped_expr.append(expr); + wrapped_expr.append("))"); + assert_expr(spec, wrapped_expr, optimized); + assert_expr(spec.cpy().map("m", {"foo", "bar", "baz"}), wrapped_expr, optimized); +} + +void assert_optimized(const GenSpec &spec, const vespalib::string &expr) { + assert_impl(spec, expr, true); +} + +void assert_not_optimized(const GenSpec &spec, const vespalib::string &expr) { + assert_impl(spec, expr, false); } //----------------------------------------------------------------------------- TEST(UnpackBitsTest, expression_can_be_optimized_with_big_bitorder) { - assert_optimized("tensor<int8>(x[2048])(bit(full{x:(x/8)},7-x%8))"); - assert_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7-x%8))"); + assert_optimized(full, "tensor<int8>(x[256])(bit(a{x:(x/8)},7-x%8))"); + assert_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x%8))"); } TEST(UnpackBitsTest, expression_can_be_optimized_with_small_bitorder) { - assert_optimized("tensor<int8>(x[2048])(bit(full{x:(x/8)},x%8))"); - assert_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},x%8))"); + assert_optimized(full, "tensor<int8>(x[256])(bit(a{x:(x/8)},x%8))"); + assert_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},x%8))"); } -TEST(UnpackBitsTest, unpack_bits_can_rename_dimension) { - assert_optimized("tensor<int8>(x[64])(bit(vy8{y:(x/8)},7-x%8))"); - assert_optimized("tensor<int8>(x[64])(bit(vy8{y:(x/8)},x%8))"); +TEST(UnpackBitsTest, result_may_have_other_cell_types_than_int8) { + assert_optimized(vx8, "tensor<bfloat16>(x[64])(bit(a{x:(x/8)},7-x%8))"); + assert_optimized(vx8, "tensor<float>(x[64])(bit(a{x:(x/8)},7-x%8))"); + assert_optimized(vx8, "tensor<double>(x[64])(bit(a{x:(x/8)},7-x%8))"); + + assert_optimized(vx8, "tensor<bfloat16>(x[64])(bit(a{x:(x/8)},x%8))"); + assert_optimized(vx8, "tensor<float>(x[64])(bit(a{x:(x/8)},x%8))"); + assert_optimized(vx8, "tensor<double>(x[64])(bit(a{x:(x/8)},x%8))"); } -TEST(UnpackBitsTest, result_may_have_other_cell_types_than_int8) { - assert_optimized("tensor<bfloat16>(x[64])(bit(vx8{x:(x/8)},7-x%8))"); - assert_optimized("tensor<float>(x[64])(bit(vx8{x:(x/8)},7-x%8))"); - assert_optimized("tensor<double>(x[64])(bit(vx8{x:(x/8)},7-x%8))"); +TEST(UnpackBitsTest, unpack_bits_can_have_multiple_dimensions) { + assert_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:(t),x:(x),y:(y/8)},7-y%8))"); + assert_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:(t),x:(x),y:(y/8)},y%8))"); +} - assert_optimized("tensor<bfloat16>(x[64])(bit(vx8{x:(x/8)},x%8))"); - assert_optimized("tensor<float>(x[64])(bit(vx8{x:(x/8)},x%8))"); - assert_optimized("tensor<double>(x[64])(bit(vx8{x:(x/8)},x%8))"); +TEST(UnpackBitsTest, unpack_bits_can_rename_dimensions) { + assert_optimized(tmxy8, "tensor<int8>(e[1],f[3],g[32])(bit(a{t:(e),x:(f),y:(g/8)},7-g%8))"); + assert_optimized(tmxy8, "tensor<int8>(e[1],f[3],g[32])(bit(a{t:(e),x:(f),y:(g/8)},g%8))"); } //----------------------------------------------------------------------------- TEST(UnpackBitsTest, source_must_be_int8) { - assert_not_optimized("tensor<int8>(x[64])(bit(vxf{x:(x/8)},7-x%8))"); + assert_not_optimized(vxf, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x%8))"); } TEST(UnpackBitsTest, dimension_sizes_must_be_appropriate) { - assert_not_optimized("tensor<int8>(x[60])(bit(vx8{x:(x/8)},7-x%8))"); - assert_not_optimized("tensor<int8>(x[68])(bit(vx8{x:(x/8)},7-x%8))"); + assert_not_optimized(vx8, "tensor<int8>(x[60])(bit(a{x:(x/8)},7-x%8))"); + assert_not_optimized(vx8, "tensor<int8>(x[68])(bit(a{x:(x/8)},7-x%8))"); + assert_not_optimized(tmxy8, "tensor<int8>(e[1],f[2],g[32])(bit(a{t:(e),x:(f),y:(g/8)},7-g%8))"); + assert_not_optimized(tmxy8, "tensor<int8>(e[2],f[3],g[32])(bit(a{t:(e),x:(f),y:(g/8)},7-g%8))"); +} + +TEST(UnpackBitsTest, must_unpack_inner_dimension) { + assert_not_optimized(tmxy8, "tensor<int8>(t[1],x[24],y[4])(bit(a{t:(t),x:(x/8),y:(y)},7-x%8))"); +} + +TEST(UnpackBitsTest, cannot_transpose_even_trivial_dimensions) { + assert_not_optimized(tmxy8, "tensor<int8>(f[1],e[3],g[32])(bit(a{t:(f),x:(e),y:(g/8)},7-g%8))"); + assert_not_optimized(tmxy8, "tensor<int8>(f[1],e[3],g[32])(bit(a{t:(f),x:(e),y:(g/8)},g%8))"); +} + +TEST(UnpackBitsTest, outer_dimensions_must_be_dimension_index_directly) { + assert_not_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:0,x:(x),y:(y/8)},7-y%8))"); + assert_not_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:(t),x:(x+1-1),y:(y/8)},7-y%8))"); } TEST(UnpackBitsTest, similar_expressions_are_not_optimized) { - assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x*8)},7-x%8))"); - assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/9)},7-x%8))"); - assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},8-x%8))"); - assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7+x%8))"); - assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7-x/8))"); - assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7-x%9))"); + assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x*8)},7-x%8))"); + assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/9)},7-x%8))"); + assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},8-x%8))"); + assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7+x%8))"); + assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x/8))"); + assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x%9))"); } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/instruction/unpack_bits_function.cpp b/eval/src/vespa/eval/instruction/unpack_bits_function.cpp index b568dd9a220..117bac79fe3 100644 --- a/eval/src/vespa/eval/instruction/unpack_bits_function.cpp +++ b/eval/src/vespa/eval/instruction/unpack_bits_function.cpp @@ -9,11 +9,13 @@ #include <vespa/eval/eval/call_nodes.h> #include <vespa/eval/eval/tensor_nodes.h> #include <vespa/eval/eval/wrap_param.h> +#include <cassert> namespace vespalib::eval { using namespace vespalib::eval::nodes; using tensor_function::Lambda; +using tensor_function::MapSubspaces; using tensor_function::wrap_param; using tensor_function::unwrap_param; using tensor_function::inject; @@ -39,7 +41,7 @@ void my_unpack_bits_op(InterpretedFunction::State &state, uint64_t param) { } } } - Value &result_ref = state.stash.create<DenseValueView>(res_type, TypedCells(unpacked_cells)); + Value &result_ref = state.stash.create<ValueView>(res_type, state.peek(0).index(), TypedCells(unpacked_cells)); state.pop_push(result_ref); } @@ -55,63 +57,79 @@ using MyTypify = TypifyValue<TypifyCellType,TypifyBool>; //----------------------------------------------------------------------------- -bool valid_lambda_params(const Lambda &lambda) { - return ((lambda.lambda().num_params() == 2) && - (lambda.bindings().size() == 1)); -} - -bool valid_type(const ValueType &type, bool must_be_int8) { - return ((type.is_dense()) && - (type.dimensions().size() == 1) && - (!must_be_int8 || (type.cell_type() == CellType::INT8))); -} - bool compatible_types(const ValueType &packed, const ValueType &unpacked) { - return (valid_type(packed, true) && valid_type(unpacked, false) && - (unpacked.dimensions()[0].size == (packed.dimensions()[0].size * 8))); + const auto &pdims = packed.dimensions(); + const auto &udims = unpacked.dimensions(); + if ((pdims.size() > 0) && + (packed.cell_type() == CellType::INT8) && + (packed.is_dense() && unpacked.is_dense()) && + (pdims.size() == udims.size())) + { + for (size_t i = 0; i < pdims.size() - 1; ++i) { + if (udims[i].size != pdims[i].size) { + return false; + } + } + return udims.back().size == (pdims.back().size * 8); + } + return false; } -bool is_little_bit_expr(const Node &node) { +bool is_little_bit_expr(const Node &node, size_t wanted_param) { // 'x%8' if (auto mod = as<Mod>(node)) { if (auto param = as<Symbol>(mod->lhs())) { if (auto eight = as<Number>(mod->rhs())) { - return ((param->id() == 0) && (eight->value() == 8.0)); + return ((param->id() == wanted_param) && (eight->value() == 8.0)); } } } return false; } -bool is_big_bit_expr(const Node &node) { +bool is_big_bit_expr(const Node &node, size_t wanted_param) { // '7-(x%8)' if (auto sub = as<Sub>(node)) { if (auto seven = as<Number>(sub->lhs())) { - return ((seven->value() == 7.0) && is_little_bit_expr(sub->rhs())); + return ((seven->value() == 7.0) && is_little_bit_expr(sub->rhs(), wanted_param)); } } return false; } -bool is_byte_expr(const Node &node) { +bool is_ident_expr(const Node &node, size_t wanted_param) { + // 'x' + if (auto param = as<Symbol>(node)) { + return (param->id() == wanted_param); + } + return false; +} + +bool is_byte_expr(const Node &node, size_t wanted_param) { // 'x/8' if (auto div = as<Div>(node)) { if (auto param = as<Symbol>(div->lhs())) { if (auto eight = as<Number>(div->rhs())) { - return ((param->id() == 0) && (eight->value() == 8.0)); + return ((param->id() == wanted_param) && (eight->value() == 8.0)); } } } return false; } -bool is_byte_peek(const TensorPeek &peek) { +bool is_byte_peek(const TensorPeek &peek, size_t dim_cnt) { if (auto param = as<Symbol>(peek.param())) { - if ((param->id() == 1) && - (peek.dim_list().size() == 1) && - (peek.num_children() == 2)) + if ((dim_cnt > 0) && + (param->id() == dim_cnt) && + (peek.dim_list().size() == dim_cnt) && + (peek.num_children() == (dim_cnt + 1))) { - return is_byte_expr(peek.get_child(1)); + for (size_t i = 0; i < dim_cnt - 1; ++i) { + if (!is_ident_expr(peek.get_child(i + 1), i)) { + return false; + } + } + return is_byte_expr(peek.get_child(dim_cnt), dim_cnt - 1); } } return false; @@ -119,6 +137,36 @@ bool is_byte_peek(const TensorPeek &peek) { //----------------------------------------------------------------------------- +struct Result { + const bool is_unpack_bits; + const bool is_big_bitorder; + const ValueType &src_type; +}; + +Result detect_unpack_bits(const ValueType &dst_type, + size_t num_bindings, + const Function &lambda, + const NodeTypes &types) +{ + size_t dim_cnt = dst_type.count_indexed_dimensions(); + if ((num_bindings == 1) && (lambda.num_params() == (dim_cnt + 1))) { + if (auto bit = as<Bit>(lambda.root())) { + if (auto peek = as<TensorPeek>(bit->get_child(0))) { + const ValueType &src_type = types.get_type(peek->param()); + if (compatible_types(src_type, dst_type) && is_byte_peek(*peek, dim_cnt)) { + assert(dim_cnt > 0); + if (is_big_bit_expr(bit->get_child(1), dim_cnt - 1)) { + return {true, true, src_type}; + } else if (is_little_bit_expr(bit->get_child(1), dim_cnt - 1)) { + return {true, false, src_type}; + } + } + } + } + } + return {false, false, dst_type}; +} + } // namespace <unnamed> UnpackBitsFunction::UnpackBitsFunction(const ValueType &res_type_in, @@ -141,21 +189,18 @@ const TensorFunction & UnpackBitsFunction::optimize(const TensorFunction &expr, Stash &stash) { if (auto lambda = as<Lambda>(expr)) { - const ValueType &dst_type = lambda->result_type(); - if (auto bit = as<Bit>(lambda->lambda().root())) { - if (auto peek = as<TensorPeek>(bit->get_child(0))) { - const ValueType &src_type = lambda->types().get_type(peek->param()); - if (compatible_types(src_type, dst_type) && - valid_lambda_params(*lambda) && - is_byte_peek(*peek)) - { - size_t param_idx = lambda->bindings()[0]; - if (is_big_bit_expr(bit->get_child(1))) { - return stash.create<UnpackBitsFunction>(dst_type, inject(src_type, param_idx, stash), true); - } else if (is_little_bit_expr(bit->get_child(1))) { - return stash.create<UnpackBitsFunction>(dst_type, inject(src_type, param_idx, stash), false); - } - } + auto result = detect_unpack_bits(lambda->result_type(), lambda->bindings().size(), lambda->lambda(), lambda->types()); + if (result.is_unpack_bits) { + assert(lambda->bindings().size() == 1); + const TensorFunction &input = inject(result.src_type, lambda->bindings()[0], stash); + return stash.create<UnpackBitsFunction>(lambda->result_type(), input, result.is_big_bitorder); + } + } + if (auto map_subspaces = as<MapSubspaces>(expr)) { + if (auto lambda = as<TensorLambda>(map_subspaces->lambda().root())) { + auto result = detect_unpack_bits(lambda->type(), lambda->bindings().size(), lambda->lambda(), map_subspaces->types()); + if (result.is_unpack_bits) { + return stash.create<UnpackBitsFunction>(map_subspaces->result_type(), map_subspaces->child(), result.is_big_bitorder); } } } |