summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-08-30 14:53:19 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-09-01 14:06:11 +0000
commit7385e6f7ca39c8eec1207b63ba4395ab38839833 (patch)
tree6ba1c4bd3d7f43665ba323c6f07f5089d1b2d78d /eval
parentd70ed3ee222320cff442c93878e6f83a73c4bd61 (diff)
benchmark universal dot product vs other options
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp258
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.cpp2
-rw-r--r--eval/src/vespa/eval/eval/interpreted_function.h1
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp43
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.h13
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp25
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.cpp14
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.h5
8 files changed, 336 insertions, 25 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index 3f60ad69b86..95eb7b406e6 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -4,11 +4,19 @@
#include <vespa/eval/eval/value_codec.h>
#include <vespa/eval/eval/interpreted_function.h>
#include <vespa/eval/eval/tensor_function.h>
+#include <vespa/eval/eval/lazy_params.h>
+#include <vespa/eval/eval/make_tensor_function.h>
+#include <vespa/eval/eval/optimize_tensor_function.h>
+#include <vespa/eval/eval/compile_tensor_function.h>
#include <vespa/eval/instruction/universal_dot_product.h>
#include <vespa/eval/eval/test/reference_operations.h>
+#include <vespa/eval/eval/test/reference_evaluation.h>
#include <vespa/eval/eval/test/gen_spec.h>
+#include <vespa/vespalib/util/benchmark_timer.h>
+#include <vespa/vespalib/util/classname.h>
#include <vespa/vespalib/util/stringfmt.h>
#include <vespa/vespalib/gtest/gtest.h>
+#include <optional>
using namespace vespalib;
using namespace vespalib::eval;
@@ -17,6 +25,8 @@ using namespace vespalib::eval::test;
using vespalib::make_string_short::fmt;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+bool bench = false;
+double budget = 1.0;
GenSpec::seq_t N_16ths = [] (size_t i) noexcept { return (i + 33.0) / 16.0; };
@@ -43,6 +53,169 @@ const std::vector<std::vector<vespalib::string>> reductions = {
{}, {"x"}, {"y"}, {"z"}, {"x", "y"}, {"x", "z"}, {"y", "z"}
};
+std::vector<std::string> ns_list = {
+ {"vespalib::eval::instruction::(anonymous namespace)::"},
+ {"vespalib::eval::(anonymous namespace)::"},
+ {"vespalib::eval::InterpretedFunction::"},
+ {"vespalib::eval::tensor_function::"},
+ {"vespalib::eval::operation::"},
+ {"vespalib::eval::aggr::"},
+ {"vespalib::eval::"}
+};
+std::string strip_ns(const vespalib::string &str) {
+ std::string tmp = str;
+ for (const auto &ns: ns_list) {
+ for (bool again = true; again;) {
+ again = false;
+ if (auto pos = tmp.find(ns); pos < tmp.size()) {
+ tmp.erase(pos, ns.size());
+ again = true;
+ }
+ }
+ }
+ return tmp;
+}
+
+TensorSpec make_spec(const vespalib::string &param_name, size_t idx) {
+ return GenSpec::from_desc(param_name).cells_double().seq(N(1 + idx));
+}
+
+TensorSpec eval_ref(const Function &fun) {
+ std::vector<TensorSpec> params;
+ for (size_t i = 0; i < fun.num_params(); ++i) {
+ params.push_back(make_spec(fun.param_name(i), i));
+ }
+ return ReferenceEvaluation::eval(fun, params);
+}
+
+class Optimize
+{
+private:
+ struct ctor_tag{};
+public:
+ enum class With { NONE, CUSTOM, PROD, SPECIFIC };
+ With with;
+ vespalib::string name;
+ OptimizeTensorFunctionOptions options;
+ tensor_function_optimizer optimizer;
+ Optimize(With with_in, const vespalib::string name_in,
+ const OptimizeTensorFunctionOptions &options_in,
+ tensor_function_optimizer optimizer_in, ctor_tag)
+ : with(with_in), name(name_in), options(options_in), optimizer(optimizer_in) {}
+ static Optimize none() { return {With::NONE, "none", {}, {}, {}}; }
+ static Optimize prod() { return {With::PROD, "prod", {}, {}, {}}; }
+ static Optimize custom(const vespalib::string &name_in, const OptimizeTensorFunctionOptions &options_in) {
+ return {With::CUSTOM, name_in, options_in, {}, {}};
+ }
+ static Optimize specific(const vespalib::string &name_in, tensor_function_optimizer optimizer_in) {
+ return {With::SPECIFIC, name_in, {}, optimizer_in, {}};
+ }
+ ~Optimize();
+};
+Optimize::~Optimize() = default;
+
+Optimize baseline() {
+ OptimizeTensorFunctionOptions my_options;
+ my_options.allow_universal_dot_product = false;
+ return Optimize::custom("baseline", my_options);
+}
+
+Optimize with_universal() {
+ OptimizeTensorFunctionOptions my_options;
+ my_options.allow_universal_dot_product = true;
+ return Optimize::custom("with_universal", my_options);
+}
+
+Optimize universal_only() {
+ auto my_optimizer = [](const TensorFunction &expr, Stash &stash)->const TensorFunction &
+ {
+ return UniversalDotProduct::optimize(expr, stash, true);
+ };
+ return Optimize::specific("universal_only", my_optimizer);
+}
+
+using cost_map_t = std::map<vespalib::string,double>;
+std::vector<std::pair<vespalib::string,cost_map_t>> benchmark_results;
+
+void benchmark(const vespalib::string &desc, const vespalib::string &expr, std::vector<Optimize> list) {
+ auto fun = Function::parse(expr);
+ ASSERT_FALSE(fun->has_error());
+ auto expected = eval_ref(*fun);
+ cost_map_t cost_map;
+ fprintf(stderr, "BENCH: %s (%s)\n", desc.c_str(), expr.c_str());
+ for (Optimize &optimize: list) {
+ std::vector<Value::UP> values;
+ for (size_t i = 0; i < fun->num_params(); ++i) {
+ auto value = value_from_spec(make_spec(fun->param_name(i), i), prod_factory);
+ values.push_back(std::move(value));
+ }
+ SimpleObjectParams params({});
+ std::vector<ValueType> param_types;
+ for (auto &&up: values) {
+ params.params.emplace_back(*up);
+ param_types.push_back(up->type());
+ }
+ NodeTypes node_types(*fun, param_types);
+ ASSERT_FALSE(node_types.get_type(fun->root()).is_error());
+ Stash stash;
+ const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash);
+ const TensorFunction *optimized = nullptr;
+ switch (optimize.with) {
+ case Optimize::With::NONE:
+ optimized = std::addressof(plain_fun);
+ break;
+ case Optimize::With::PROD:
+ optimized = std::addressof(optimize_tensor_function(prod_factory, plain_fun, stash));
+ break;
+ case Optimize::With::CUSTOM:
+ optimized = std::addressof(optimize_tensor_function(prod_factory, plain_fun, stash, optimize.options));
+ break;
+ case Optimize::With::SPECIFIC:
+ size_t count = 0;
+ optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, &count));
+ ASSERT_GT(count, 0);
+ break;
+ }
+ ASSERT_NE(optimized, nullptr);
+ CTFMetaData ctf_meta;
+ InterpretedFunction ifun(prod_factory, *optimized, &ctf_meta);
+ ASSERT_EQ(ctf_meta.steps.size(), ifun.program_size());
+ BenchmarkTimer timer(budget);
+ std::vector<duration> prev_time(ctf_meta.steps.size(), duration::zero());
+ std::vector<duration> min_time(ctf_meta.steps.size(), duration::max());
+ InterpretedFunction::ProfiledContext pctx(ifun);
+ for (bool first = true; timer.has_budget(); first = false) {
+ const Value &profiled_result = ifun.eval(pctx, params);
+ if (first) {
+ EXPECT_EQ(spec_from_value(profiled_result), expected);
+ }
+ timer.before();
+ const Value &result = ifun.eval(pctx.context, params);
+ timer.after();
+ if (first) {
+ EXPECT_EQ(spec_from_value(result), expected);
+ }
+ for (size_t i = 0; i < ctf_meta.steps.size(); ++i) {
+ min_time[i] = std::min(min_time[i], pctx.cost[i].second - prev_time[i]);
+ prev_time[i] = pctx.cost[i].second;
+ }
+ }
+ double cost_us = timer.min_time() * 1000.0 * 1000.0;
+ cost_map.emplace(optimize.name, cost_us);
+ fprintf(stderr, " optimized with: %s: %g us {\n", optimize.name.c_str(), cost_us);
+ for (size_t i = 0; i < ctf_meta.steps.size(); ++i) {
+ auto name = strip_ns(ctf_meta.steps[i].class_name);
+ if (name.find("Inject") > name.size() && name.find("ConstValue") > name.size()) {
+ fprintf(stderr, " %s: %zu ns\n", name.c_str(), count_ns(min_time[i]));
+ fprintf(stderr, " +-- %s\n", strip_ns(ctf_meta.steps[i].symbol_name).c_str());
+ }
+ }
+ fprintf(stderr, " }\n");
+ }
+ fprintf(stderr, "\n");
+ benchmark_results.emplace_back(desc, std::move(cost_map));
+}
+
TensorSpec perform_dot_product(const TensorSpec &a, const TensorSpec &b, const std::vector<vespalib::string> &dims)
{
Stash stash;
@@ -86,4 +259,87 @@ TEST(UniversalDotProductTest, generic_dot_product_works_for_various_cases) {
fprintf(stderr, "total test cases run: %zu\n", test_cases);
}
-GTEST_MAIN_RUN_ALL_TESTS()
+TEST(UniversalDotProductTest, bench_vector_dot_product) {
+ if (!bench) {
+ fprintf(stderr, "benchmarking disabled, run with 'bench' parameter to enable\n");
+ return;
+ }
+ auto optimize_list = std::vector<Optimize>({baseline(), with_universal(), universal_only()});
+
+ benchmark("number number", "reduce(1.0*2.0,sum)", optimize_list);
+ benchmark("number vector", "reduce(5.0*x128,sum,x)", optimize_list);
+ benchmark("vector vector small", "reduce(x16*x16,sum,x)", optimize_list);
+ benchmark("vector vector large", "reduce(x768*x768,sum,x)", optimize_list);
+ benchmark("vector matrix full", "reduce(y64*x8y64,sum,x,y)", optimize_list);
+ benchmark("vector matrix inner", "reduce(y64*x8y64,sum,y)", optimize_list);
+ benchmark("vector matrix outer", "reduce(y64*x8y64,sum,x)", optimize_list);
+ benchmark("matrix matrix same", "reduce(a8y64*a8y64,sum,y)", optimize_list);
+ benchmark("matrix matrix different", "reduce(a8y64*b8y64,sum,y)", optimize_list);
+ benchmark("matmul", "reduce(a8b64*b64c8,sum,b)", optimize_list);
+ benchmark("sparse overlap", "reduce(x64_1*x64_1,sum,x)", optimize_list);
+ benchmark("sparse no overlap", "reduce(a64_1*b64_1,sum,b)", optimize_list);
+ benchmark("mixed dense", "reduce(a1_16x768*x768,sum,x)", optimize_list);
+ benchmark("mixed mixed complex", "reduce(a1_1x128*a2_1b64_1x128,sum,a,x)", optimize_list);
+
+ size_t max_desc_size = 0;
+ for (const auto &[desc, cost_map]: benchmark_results) {
+ max_desc_size = std::max(max_desc_size, desc.size());
+ }
+ for (const auto &[desc, cost_map]: benchmark_results) {
+ for (size_t i = 0; i < max_desc_size - desc.size(); ++i) {
+ fprintf(stderr, " ");
+ }
+ fprintf(stderr, "%s: ", desc.c_str());
+ size_t cnt = 0;
+ double baseline_cost = 0.0;
+ double with_universal_cost = 0.0;
+ double universal_only_cost = 0.0;
+ for (const auto &[name, cost]: cost_map) {
+ if (++cnt > 1) {
+ fprintf(stderr, ", ");
+ }
+ if (name == "baseline") {
+ baseline_cost = cost;
+ } else if (name == "with_universal") {
+ with_universal_cost = cost;
+ } else if (name == "universal_only") {
+ universal_only_cost = cost;
+ }
+ fprintf(stderr, "%s: %8.3f us", name.c_str(), cost);
+ }
+ if (with_universal_cost > 1.1 * baseline_cost) {
+ fprintf(stderr, ", LOSS: %8.3f", with_universal_cost / baseline_cost);
+ }
+ if (baseline_cost > 1.1 * with_universal_cost) {
+ fprintf(stderr, ", GAIN: %8.3f", baseline_cost / with_universal_cost);
+ }
+ if (with_universal_cost > 1.1 * universal_only_cost) {
+ fprintf(stderr, ", MISSED: %8.3f", with_universal_cost / universal_only_cost);
+ }
+ fprintf(stderr, "\n");
+ }
+ fprintf(stderr, "\n");
+}
+
+int main(int argc, char **argv) {
+ const std::string bench_option = "bench";
+ const std::string fast_option = "fast";
+ const std::string slow_option = "slow";
+ if ((argc > 1) && (bench_option == argv[1])) {
+ bench = true;
+ ++argv;
+ --argc;
+ }
+ if ((argc > 1) && (fast_option == argv[1])) {
+ budget = 0.1;
+ ++argv;
+ --argc;
+ }
+ if ((argc > 1) && (slow_option == argv[1])) {
+ budget = 5.0;
+ ++argv;
+ --argc;
+ }
+ ::testing::InitGoogleTest(&argc, argv);
+ return RUN_ALL_TESTS();
+}
diff --git a/eval/src/vespa/eval/eval/interpreted_function.cpp b/eval/src/vespa/eval/eval/interpreted_function.cpp
index 48683291cfb..8a32b7d11ff 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.cpp
+++ b/eval/src/vespa/eval/eval/interpreted_function.cpp
@@ -68,6 +68,8 @@ InterpretedFunction::ProfiledContext::ProfiledContext(const InterpretedFunction
{
}
+InterpretedFunction::ProfiledContext::~ProfiledContext() = default;
+
vespalib::string
InterpretedFunction::Instruction::resolve_symbol() const
{
diff --git a/eval/src/vespa/eval/eval/interpreted_function.h b/eval/src/vespa/eval/eval/interpreted_function.h
index 4528ccb79aa..86ab22073da 100644
--- a/eval/src/vespa/eval/eval/interpreted_function.h
+++ b/eval/src/vespa/eval/eval/interpreted_function.h
@@ -74,6 +74,7 @@ public:
Context context;
std::vector<std::pair<size_t,duration>> cost;
ProfiledContext(const InterpretedFunction &ifun);
+ ~ProfiledContext();
};
using op_function = void (*)(State &, uint64_t);
class Instruction {
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index 3d9152d6b80..4013021aaa4 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -44,6 +44,13 @@ LOG_SETUP(".eval.eval.optimize_tensor_function");
namespace vespalib::eval {
+OptimizeTensorFunctionOptions::OptimizeTensorFunctionOptions() noexcept
+ : allow_universal_dot_product(false)
+{
+}
+
+OptimizeTensorFunctionOptions::~OptimizeTensorFunctionOptions() = default;
+
namespace {
using Child = TensorFunction::Child;
@@ -60,7 +67,9 @@ void run_optimize_pass(const Child &root, Func&& optimize_node) {
}
}
-const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const TensorFunction &expr, Stash &stash) {
+const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const TensorFunction &expr, Stash &stash,
+ const OptimizeTensorFunctionOptions &options)
+{
Child root(expr);
run_optimize_pass(root, [&stash](const Child &child)
{
@@ -78,7 +87,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te
child.set(L2Distance::optimize(child.get(), stash));
child.set(MixedL2Distance::optimize(child.get(), stash));
});
- run_optimize_pass(root, [&stash](const Child &child)
+ run_optimize_pass(root, [&stash,&options](const Child &child)
{
child.set(DenseDotProductFunction::optimize(child.get(), stash));
child.set(SparseDotProductFunction::optimize(child.get(), stash));
@@ -89,7 +98,9 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te
child.set(DenseHammingDistance::optimize(child.get(), stash));
child.set(SimpleJoinCount::optimize(child.get(), stash));
child.set(MappedLookup::optimize(child.get(), stash));
- // child.set(UniversalDotProduct::optimize(child.get(), stash));
+ if (options.allow_universal_dot_product) {
+ child.set(UniversalDotProduct::optimize(child.get(), stash, false));
+ }
});
run_optimize_pass(root, [&stash](const Child &child)
{
@@ -116,11 +127,33 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const Te
} // namespace vespalib::eval::<unnamed>
-const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash) {
+const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash,
+ const OptimizeTensorFunctionOptions &options)
+{
LOG(debug, "tensor function before optimization:\n%s\n", function.as_string().c_str());
- const TensorFunction &optimized = optimize_for_factory(factory, function, stash);
+ const TensorFunction &optimized = optimize_for_factory(factory, function, stash, options);
LOG(debug, "tensor function after optimization:\n%s\n", optimized.as_string().c_str());
return optimized;
}
+const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash) {
+ return optimize_tensor_function(factory, function, stash, OptimizeTensorFunctionOptions());
+}
+
+const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, size_t *count) {
+ Child root(function);
+ run_optimize_pass(root, [&](const Child &child)
+ {
+ const TensorFunction &child_before = child.get();
+ const TensorFunction &child_after = optimizer(child_before, stash);
+ if (std::addressof(child_after) != std::addressof(child_before)) {
+ child.set(child_after);
+ if (count != nullptr) {
+ ++(*count);
+ }
+ }
+ });
+ return root.get();
+}
+
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.h b/eval/src/vespa/eval/eval/optimize_tensor_function.h
index d8ed104f3a6..4a5945860e7 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.h
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.h
@@ -2,13 +2,26 @@
#pragma once
+#include <functional>
+
namespace vespalib { class Stash; }
namespace vespalib::eval {
+struct OptimizeTensorFunctionOptions {
+ bool allow_universal_dot_product;
+ OptimizeTensorFunctionOptions() noexcept;
+ ~OptimizeTensorFunctionOptions();
+};
+
struct ValueBuilderFactory;
struct TensorFunction;
+const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash,
+ const OptimizeTensorFunctionOptions &options);
const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash);
+using tensor_function_optimizer = std::function<const TensorFunction &(const TensorFunction &expr, Stash &stash)>;
+const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, size_t *count = nullptr);
+
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 90761e43a01..ef81fb27def 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -63,25 +63,18 @@ struct MyMutableInject : public tensor_function::Inject {
};
const TensorFunction &maybe_patch(bool allow_mutable, const TensorFunction &plain_fun, const std::set<size_t> &mutable_set, Stash &stash) {
- using Child = TensorFunction::Child;
if (!allow_mutable) {
return plain_fun;
}
- Child root(plain_fun);
- std::vector<Child::CREF> nodes({root});
- for (size_t i = 0; i < nodes.size(); ++i) {
- nodes[i].get().get().push_children(nodes);
- }
- while (!nodes.empty()) {
- const Child &child = nodes.back();
- if (auto inject = as<tensor_function::Inject>(child.get())) {
- if (mutable_set.count(inject->param_idx()) > 0) {
- child.set(stash.create<MyMutableInject>(inject->result_type(), inject->param_idx()));
- }
- }
- nodes.pop_back();
- }
- return root.get();
+ auto optimizer = [&mutable_set](const TensorFunction &node, Stash &my_stash)->const TensorFunction &{
+ if (auto inject = as<tensor_function::Inject>(node);
+ inject && mutable_set.count(inject->param_idx()) > 0)
+ {
+ return my_stash.create<MyMutableInject>(inject->result_type(), inject->param_idx());
+ }
+ return node;
+ };
+ return apply_tensor_function_optimizer(plain_fun, optimizer, stash);
}
std::vector<Value::UP> make_params(const ValueBuilderFactory &factory, const Function &function,
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
index 79a94d862bf..86e6be52de4 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
@@ -84,6 +84,14 @@ struct SelectUniversalDotProduct {
}
};
+bool check_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs) {
+ UniversalDotProductParam param(res, lhs, rhs);
+ if (param.vector_size < 8) {
+ return false;
+ }
+ return true;
+}
+
} // namespace <unnamed>
UniversalDotProduct::UniversalDotProduct(const ValueType &res_type_in,
@@ -106,11 +114,13 @@ UniversalDotProduct::compile_self(const ValueBuilderFactory &, Stash &stash) con
}
const TensorFunction &
-UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash)
+UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash, bool force)
{
if (auto reduce = as<Reduce>(expr); reduce && (reduce->aggr() == Aggr::SUM)) {
if (auto join = as<Join>(reduce->child()); join && (join->function() == Mul::f)) {
- return stash.create<UniversalDotProduct>(expr.result_type(), join->lhs(), join->rhs());
+ if (force || check_types(expr.result_type(), join->lhs().result_type(), join->rhs().result_type())) {
+ return stash.create<UniversalDotProduct>(expr.result_type(), join->lhs(), join->rhs());
+ }
}
}
return expr;
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.h b/eval/src/vespa/eval/instruction/universal_dot_product.h
index ac5aa157f17..40fd109cc73 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.h
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.h
@@ -9,6 +9,9 @@ namespace vespalib::eval {
/**
* Tensor function performing dot product compatible operations
* (join:mul, reduce:sum) on values of arbitrary complexity.
+ *
+ * Note: can evaluate 'anything', but unless 'force' is given; will
+ * try to be a bit conservative about when to optimize.
**/
class UniversalDotProduct : public tensor_function::Op2
{
@@ -16,7 +19,7 @@ public:
UniversalDotProduct(const ValueType &res_type, const TensorFunction &lhs, const TensorFunction &rhs);
InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
bool result_is_mutable() const override { return true; }
- static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);
+ static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash, bool force);
};
} // namespace