diff options
author | Arne Juul <arnej@verizonmedia.com> | 2021-01-21 10:50:54 +0000 |
---|---|---|
committer | Arne Juul <arnej@verizonmedia.com> | 2021-01-21 10:50:54 +0000 |
commit | e127f6cb47f95977d441808db3b265faecae5310 (patch) | |
tree | c9bfb03ea03bade5a1d9b82133ca3afbe40b8c5f /eval | |
parent | b63d5ee7262d7ed78742fdd01e4f7cfc2edbf0ee (diff) |
extend "Dense Simple Map" function to handle all mixed cases.
Diffstat (limited to 'eval')
4 files changed, 32 insertions, 23 deletions
diff --git a/eval/src/tests/instruction/dense_simple_map_function/dense_simple_map_function_test.cpp b/eval/src/tests/instruction/dense_simple_map_function/dense_simple_map_function_test.cpp index 13a24c13a2e..d83a8b141a8 100644 --- a/eval/src/tests/instruction/dense_simple_map_function/dense_simple_map_function_test.cpp +++ b/eval/src/tests/instruction/dense_simple_map_function/dense_simple_map_function_test.cpp @@ -19,6 +19,8 @@ EvalFixture::ParamRepo make_params() { .add("b", spec(2.5)) .add("sparse", spec({x({"a"})}, N())) .add("mixed", spec({x({"a"}),y(5)}, N())) + .add("@sparse", spec({x({"a"})}, N()), true) + .add("@mixed", spec({x({"a"}),y(5)}, N()), true) .add_matrix("x", 5, "y", 3); } EvalFixture::ParamRepo param_repo = make_params(); @@ -28,7 +30,7 @@ void verify_optimized(const vespalib::string &expr, bool inplace) { EvalFixture fixture(prod_factory, expr, param_repo, true, true); EXPECT_EQ(fixture.result(), EvalFixture::ref(expr, param_repo)); EXPECT_EQ(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<DenseSimpleMapFunction>(); + auto info = fixture.find_all<MixedMapFunction>(); ASSERT_EQ(info.size(), 1u); EXPECT_TRUE(info[0]->result_is_mutable()); EXPECT_EQ(info[0]->inplace(), inplace); @@ -45,7 +47,7 @@ void verify_not_optimized(const vespalib::string &expr) { EvalFixture fixture(prod_factory, expr, param_repo, true); EXPECT_EQ(fixture.result(), EvalFixture::ref(expr, param_repo)); EXPECT_EQ(fixture.result(), slow_fixture.result()); - auto info = fixture.find_all<DenseSimpleMapFunction>(); + auto info = fixture.find_all<MixedMapFunction>(); EXPECT_TRUE(info.empty()); } @@ -63,12 +65,20 @@ TEST(MapTest, scalar_map_is_not_optimized) { verify_not_optimized("map(a,f(x)(x+10))"); } -TEST(MapTest, sparse_map_is_not_optimized) { - verify_not_optimized("map(sparse,f(x)(x+10))"); +TEST(MapTest, sparse_map_is_optimized) { + verify_optimized("map(sparse,f(x)(x+10))", false); } -TEST(MapTest, mixed_map_is_not_optimized) { - verify_not_optimized("map(mixed,f(x)(x+10))"); +TEST(MapTest, sparse_map_can_be_inplace) { + verify_optimized("map(@sparse,f(x)(x+10))", true); +} + +TEST(MapTest, mixed_map_is_optimized) { + verify_optimized("map(mixed,f(x)(x+10))", false); +} + +TEST(MapTest, mixed_map_can_be_inplace) { + verify_optimized("map(@mixed,f(x)(x+10))", true); } GTEST_MAIN_RUN_ALL_TESTS() diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp index 0d7a6937c0d..b50e30576a1 100644 --- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp +++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp @@ -68,7 +68,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &factory, c child.set(DenseLambdaPeekOptimizer::optimize(child.get(), stash)); child.set(FastRenameOptimizer::optimize(child.get(), stash)); child.set(PowAsMapOptimizer::optimize(child.get(), stash)); - child.set(DenseSimpleMapFunction::optimize(child.get(), stash)); + child.set(MixedMapFunction::optimize(child.get(), stash)); child.set(DenseSimpleJoinFunction::optimize(child.get(), stash)); child.set(JoinWithNumberFunction::optimize(child.get(), stash)); child.set(DenseSingleReduceFunction::optimize(child.get(), stash)); diff --git a/eval/src/vespa/eval/instruction/dense_simple_map_function.cpp b/eval/src/vespa/eval/instruction/dense_simple_map_function.cpp index ec7d2014436..4ca05b74886 100644 --- a/eval/src/vespa/eval/instruction/dense_simple_map_function.cpp +++ b/eval/src/vespa/eval/instruction/dense_simple_map_function.cpp @@ -36,7 +36,7 @@ void my_simple_map_op(State &state, uint64_t param) { auto dst_cells = make_dst_cells<CT, inplace>(src_cells, state.stash); apply_op1_vec(dst_cells.begin(), src_cells.begin(), dst_cells.size(), my_fun); if (!inplace) { - state.pop_push(state.stash.create<DenseValueView>(child.type(), TypedCells(dst_cells))); + state.pop_push(state.stash.create<ValueView>(child.type(), child.index(), TypedCells(dst_cells))); } } @@ -54,17 +54,17 @@ using MyTypify = TypifyValue<TypifyCellType,TypifyOp1,TypifyBool>; //----------------------------------------------------------------------------- -DenseSimpleMapFunction::DenseSimpleMapFunction(const ValueType &result_type, - const TensorFunction &child, - map_fun_t function_in) +MixedMapFunction::MixedMapFunction(const ValueType &result_type, + const TensorFunction &child, + map_fun_t function_in) : Map(result_type, child, function_in) { } -DenseSimpleMapFunction::~DenseSimpleMapFunction() = default; +MixedMapFunction::~MixedMapFunction() = default; Instruction -DenseSimpleMapFunction::compile_self(const ValueBuilderFactory &, Stash &) const +MixedMapFunction::compile_self(const ValueBuilderFactory &, Stash &) const { auto op = typify_invoke<3,MyTypify,MyGetFun>(result_type().cell_type(), function(), inplace()); static_assert(sizeof(uint64_t) == sizeof(function())); @@ -72,11 +72,11 @@ DenseSimpleMapFunction::compile_self(const ValueBuilderFactory &, Stash &) const } const TensorFunction & -DenseSimpleMapFunction::optimize(const TensorFunction &expr, Stash &stash) +MixedMapFunction::optimize(const TensorFunction &expr, Stash &stash) { if (auto map = as<Map>(expr)) { - if (map->child().result_type().is_dense()) { - return stash.create<DenseSimpleMapFunction>(map->result_type(), map->child(), map->function()); + if (! map->child().result_type().is_scalar()) { + return stash.create<MixedMapFunction>(map->result_type(), map->child(), map->function()); } } return expr; diff --git a/eval/src/vespa/eval/instruction/dense_simple_map_function.h b/eval/src/vespa/eval/instruction/dense_simple_map_function.h index 40432f35c58..3f3b821d15e 100644 --- a/eval/src/vespa/eval/instruction/dense_simple_map_function.h +++ b/eval/src/vespa/eval/instruction/dense_simple_map_function.h @@ -7,17 +7,16 @@ namespace vespalib::eval { /** - * Tensor function for simple map operations on dense tensors. - * TODO: Fix generic map to handle inplace, and remove this. + * Tensor function optimizing map operations on tensors. **/ -class DenseSimpleMapFunction : public tensor_function::Map +class MixedMapFunction : public tensor_function::Map { public: using map_fun_t = operation::op1_t; - DenseSimpleMapFunction(const ValueType &result_type, - const TensorFunction &child, - map_fun_t function_in); - ~DenseSimpleMapFunction() override; + MixedMapFunction(const ValueType &result_type, + const TensorFunction &child, + map_fun_t function_in); + ~MixedMapFunction() override; bool inplace() const { return child().result_is_mutable(); } InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override; static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash); |