// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #include "dense_dot_product_function.h" #include #include namespace vespalib::eval { using namespace tensor_function; using namespace operation; namespace { template void my_dot_product_op(InterpretedFunction::State &state, uint64_t) { auto lhs_cells = state.peek(1).cells().typify(); auto rhs_cells = state.peek(0).cells().typify(); double result = DotProduct::apply(lhs_cells.cbegin(), rhs_cells.cbegin(), lhs_cells.size()); state.pop_pop_push(state.stash.create(result)); } struct MyDotProductOp { template static auto invoke() { return my_dot_product_op; } }; } // namespace DenseDotProductFunction::DenseDotProductFunction(const TensorFunction &lhs_in, const TensorFunction &rhs_in) : tensor_function::Op2(ValueType::double_type(), lhs_in, rhs_in) { } InterpretedFunction::Instruction DenseDotProductFunction::compile_self(const ValueBuilderFactory &, Stash &) const { auto op = typify_invoke<2,TypifyCellType,MyDotProductOp>(lhs().result_type().cell_type(), rhs().result_type().cell_type()); return InterpretedFunction::Instruction(op); } bool DenseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) { return (res.is_double() && lhs.is_dense() && (rhs.dimensions() == lhs.dimensions())); } const TensorFunction & DenseDotProductFunction::optimize(const TensorFunction &expr, Stash &stash) { auto reduce = as(expr); if (reduce && (reduce->aggr() == Aggr::SUM)) { auto join = as(reduce->child()); if (join && (join->function() == Mul::f)) { const TensorFunction &lhs = join->lhs(); const TensorFunction &rhs = join->rhs(); if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) { return stash.create(lhs, rhs); } } } return expr; } } // namespace