diff options
12 files changed, 179 insertions, 84 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt index f063faf19c3..00ab5b347ea 100644 --- a/eval/CMakeLists.txt +++ b/eval/CMakeLists.txt @@ -27,7 +27,7 @@ vespa_define_module( src/tests/tensor/dense_dot_product_function src/tests/tensor/dense_tensor_address_combiner src/tests/tensor/dense_tensor_builder - src/tests/tensor/dense_tensor_function_compiler + src/tests/tensor/dense_tensor_function_optimizer src/tests/tensor/dense_xw_product_function src/tests/tensor/sparse_tensor_builder src/tests/tensor/tensor_address diff --git a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp index 681a4dabc19..641ebddfec2 100644 --- a/eval/src/tests/eval/tensor_function/tensor_function_test.cpp +++ b/eval/src/tests/eval/tensor_function/tensor_function_test.cpp @@ -162,4 +162,28 @@ TEST("require that tensor join works") { TEST_DO(verify_equal(*expect, ctx.eval(prog))); } +TEST("require that push_children works") { + Stash stash; + std::vector<Node::Child::CREF> refs; + const Node &a = inject(ValueType::double_type(), 0, stash); + const Node &b = inject(ValueType::double_type(), 1, stash); + a.push_children(refs); + b.push_children(refs); + ASSERT_EQUAL(refs.size(), 0u); + //------------------------------------------------------------------------- + reduce(a, Aggr::SUM, {}, stash).push_children(refs); + ASSERT_EQUAL(refs.size(), 1u); + EXPECT_EQUAL(&refs[0].get().get(), &a); + //------------------------------------------------------------------------- + map(b, operation::Neg::f, stash).push_children(refs); + ASSERT_EQUAL(refs.size(), 2u); + EXPECT_EQUAL(&refs[1].get().get(), &b); + //------------------------------------------------------------------------- + join(a, b, operation::Add::f, stash).push_children(refs); + ASSERT_EQUAL(refs.size(), 4u); + EXPECT_EQUAL(&refs[2].get().get(), &a); + EXPECT_EQUAL(&refs[3].get().get(), &b); + //------------------------------------------------------------------------- +} + TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/tests/tensor/dense_tensor_function_compiler/CMakeLists.txt b/eval/src/tests/tensor/dense_tensor_function_compiler/CMakeLists.txt deleted file mode 100644 index b49a439b0ab..00000000000 --- a/eval/src/tests/tensor/dense_tensor_function_compiler/CMakeLists.txt +++ /dev/null @@ -1,8 +0,0 @@ -# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -vespa_add_executable(eval_dense_tensor_function_compiler_test_app TEST - SOURCES - dense_tensor_function_compiler_test.cpp - DEPENDS - vespaeval -) -vespa_add_test(NAME eval_dense_tensor_function_compiler_test_app COMMAND eval_dense_tensor_function_compiler_test_app) diff --git a/eval/src/tests/tensor/dense_tensor_function_optimizer/CMakeLists.txt b/eval/src/tests/tensor/dense_tensor_function_optimizer/CMakeLists.txt new file mode 100644 index 00000000000..3a95ef776d7 --- /dev/null +++ b/eval/src/tests/tensor/dense_tensor_function_optimizer/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_dense_tensor_function_optimizer_test_app TEST + SOURCES + dense_tensor_function_optimizer_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_dense_tensor_function_optimizer_test_app COMMAND eval_dense_tensor_function_optimizer_test_app) diff --git a/eval/src/tests/tensor/dense_tensor_function_compiler/FILES b/eval/src/tests/tensor/dense_tensor_function_optimizer/FILES index 3c4ec2f1753..3c4ec2f1753 100644 --- a/eval/src/tests/tensor/dense_tensor_function_compiler/FILES +++ b/eval/src/tests/tensor/dense_tensor_function_optimizer/FILES diff --git a/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp b/eval/src/tests/tensor/dense_tensor_function_optimizer/dense_tensor_function_optimizer_test.cpp index 7df436d85a1..57d03c09686 100644 --- a/eval/src/tests/tensor/dense_tensor_function_compiler/dense_tensor_function_compiler_test.cpp +++ b/eval/src/tests/tensor/dense_tensor_function_optimizer/dense_tensor_function_optimizer_test.cpp @@ -3,7 +3,7 @@ #include <vespa/vespalib/testkit/test_kit.h> #include <vespa/eval/tensor/dense/dense_dot_product_function.h> #include <vespa/eval/tensor/dense/dense_xw_product_function.h> -#include <vespa/eval/tensor/dense/dense_tensor_function_compiler.h> +#include <vespa/eval/tensor/dense/dense_tensor_function_optimizer.h> #include <vespa/eval/eval/operation.h> using namespace vespalib::eval; @@ -15,7 +15,7 @@ using vespalib::Stash; //----------------------------------------------------------------------------- const TensorFunction & -compileDotProduct(const vespalib::string &lhsType, +optimizeDotProduct(const vespalib::string &lhsType, const vespalib::string &rhsType, Stash &stash) { @@ -23,15 +23,15 @@ compileDotProduct(const vespalib::string &lhsType, inject(ValueType::from_spec(rhsType), 3, stash), Mul::f, stash), Aggr::SUM, {}, stash); - return DenseTensorFunctionCompiler::compile(reduceNode, stash); + return DenseTensorFunctionOptimizer::optimize(reduceNode, stash); } void -assertCompiledDotProduct(const vespalib::string &lhsType, +assertOptimizedDotProduct(const vespalib::string &lhsType, const vespalib::string &rhsType) { Stash stash; - const TensorFunction &func = compileDotProduct(lhsType, rhsType, stash); + const TensorFunction &func = optimizeDotProduct(lhsType, rhsType, stash); const DenseDotProductFunction *dotProduct = as<DenseDotProductFunction>(func); ASSERT_TRUE(dotProduct); EXPECT_EQUAL(1u, dotProduct->lhsTensorId()); @@ -39,11 +39,11 @@ assertCompiledDotProduct(const vespalib::string &lhsType, } void -assertNotCompiledDotProduct(const vespalib::string &lhsType, +assertNotOptimizedDotProduct(const vespalib::string &lhsType, const vespalib::string &rhsType) { Stash stash; - const TensorFunction &func = compileDotProduct(lhsType, rhsType, stash); + const TensorFunction &func = optimizeDotProduct(lhsType, rhsType, stash); const Reduce *reduce = as<Reduce>(func); EXPECT_TRUE(reduce); } @@ -51,7 +51,7 @@ assertNotCompiledDotProduct(const vespalib::string &lhsType, //----------------------------------------------------------------------------- const TensorFunction & -compileXWProduct(const vespalib::string &lhsType, +optimizeXWProduct(const vespalib::string &lhsType, const vespalib::string &rhsType, const vespalib::string &dim, Stash &stash) @@ -60,17 +60,17 @@ compileXWProduct(const vespalib::string &lhsType, inject(ValueType::from_spec(rhsType), 3, stash), Mul::f, stash), Aggr::SUM, {dim}, stash); - return DenseTensorFunctionCompiler::compile(reduceNode, stash); + return DenseTensorFunctionOptimizer::optimize(reduceNode, stash); } void -assertCompiledXWProduct(const vespalib::string &vecTypeStr, +assertOptimizedXWProduct(const vespalib::string &vecTypeStr, const vespalib::string &matTypeStr, const vespalib::string &dim) { Stash stash; - const TensorFunction &func = compileXWProduct(vecTypeStr, matTypeStr, dim, stash); - const TensorFunction &inv_func = compileXWProduct(matTypeStr, vecTypeStr, dim, stash); + const TensorFunction &func = optimizeXWProduct(vecTypeStr, matTypeStr, dim, stash); + const TensorFunction &inv_func = optimizeXWProduct(matTypeStr, vecTypeStr, dim, stash); const DenseXWProductFunction *xwProduct = as<DenseXWProductFunction>(func); const DenseXWProductFunction *inv_xwProduct = as<DenseXWProductFunction>(inv_func); ValueType vecType = ValueType::from_spec(vecTypeStr); @@ -92,13 +92,13 @@ assertCompiledXWProduct(const vespalib::string &vecTypeStr, } void -assertNotCompiledXWProduct(const vespalib::string &vecType, +assertNotOptimizedXWProduct(const vespalib::string &vecType, const vespalib::string &matType, const vespalib::string &dim) { Stash stash; - const TensorFunction &func = compileXWProduct(vecType, matType, dim, stash); - const TensorFunction &inv_func = compileXWProduct(matType, vecType, dim, stash); + const TensorFunction &func = optimizeXWProduct(vecType, matType, dim, stash); + const TensorFunction &inv_func = optimizeXWProduct(matType, vecType, dim, stash); const Reduce *reduce = as<Reduce>(func); const Reduce *inv_reduce = as<Reduce>(inv_func); EXPECT_TRUE(reduce); @@ -107,45 +107,45 @@ assertNotCompiledXWProduct(const vespalib::string &vecType, //----------------------------------------------------------------------------- -TEST("require that dot product with compatible dimensions is compiled") +TEST("require that dot product with compatible dimensions is optimized") { - TEST_DO(assertCompiledDotProduct("tensor(x[5])", "tensor(x[5])")); - TEST_DO(assertCompiledDotProduct("tensor(x[3])", "tensor(x[5])")); - TEST_DO(assertCompiledDotProduct("tensor(x[5])", "tensor(x[3])")); - TEST_DO(assertCompiledDotProduct("tensor(x[])", "tensor(x[5])")); - TEST_DO(assertCompiledDotProduct("tensor(x[5])", "tensor(x[])")); - TEST_DO(assertCompiledDotProduct("tensor(x[])", "tensor(x[])")); + TEST_DO(assertOptimizedDotProduct("tensor(x[5])", "tensor(x[5])")); + TEST_DO(assertOptimizedDotProduct("tensor(x[3])", "tensor(x[5])")); + TEST_DO(assertOptimizedDotProduct("tensor(x[5])", "tensor(x[3])")); + TEST_DO(assertOptimizedDotProduct("tensor(x[])", "tensor(x[5])")); + TEST_DO(assertOptimizedDotProduct("tensor(x[5])", "tensor(x[])")); + TEST_DO(assertOptimizedDotProduct("tensor(x[])", "tensor(x[])")); } -TEST("require that dot product with incompatible dimensions is NOT compiled") +TEST("require that dot product with incompatible dimensions is NOT optimized") { - TEST_DO(assertNotCompiledDotProduct("tensor(x[5])", "tensor(y[5])")); - TEST_DO(assertNotCompiledDotProduct("tensor(y[5])", "tensor(x[5])")); - TEST_DO(assertNotCompiledDotProduct("tensor(y[])", "tensor(x[])")); - TEST_DO(assertNotCompiledDotProduct("tensor(x[5])", "tensor(x[5],y[7])")); - TEST_DO(assertNotCompiledDotProduct("tensor(x[5],y[7])", "tensor(x[5],y[7])")); + TEST_DO(assertNotOptimizedDotProduct("tensor(x[5])", "tensor(y[5])")); + TEST_DO(assertNotOptimizedDotProduct("tensor(y[5])", "tensor(x[5])")); + TEST_DO(assertNotOptimizedDotProduct("tensor(y[])", "tensor(x[])")); + TEST_DO(assertNotOptimizedDotProduct("tensor(x[5])", "tensor(x[5],y[7])")); + TEST_DO(assertNotOptimizedDotProduct("tensor(x[5],y[7])", "tensor(x[5],y[7])")); } //----------------------------------------------------------------------------- -TEST("require that xw products with compatible dimensions are compiled") { - TEST_DO(assertCompiledXWProduct("tensor(x[3])", "tensor(x[3],y[4])", "x")); - TEST_DO(assertCompiledXWProduct("tensor(y[4])", "tensor(x[3],y[4])", "y")); +TEST("require that xw products with compatible dimensions are optimized") { + TEST_DO(assertOptimizedXWProduct("tensor(x[3])", "tensor(x[3],y[4])", "x")); + TEST_DO(assertOptimizedXWProduct("tensor(y[4])", "tensor(x[3],y[4])", "y")); } -TEST("require that xw products with incompatible dimensions are not compiled") { - TEST_DO(assertNotCompiledXWProduct("tensor(x[3])", "tensor(x[3],y[4])", "y")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[])", "tensor(x[3],y[4])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[3])", "tensor(x[],y[4])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[3])", "tensor(x[3],y[])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[2])", "tensor(x[3],y[4])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[4])", "tensor(x[3],y[4])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[3])", "tensor(y[3],z[4])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[3])", "tensor(y[3],z[4])", "y")); - TEST_DO(assertNotCompiledXWProduct("tensor(x[3])", "tensor(y[3],z[4])", "z")); - TEST_DO(assertNotCompiledXWProduct("tensor(y[4])", "tensor(x[3],y[4])", "x")); - TEST_DO(assertNotCompiledXWProduct("tensor(y[3])", "tensor(x[3],y[4])", "y")); - TEST_DO(assertNotCompiledXWProduct("tensor(y[5])", "tensor(x[3],y[4])", "y")); +TEST("require that xw products with incompatible dimensions are not optimized") { + TEST_DO(assertNotOptimizedXWProduct("tensor(x[3])", "tensor(x[3],y[4])", "y")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[])", "tensor(x[3],y[4])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[3])", "tensor(x[],y[4])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[3])", "tensor(x[3],y[])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[2])", "tensor(x[3],y[4])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[4])", "tensor(x[3],y[4])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[3])", "tensor(y[3],z[4])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[3])", "tensor(y[3],z[4])", "y")); + TEST_DO(assertNotOptimizedXWProduct("tensor(x[3])", "tensor(y[3],z[4])", "z")); + TEST_DO(assertNotOptimizedXWProduct("tensor(y[4])", "tensor(x[3],y[4])", "x")); + TEST_DO(assertNotOptimizedXWProduct("tensor(y[3])", "tensor(x[3],y[4])", "y")); + TEST_DO(assertNotOptimizedXWProduct("tensor(y[5])", "tensor(x[3],y[4])", "y")); } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/eval/tensor_function.cpp b/eval/src/vespa/eval/eval/tensor_function.cpp index 9cd7c7fc9c2..763f1cc39ff 100644 --- a/eval/src/vespa/eval/eval/tensor_function.cpp +++ b/eval/src/vespa/eval/eval/tensor_function.cpp @@ -28,31 +28,61 @@ Inject::eval(ConstArrayRef<Value::CREF> params, Stash &) const return params[tensor_id]; } +void +Inject::push_children(std::vector<Child::CREF> &) const +{ +} + +//----------------------------------------------------------------------------- + const Value & Reduce::eval(ConstArrayRef<Value::CREF> params, Stash &stash) const { - const Value &a = tensor.eval(params, stash); + const Value &a = tensor.get().eval(params, stash); const TensorEngine &engine = infer_engine({a}); return engine.reduce(a, aggr, dimensions, stash); } +void +Reduce::push_children(std::vector<Child::CREF> &children) const +{ + children.emplace_back(tensor); +} + +//----------------------------------------------------------------------------- + const Value & Map::eval(ConstArrayRef<Value::CREF> params, Stash &stash) const { - const Value &a = tensor.eval(params, stash); + const Value &a = tensor.get().eval(params, stash); const TensorEngine &engine = infer_engine({a}); return engine.map(a, function, stash); } +void +Map::push_children(std::vector<Child::CREF> &children) const +{ + children.emplace_back(tensor); +} + +//----------------------------------------------------------------------------- + const Value & Join::eval(ConstArrayRef<Value::CREF> params, Stash &stash) const { - const Value &a = lhs_tensor.eval(params, stash); - const Value &b = rhs_tensor.eval(params, stash); + const Value &a = lhs_tensor.get().eval(params, stash); + const Value &b = rhs_tensor.get().eval(params, stash); const TensorEngine &engine = infer_engine({a,b}); return engine.join(a, b, function, stash); } +void +Join::push_children(std::vector<Child::CREF> &children) const +{ + children.emplace_back(lhs_tensor); + children.emplace_back(rhs_tensor); +} + //----------------------------------------------------------------------------- const Node &inject(const ValueType &type, size_t tensor_id, Stash &stash) { diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h index e23dc8c6fc0..4b0db486971 100644 --- a/eval/src/vespa/eval/eval/tensor_function.h +++ b/eval/src/vespa/eval/eval/tensor_function.h @@ -70,15 +70,36 @@ using join_fun_t = double (*)(double, double); * will invoke the immediate API on the tensor engine associated with * the input tensors. In other words, the intermediate representation * 'compiles to itself'. + * + * The reason for using the top-level TensorFunction interface when + * referencing downwards in the tree is to enable mixed-mode execution + * resulting from partial optimization where the intermediate + * representation is partially replaced by implementation-specific + * tensor functions, which may or may not rely on lower-level tensor + * functions that may in turn be mixed-mode. **/ struct Node : public TensorFunction { + /** + * Reference to a sub-tree. References are replaceable to enable + * in-place bottom-up optimization during compilation. + **/ + class Child { + private: + mutable const TensorFunction *ptr; + public: + using CREF = std::reference_wrapper<const Child>; + Child(const TensorFunction &child) : ptr(&child) {} + const TensorFunction &get() const { return *ptr; } + void set(const TensorFunction &child) const { ptr = &child; } + }; const ValueType result_type; Node(const ValueType &result_type_in) : result_type(result_type_in) {} Node(const Node &) = delete; Node &operator=(const Node &) = delete; Node(Node &&) = delete; Node &operator=(Node &&) = delete; + virtual void push_children(std::vector<Child::CREF> &children) const = 0; }; struct Inject : Node { @@ -87,41 +108,45 @@ struct Inject : Node { size_t tensor_id_in) : Node(result_type_in), tensor_id(tensor_id_in) {} const Value &eval(ConstArrayRef<Value::CREF> params, Stash &) const override; + void push_children(std::vector<Child::CREF> &children) const override; }; struct Reduce : Node { - const Node &tensor; + Child tensor; const Aggr aggr; const std::vector<vespalib::string> dimensions; Reduce(const ValueType &result_type_in, - const Node &tensor_in, + const TensorFunction &tensor_in, Aggr aggr_in, const std::vector<vespalib::string> &dimensions_in) : Node(result_type_in), tensor(tensor_in), aggr(aggr_in), dimensions(dimensions_in) {} const Value &eval(ConstArrayRef<Value::CREF> params, Stash &stash) const override; + void push_children(std::vector<Child::CREF> &children) const override; }; struct Map : Node { - const Node &tensor; + Child tensor; const map_fun_t function; Map(const ValueType &result_type_in, - const Node &tensor_in, + const TensorFunction &tensor_in, map_fun_t function_in) : Node(result_type_in), tensor(tensor_in), function(function_in) {} const Value &eval(ConstArrayRef<Value::CREF> params, Stash &stash) const override; + void push_children(std::vector<Child::CREF> &children) const override; }; struct Join : Node { - const Node &lhs_tensor; - const Node &rhs_tensor; + Child lhs_tensor; + Child rhs_tensor; const join_fun_t function; Join(const ValueType &result_type_in, - const Node &lhs_tensor_in, - const Node &rhs_tensor_in, + const TensorFunction &lhs_tensor_in, + const TensorFunction &rhs_tensor_in, join_fun_t function_in) : Node(result_type_in), lhs_tensor(lhs_tensor_in), rhs_tensor(rhs_tensor_in), function(function_in) {} const Value &eval(ConstArrayRef<Value::CREF> params, Stash &stash) const override; + void push_children(std::vector<Child::CREF> &children) const override; }; const Node &inject(const ValueType &type, size_t tensor_id, Stash &stash); diff --git a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp index 773d2364b7d..c9f3be9d588 100644 --- a/eval/src/vespa/eval/tensor/default_tensor_engine.cpp +++ b/eval/src/vespa/eval/tensor/default_tensor_engine.cpp @@ -7,7 +7,7 @@ #include "serialization/typed_binary_format.h" #include "dense/dense_tensor.h" #include "dense/dense_tensor_builder.h" -#include "dense/dense_tensor_function_compiler.h" +#include "dense/dense_tensor_function_optimizer.h" #include <vespa/eval/eval/value.h> #include <vespa/eval/eval/tensor_spec.h> #include <vespa/eval/eval/simple_tensor_engine.h> @@ -208,7 +208,22 @@ DefaultTensorEngine::decode(nbostream &input) const const TensorFunction & DefaultTensorEngine::compile(const eval::tensor_function::Node &expr, Stash &stash) const { - return DenseTensorFunctionCompiler::compile(expr, stash); + using Node = eval::tensor_function::Node; + using Child = Node::Child; + Child root(expr); + std::vector<Child::CREF> nodes({root}); + for (size_t i = 0; i < nodes.size(); ++i) { + const Child &child = nodes[i]; + const Node *node = dynamic_cast<const Node *>(&child.get()); + assert(node != nullptr); + node->push_children(nodes); + } + while (!nodes.empty()) { + const Child &child = nodes.back(); + child.set(DenseTensorFunctionOptimizer::optimize(child.get(), stash)); + nodes.pop_back(); + } + return root.get(); } //----------------------------------------------------------------------------- diff --git a/eval/src/vespa/eval/tensor/dense/CMakeLists.txt b/eval/src/vespa/eval/tensor/dense/CMakeLists.txt index 97343ffd380..1fa839ca4b2 100644 --- a/eval/src/vespa/eval/tensor/dense/CMakeLists.txt +++ b/eval/src/vespa/eval/tensor/dense/CMakeLists.txt @@ -8,7 +8,7 @@ vespa_add_library(eval_tensor_dense OBJECT dense_tensor_address_combiner.cpp dense_tensor_builder.cpp dense_tensor_cells_iterator.cpp - dense_tensor_function_compiler.cpp + dense_tensor_function_optimizer.cpp dense_tensor_view.cpp dense_tensor_reduce.cpp mutable_dense_tensor_view.cpp diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp index 22e2a3fb78c..bd57db009b9 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.cpp +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.cpp @@ -2,7 +2,7 @@ #include "dense_dot_product_function.h" #include "dense_xw_product_function.h" -#include "dense_tensor_function_compiler.h" +#include "dense_tensor_function_optimizer.h" #include <vespa/eval/eval/operation.h> #include <vespa/vespalib/test/insertion_operators.h> #include <iostream> @@ -54,24 +54,25 @@ const TensorFunction &createDenseXWProduct(const ValueType &res, const Inject &v common_is_inner); } -struct InnerProductFunctionCompiler +struct InnerProductFunctionOptimizer { - static const TensorFunction &compile(const Node &expr, Stash &stash) { + static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash) { const Reduce *reduce = as<Reduce>(expr); if (reduce && (reduce->aggr == Aggr::SUM)) { - const Join *join = as<Join>(reduce->tensor); + const ValueType &result_type = reduce->result_type; + const Join *join = as<Join>(reduce->tensor.get()); if (join && (join->function == Mul::f)) { - const Inject *lhs = as<Inject>(join->lhs_tensor); - const Inject *rhs = as<Inject>(join->rhs_tensor); + const Inject *lhs = as<Inject>(join->lhs_tensor.get()); + const Inject *rhs = as<Inject>(join->rhs_tensor.get()); if (lhs && rhs) { - if (isDenseDotProduct(expr.result_type, lhs->result_type, rhs->result_type)) { + if (isDenseDotProduct(result_type, lhs->result_type, rhs->result_type)) { return stash.create<DenseDotProductFunction>(lhs->tensor_id, rhs->tensor_id); } - if (isDenseXWProduct(expr.result_type, lhs->result_type, rhs->result_type)) { - return createDenseXWProduct(expr.result_type, *lhs, *rhs, stash); + if (isDenseXWProduct(result_type, lhs->result_type, rhs->result_type)) { + return createDenseXWProduct(result_type, *lhs, *rhs, stash); } - if (isDenseXWProduct(expr.result_type, rhs->result_type, lhs->result_type)) { - return createDenseXWProduct(expr.result_type, *rhs, *lhs, stash); + if (isDenseXWProduct(result_type, rhs->result_type, lhs->result_type)) { + return createDenseXWProduct(result_type, *rhs, *lhs, stash); } } } @@ -83,9 +84,9 @@ struct InnerProductFunctionCompiler } const TensorFunction & -DenseTensorFunctionCompiler::compile(const eval::tensor_function::Node &expr, Stash &stash) +DenseTensorFunctionOptimizer::optimize(const eval::TensorFunction &expr, Stash &stash) { - return InnerProductFunctionCompiler::compile(expr, stash); + return InnerProductFunctionOptimizer::optimize(expr, stash); } } diff --git a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.h index 61c3af079e3..2478447ca48 100644 --- a/eval/src/vespa/eval/tensor/dense/dense_tensor_function_compiler.h +++ b/eval/src/vespa/eval/tensor/dense/dense_tensor_function_optimizer.h @@ -10,11 +10,11 @@ namespace vespalib::tensor { /** * Class that recognizes calculations over dense tensors (in tensor function intermediate representation) - * and compiles this into an explicit tensor function. + * and optimizes this into an explicit tensor function. */ -struct DenseTensorFunctionCompiler +struct DenseTensorFunctionOptimizer { - static const eval::TensorFunction &compile(const eval::tensor_function::Node &expr, Stash &stash); + static const eval::TensorFunction &optimize(const eval::TensorFunction &expr, Stash &stash); }; } |