// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "dense_simple_expand_function.h" #include #include #include #include #include #include #include namespace vespalib::eval{ using vespalib::ArrayRef; using namespace operation; using namespace tensor_function; using Inner = DenseSimpleExpandFunction::Inner; using op_function = InterpretedFunction::op_function; using Instruction = InterpretedFunction::Instruction; using State = InterpretedFunction::State; namespace { struct ExpandParams { const ValueType &result_type; size_t result_size; join_fun_t function; ExpandParams(const ValueType &result_type_in, size_t result_size_in, join_fun_t function_in) : result_type(result_type_in), result_size(result_size_in), function(function_in) {} }; template void my_simple_expand_op(State &state, uint64_t param) { using ICT = typename std::conditional::type; using OCT = typename std::conditional::type; using OP = typename std::conditional,Fun>::type; const ExpandParams ¶ms = unwrap_param(param); OP my_op(params.function); auto inner_cells = state.peek(rhs_inner ? 0 : 1).cells().typify(); auto outer_cells = state.peek(rhs_inner ? 1 : 0).cells().typify(); auto dst_cells = state.stash.create_array(params.result_size); DCT *dst = dst_cells.begin(); for (OCT outer_cell: outer_cells) { apply_op2_vec_num(dst, inner_cells.begin(), outer_cell, inner_cells.size(), my_op); dst += inner_cells.size(); } state.pop_pop_push(state.stash.create(params.result_type, TypedCells(dst_cells))); } //----------------------------------------------------------------------------- struct SelectDenseSimpleExpand { template static auto invoke() { constexpr CellMeta ocm = CellMeta::join(LCM::value, RCM::value); using LCT = CellValueType; using RCT = CellValueType; using OCT = CellValueType; return my_simple_expand_op; } }; using MyTypify = TypifyValue; //----------------------------------------------------------------------------- std::optional detect_simple_expand(const TensorFunction &lhs, const TensorFunction &rhs) { std::vector a = lhs.result_type().nontrivial_indexed_dimensions(); std::vector b = rhs.result_type().nontrivial_indexed_dimensions(); if (a.empty() || b.empty()) { return std::nullopt; } else if (a.back().name < b.front().name) { return Inner::RHS; } else if (b.back().name < a.front().name) { return Inner::LHS; } else { return std::nullopt; } } } // namespace //----------------------------------------------------------------------------- DenseSimpleExpandFunction::DenseSimpleExpandFunction(const ValueType &result_type, const TensorFunction &lhs, const TensorFunction &rhs, join_fun_t function_in, Inner inner_in) : Join(result_type, lhs, rhs, function_in), _inner(inner_in) { } DenseSimpleExpandFunction::~DenseSimpleExpandFunction() = default; Instruction DenseSimpleExpandFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const { size_t result_size = result_type().dense_subspace_size(); const ExpandParams ¶ms = stash.create(result_type(), result_size, function()); auto op = typify_invoke<4,MyTypify,SelectDenseSimpleExpand>(lhs().result_type().cell_meta().not_scalar(), rhs().result_type().cell_meta().not_scalar(), function(), (_inner == Inner::RHS)); return Instruction(op, wrap_param(params)); } const TensorFunction & DenseSimpleExpandFunction::optimize(const TensorFunction &expr, Stash &stash) { if (auto join = as(expr)) { const TensorFunction &lhs = join->lhs(); const TensorFunction &rhs = join->rhs(); if (lhs.result_type().is_dense() && rhs.result_type().is_dense()) { if (std::optional inner = detect_simple_expand(lhs, rhs)) { assert(expr.result_type().dense_subspace_size() == (lhs.result_type().dense_subspace_size() * rhs.result_type().dense_subspace_size())); return stash.create(join->result_type(), lhs, rhs, join->function(), inner.value()); } } } return expr; } } // namespace