summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-11-09 16:08:21 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-11-09 16:10:09 +0000
commit1642f0e8f1b3195fcce3cc2fc476f2b69d2b18d4 (patch)
treecbefd82d9f1a0bc8c8720be95120f2c90c7a4a39 /eval
parentc5b365a96a3a9c06685bf81f704cd94ced5ec28a (diff)
detect unpack bits in more cases
- with multiple dimensions - inside map_subspaces
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp109
-rw-r--r--eval/src/vespa/eval/instruction/unpack_bits_function.cpp125
2 files changed, 155 insertions, 79 deletions
diff --git a/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp b/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp
index 1125ef4272a..a6f44831043 100644
--- a/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp
+++ b/eval/src/tests/instruction/unpack_bits_function/unpack_bits_function_test.cpp
@@ -16,16 +16,15 @@ const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get();
auto my_seq = Seq({-128, -43, 85, 127});
-EvalFixture::ParamRepo make_params() {
- return EvalFixture::ParamRepo()
- .add("full", GenSpec(-128).idx("x", 256).cells(CellType::INT8))
- .add("vx8", GenSpec().seq(my_seq).idx("x", 8).cells(CellType::INT8))
- .add("vy8", GenSpec().seq(my_seq).idx("y", 8).cells(CellType::INT8))
- .add("vxf", GenSpec().seq(my_seq).idx("x", 8).cells(CellType::FLOAT));
-}
-EvalFixture::ParamRepo param_repo = make_params();
-
-void assert_optimized(const vespalib::string &expr) {
+auto full = GenSpec(-128).idx("x", 32).cells(CellType::INT8);
+auto vx8 = GenSpec().seq(my_seq).idx("x", 8).cells(CellType::INT8);
+auto vy8 = GenSpec().seq(my_seq).idx("y", 8).cells(CellType::INT8);
+auto vxf = GenSpec().seq(my_seq).idx("x", 8).cells(CellType::FLOAT);
+auto tmxy8 = GenSpec().seq(my_seq).idx("t", 1).idx("x", 3).idx("y", 4).cells(CellType::INT8);
+
+void assert_expr(const GenSpec &spec, const vespalib::string &expr, bool optimized) {
+ EvalFixture::ParamRepo param_repo;
+ param_repo.add("a", spec);
EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
EvalFixture test_fixture(test_factory, expr, param_repo, true);
EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
@@ -33,62 +32,94 @@ void assert_optimized(const vespalib::string &expr) {
EXPECT_EQ(fast_fixture.result(), expect);
EXPECT_EQ(test_fixture.result(), expect);
EXPECT_EQ(slow_fixture.result(), expect);
- EXPECT_EQ(fast_fixture.find_all<UnpackBitsFunction>().size(), 1u);
- EXPECT_EQ(test_fixture.find_all<UnpackBitsFunction>().size(), 1u);
+ EXPECT_EQ(fast_fixture.find_all<UnpackBitsFunction>().size(), optimized ? 1u : 0u);
+ EXPECT_EQ(test_fixture.find_all<UnpackBitsFunction>().size(), optimized ? 1u : 0u);
EXPECT_EQ(slow_fixture.find_all<UnpackBitsFunction>().size(), 0u);
}
-void assert_not_optimized(const vespalib::string &expr) {
- EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
- EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQ(fast_fixture.find_all<UnpackBitsFunction>().size(), 0u);
+void assert_impl(const GenSpec &spec, const vespalib::string &expr, bool optimized) {
+ assert_expr(spec, expr, optimized);
+ vespalib::string wrapped_expr("map_subspaces(a,f(a)(");
+ wrapped_expr.append(expr);
+ wrapped_expr.append("))");
+ assert_expr(spec, wrapped_expr, optimized);
+ assert_expr(spec.cpy().map("m", {"foo", "bar", "baz"}), wrapped_expr, optimized);
+}
+
+void assert_optimized(const GenSpec &spec, const vespalib::string &expr) {
+ assert_impl(spec, expr, true);
+}
+
+void assert_not_optimized(const GenSpec &spec, const vespalib::string &expr) {
+ assert_impl(spec, expr, false);
}
//-----------------------------------------------------------------------------
TEST(UnpackBitsTest, expression_can_be_optimized_with_big_bitorder) {
- assert_optimized("tensor<int8>(x[2048])(bit(full{x:(x/8)},7-x%8))");
- assert_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7-x%8))");
+ assert_optimized(full, "tensor<int8>(x[256])(bit(a{x:(x/8)},7-x%8))");
+ assert_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x%8))");
}
TEST(UnpackBitsTest, expression_can_be_optimized_with_small_bitorder) {
- assert_optimized("tensor<int8>(x[2048])(bit(full{x:(x/8)},x%8))");
- assert_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},x%8))");
+ assert_optimized(full, "tensor<int8>(x[256])(bit(a{x:(x/8)},x%8))");
+ assert_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},x%8))");
}
-TEST(UnpackBitsTest, unpack_bits_can_rename_dimension) {
- assert_optimized("tensor<int8>(x[64])(bit(vy8{y:(x/8)},7-x%8))");
- assert_optimized("tensor<int8>(x[64])(bit(vy8{y:(x/8)},x%8))");
+TEST(UnpackBitsTest, result_may_have_other_cell_types_than_int8) {
+ assert_optimized(vx8, "tensor<bfloat16>(x[64])(bit(a{x:(x/8)},7-x%8))");
+ assert_optimized(vx8, "tensor<float>(x[64])(bit(a{x:(x/8)},7-x%8))");
+ assert_optimized(vx8, "tensor<double>(x[64])(bit(a{x:(x/8)},7-x%8))");
+
+ assert_optimized(vx8, "tensor<bfloat16>(x[64])(bit(a{x:(x/8)},x%8))");
+ assert_optimized(vx8, "tensor<float>(x[64])(bit(a{x:(x/8)},x%8))");
+ assert_optimized(vx8, "tensor<double>(x[64])(bit(a{x:(x/8)},x%8))");
}
-TEST(UnpackBitsTest, result_may_have_other_cell_types_than_int8) {
- assert_optimized("tensor<bfloat16>(x[64])(bit(vx8{x:(x/8)},7-x%8))");
- assert_optimized("tensor<float>(x[64])(bit(vx8{x:(x/8)},7-x%8))");
- assert_optimized("tensor<double>(x[64])(bit(vx8{x:(x/8)},7-x%8))");
+TEST(UnpackBitsTest, unpack_bits_can_have_multiple_dimensions) {
+ assert_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:(t),x:(x),y:(y/8)},7-y%8))");
+ assert_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:(t),x:(x),y:(y/8)},y%8))");
+}
- assert_optimized("tensor<bfloat16>(x[64])(bit(vx8{x:(x/8)},x%8))");
- assert_optimized("tensor<float>(x[64])(bit(vx8{x:(x/8)},x%8))");
- assert_optimized("tensor<double>(x[64])(bit(vx8{x:(x/8)},x%8))");
+TEST(UnpackBitsTest, unpack_bits_can_rename_dimensions) {
+ assert_optimized(tmxy8, "tensor<int8>(e[1],f[3],g[32])(bit(a{t:(e),x:(f),y:(g/8)},7-g%8))");
+ assert_optimized(tmxy8, "tensor<int8>(e[1],f[3],g[32])(bit(a{t:(e),x:(f),y:(g/8)},g%8))");
}
//-----------------------------------------------------------------------------
TEST(UnpackBitsTest, source_must_be_int8) {
- assert_not_optimized("tensor<int8>(x[64])(bit(vxf{x:(x/8)},7-x%8))");
+ assert_not_optimized(vxf, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x%8))");
}
TEST(UnpackBitsTest, dimension_sizes_must_be_appropriate) {
- assert_not_optimized("tensor<int8>(x[60])(bit(vx8{x:(x/8)},7-x%8))");
- assert_not_optimized("tensor<int8>(x[68])(bit(vx8{x:(x/8)},7-x%8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[60])(bit(a{x:(x/8)},7-x%8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[68])(bit(a{x:(x/8)},7-x%8))");
+ assert_not_optimized(tmxy8, "tensor<int8>(e[1],f[2],g[32])(bit(a{t:(e),x:(f),y:(g/8)},7-g%8))");
+ assert_not_optimized(tmxy8, "tensor<int8>(e[2],f[3],g[32])(bit(a{t:(e),x:(f),y:(g/8)},7-g%8))");
+}
+
+TEST(UnpackBitsTest, must_unpack_inner_dimension) {
+ assert_not_optimized(tmxy8, "tensor<int8>(t[1],x[24],y[4])(bit(a{t:(t),x:(x/8),y:(y)},7-x%8))");
+}
+
+TEST(UnpackBitsTest, cannot_transpose_even_trivial_dimensions) {
+ assert_not_optimized(tmxy8, "tensor<int8>(f[1],e[3],g[32])(bit(a{t:(f),x:(e),y:(g/8)},7-g%8))");
+ assert_not_optimized(tmxy8, "tensor<int8>(f[1],e[3],g[32])(bit(a{t:(f),x:(e),y:(g/8)},g%8))");
+}
+
+TEST(UnpackBitsTest, outer_dimensions_must_be_dimension_index_directly) {
+ assert_not_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:0,x:(x),y:(y/8)},7-y%8))");
+ assert_not_optimized(tmxy8, "tensor<int8>(t[1],x[3],y[32])(bit(a{t:(t),x:(x+1-1),y:(y/8)},7-y%8))");
}
TEST(UnpackBitsTest, similar_expressions_are_not_optimized) {
- assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x*8)},7-x%8))");
- assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/9)},7-x%8))");
- assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},8-x%8))");
- assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7+x%8))");
- assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7-x/8))");
- assert_not_optimized("tensor<int8>(x[64])(bit(vx8{x:(x/8)},7-x%9))");
+ assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x*8)},7-x%8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/9)},7-x%8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},8-x%8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7+x%8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x/8))");
+ assert_not_optimized(vx8, "tensor<int8>(x[64])(bit(a{x:(x/8)},7-x%9))");
}
//-----------------------------------------------------------------------------
diff --git a/eval/src/vespa/eval/instruction/unpack_bits_function.cpp b/eval/src/vespa/eval/instruction/unpack_bits_function.cpp
index b568dd9a220..117bac79fe3 100644
--- a/eval/src/vespa/eval/instruction/unpack_bits_function.cpp
+++ b/eval/src/vespa/eval/instruction/unpack_bits_function.cpp
@@ -9,11 +9,13 @@
#include <vespa/eval/eval/call_nodes.h>
#include <vespa/eval/eval/tensor_nodes.h>
#include <vespa/eval/eval/wrap_param.h>
+#include <cassert>
namespace vespalib::eval {
using namespace vespalib::eval::nodes;
using tensor_function::Lambda;
+using tensor_function::MapSubspaces;
using tensor_function::wrap_param;
using tensor_function::unwrap_param;
using tensor_function::inject;
@@ -39,7 +41,7 @@ void my_unpack_bits_op(InterpretedFunction::State &state, uint64_t param) {
}
}
}
- Value &result_ref = state.stash.create<DenseValueView>(res_type, TypedCells(unpacked_cells));
+ Value &result_ref = state.stash.create<ValueView>(res_type, state.peek(0).index(), TypedCells(unpacked_cells));
state.pop_push(result_ref);
}
@@ -55,63 +57,79 @@ using MyTypify = TypifyValue<TypifyCellType,TypifyBool>;
//-----------------------------------------------------------------------------
-bool valid_lambda_params(const Lambda &lambda) {
- return ((lambda.lambda().num_params() == 2) &&
- (lambda.bindings().size() == 1));
-}
-
-bool valid_type(const ValueType &type, bool must_be_int8) {
- return ((type.is_dense()) &&
- (type.dimensions().size() == 1) &&
- (!must_be_int8 || (type.cell_type() == CellType::INT8)));
-}
-
bool compatible_types(const ValueType &packed, const ValueType &unpacked) {
- return (valid_type(packed, true) && valid_type(unpacked, false) &&
- (unpacked.dimensions()[0].size == (packed.dimensions()[0].size * 8)));
+ const auto &pdims = packed.dimensions();
+ const auto &udims = unpacked.dimensions();
+ if ((pdims.size() > 0) &&
+ (packed.cell_type() == CellType::INT8) &&
+ (packed.is_dense() && unpacked.is_dense()) &&
+ (pdims.size() == udims.size()))
+ {
+ for (size_t i = 0; i < pdims.size() - 1; ++i) {
+ if (udims[i].size != pdims[i].size) {
+ return false;
+ }
+ }
+ return udims.back().size == (pdims.back().size * 8);
+ }
+ return false;
}
-bool is_little_bit_expr(const Node &node) {
+bool is_little_bit_expr(const Node &node, size_t wanted_param) {
// 'x%8'
if (auto mod = as<Mod>(node)) {
if (auto param = as<Symbol>(mod->lhs())) {
if (auto eight = as<Number>(mod->rhs())) {
- return ((param->id() == 0) && (eight->value() == 8.0));
+ return ((param->id() == wanted_param) && (eight->value() == 8.0));
}
}
}
return false;
}
-bool is_big_bit_expr(const Node &node) {
+bool is_big_bit_expr(const Node &node, size_t wanted_param) {
// '7-(x%8)'
if (auto sub = as<Sub>(node)) {
if (auto seven = as<Number>(sub->lhs())) {
- return ((seven->value() == 7.0) && is_little_bit_expr(sub->rhs()));
+ return ((seven->value() == 7.0) && is_little_bit_expr(sub->rhs(), wanted_param));
}
}
return false;
}
-bool is_byte_expr(const Node &node) {
+bool is_ident_expr(const Node &node, size_t wanted_param) {
+ // 'x'
+ if (auto param = as<Symbol>(node)) {
+ return (param->id() == wanted_param);
+ }
+ return false;
+}
+
+bool is_byte_expr(const Node &node, size_t wanted_param) {
// 'x/8'
if (auto div = as<Div>(node)) {
if (auto param = as<Symbol>(div->lhs())) {
if (auto eight = as<Number>(div->rhs())) {
- return ((param->id() == 0) && (eight->value() == 8.0));
+ return ((param->id() == wanted_param) && (eight->value() == 8.0));
}
}
}
return false;
}
-bool is_byte_peek(const TensorPeek &peek) {
+bool is_byte_peek(const TensorPeek &peek, size_t dim_cnt) {
if (auto param = as<Symbol>(peek.param())) {
- if ((param->id() == 1) &&
- (peek.dim_list().size() == 1) &&
- (peek.num_children() == 2))
+ if ((dim_cnt > 0) &&
+ (param->id() == dim_cnt) &&
+ (peek.dim_list().size() == dim_cnt) &&
+ (peek.num_children() == (dim_cnt + 1)))
{
- return is_byte_expr(peek.get_child(1));
+ for (size_t i = 0; i < dim_cnt - 1; ++i) {
+ if (!is_ident_expr(peek.get_child(i + 1), i)) {
+ return false;
+ }
+ }
+ return is_byte_expr(peek.get_child(dim_cnt), dim_cnt - 1);
}
}
return false;
@@ -119,6 +137,36 @@ bool is_byte_peek(const TensorPeek &peek) {
//-----------------------------------------------------------------------------
+struct Result {
+ const bool is_unpack_bits;
+ const bool is_big_bitorder;
+ const ValueType &src_type;
+};
+
+Result detect_unpack_bits(const ValueType &dst_type,
+ size_t num_bindings,
+ const Function &lambda,
+ const NodeTypes &types)
+{
+ size_t dim_cnt = dst_type.count_indexed_dimensions();
+ if ((num_bindings == 1) && (lambda.num_params() == (dim_cnt + 1))) {
+ if (auto bit = as<Bit>(lambda.root())) {
+ if (auto peek = as<TensorPeek>(bit->get_child(0))) {
+ const ValueType &src_type = types.get_type(peek->param());
+ if (compatible_types(src_type, dst_type) && is_byte_peek(*peek, dim_cnt)) {
+ assert(dim_cnt > 0);
+ if (is_big_bit_expr(bit->get_child(1), dim_cnt - 1)) {
+ return {true, true, src_type};
+ } else if (is_little_bit_expr(bit->get_child(1), dim_cnt - 1)) {
+ return {true, false, src_type};
+ }
+ }
+ }
+ }
+ }
+ return {false, false, dst_type};
+}
+
} // namespace <unnamed>
UnpackBitsFunction::UnpackBitsFunction(const ValueType &res_type_in,
@@ -141,21 +189,18 @@ const TensorFunction &
UnpackBitsFunction::optimize(const TensorFunction &expr, Stash &stash)
{
if (auto lambda = as<Lambda>(expr)) {
- const ValueType &dst_type = lambda->result_type();
- if (auto bit = as<Bit>(lambda->lambda().root())) {
- if (auto peek = as<TensorPeek>(bit->get_child(0))) {
- const ValueType &src_type = lambda->types().get_type(peek->param());
- if (compatible_types(src_type, dst_type) &&
- valid_lambda_params(*lambda) &&
- is_byte_peek(*peek))
- {
- size_t param_idx = lambda->bindings()[0];
- if (is_big_bit_expr(bit->get_child(1))) {
- return stash.create<UnpackBitsFunction>(dst_type, inject(src_type, param_idx, stash), true);
- } else if (is_little_bit_expr(bit->get_child(1))) {
- return stash.create<UnpackBitsFunction>(dst_type, inject(src_type, param_idx, stash), false);
- }
- }
+ auto result = detect_unpack_bits(lambda->result_type(), lambda->bindings().size(), lambda->lambda(), lambda->types());
+ if (result.is_unpack_bits) {
+ assert(lambda->bindings().size() == 1);
+ const TensorFunction &input = inject(result.src_type, lambda->bindings()[0], stash);
+ return stash.create<UnpackBitsFunction>(lambda->result_type(), input, result.is_big_bitorder);
+ }
+ }
+ if (auto map_subspaces = as<MapSubspaces>(expr)) {
+ if (auto lambda = as<TensorLambda>(map_subspaces->lambda().root())) {
+ auto result = detect_unpack_bits(lambda->type(), lambda->bindings().size(), lambda->lambda(), map_subspaces->types());
+ if (result.is_unpack_bits) {
+ return stash.create<UnpackBitsFunction>(map_subspaces->result_type(), map_subspaces->child(), result.is_big_bitorder);
}
}
}