summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-02-01 15:13:22 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-02-01 18:38:36 +0000
commit0261658338a6f7ad28bfca6f16f8a4b7c35d9cae (patch)
tree14f4494de17b6a64fc2f916fafcf66c1e723ec93 /eval
parentfe6300b1e9b81c09aa0235b5049439198c6a2206 (diff)
sparse dot product
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt9
-rw-r--r--eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp85
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp8
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp5
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/instruction/generic_join.cpp11
-rw-r--r--eval/src/vespa/eval/instruction/generic_join.h10
-rw-r--r--eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp107
-rw-r--r--eval/src/vespa/eval/instruction/sparse_dot_product_function.h23
10 files changed, 249 insertions, 11 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index fe621f6e9f0..23127cc12b5 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -65,6 +65,7 @@ vespa_define_module(
src/tests/instruction/mixed_simple_join_function
src/tests/instruction/pow_as_map_optimizer
src/tests/instruction/remove_trivial_dimension_optimizer
+ src/tests/instruction/sparse_dot_product_function
src/tests/instruction/sum_max_dot_product_function
src/tests/instruction/vector_from_doubles_function
src/tests/streamed/value
diff --git a/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt b/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt
new file mode 100644
index 00000000000..076f1d79796
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_dot_product_function/CMakeLists.txt
@@ -0,0 +1,9 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_sparse_dot_product_function_test_app TEST
+ SOURCES
+ sparse_dot_product_function_test.cpp
+ DEPENDS
+ vespaeval
+ GTest::GTest
+)
+vespa_add_test(NAME eval_sparse_dot_product_function_test_app COMMAND eval_sparse_dot_product_function_test_app)
diff --git a/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp b/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp
new file mode 100644
index 00000000000..65eab2778aa
--- /dev/null
+++ b/eval/src/tests/instruction/sparse_dot_product_function/sparse_dot_product_function_test.cpp
@@ -0,0 +1,85 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/eval/eval/fast_value.h>
+#include <vespa/eval/eval/simple_value.h>
+#include <vespa/eval/instruction/sparse_dot_product_function.h>
+#include <vespa/eval/eval/test/eval_fixture.h>
+#include <vespa/eval/eval/test/gen_spec.h>
+#include <vespa/vespalib/gtest/gtest.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::test;
+
+const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
+const ValueBuilderFactory &test_factory = SimpleValueBuilderFactory::get();
+
+//-----------------------------------------------------------------------------
+
+EvalFixture::ParamRepo make_params() {
+ return EvalFixture::ParamRepo()
+ .add("v1_x", GenSpec().map("x", 32, 1).seq_bias(3.0).gen())
+ .add("v1_x_f", GenSpec().map("x", 32, 1).seq_bias(3.0).cells_float().gen())
+ .add("v2_x", GenSpec().map("x", 16, 2).seq_bias(7.0).gen())
+ .add("v2_x_f", GenSpec().map("x", 16, 2).seq_bias(7.0).cells_float().gen())
+ .add("v3_y", GenSpec().map("y", 10, 1).gen())
+ .add("v4_xd", GenSpec().idx("x", 10).gen())
+ .add("m1_xy", GenSpec().map("x", 32, 1).map("y", 16, 2).seq_bias(3.0).gen())
+ .add("m2_xy", GenSpec().map("x", 16, 2).map("y", 32, 1).seq_bias(7.0).gen())
+ .add("m3_xym", GenSpec().map("x", 8, 1).idx("y", 5).gen());
+}
+EvalFixture::ParamRepo param_repo = make_params();
+
+void assert_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EvalFixture test_fixture(test_factory, expr, param_repo, true);
+ EvalFixture slow_fixture(prod_factory, expr, param_repo, false);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(test_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(slow_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseDotProductFunction>().size(), 1u);
+ EXPECT_EQ(test_fixture.find_all<SparseDotProductFunction>().size(), 1u);
+ EXPECT_EQ(slow_fixture.find_all<SparseDotProductFunction>().size(), 0u);
+}
+
+void assert_not_optimized(const vespalib::string &expr) {
+ EvalFixture fast_fixture(prod_factory, expr, param_repo, true);
+ EXPECT_EQ(fast_fixture.result(), EvalFixture::ref(expr, param_repo));
+ EXPECT_EQ(fast_fixture.find_all<SparseDotProductFunction>().size(), 0u);
+}
+
+//-----------------------------------------------------------------------------
+
+TEST(SparseDotProduct, expression_can_be_optimized)
+{
+ assert_optimized("reduce(v1_x*v2_x,sum,x)");
+ assert_optimized("reduce(v2_x*v1_x,sum)");
+ assert_optimized("reduce(v1_x*v2_x_f,sum)");
+ assert_optimized("reduce(v1_x_f*v2_x,sum)");
+ assert_optimized("reduce(v1_x_f*v2_x_f,sum)");
+}
+
+TEST(SparseDotProduct, multi_dimensional_expression_can_be_optimized)
+{
+ assert_optimized("reduce(m1_xy*m2_xy,sum,x,y)");
+ assert_optimized("reduce(m1_xy*m2_xy,sum)");
+}
+
+TEST(SparseDotProduct, embedded_dot_product_is_not_optimized)
+{
+ assert_not_optimized("reduce(m1_xy*v1_x,sum,x)");
+ assert_not_optimized("reduce(v1_x*m1_xy,sum,x)");
+}
+
+TEST(SparseDotProduct, similar_expressions_are_not_optimized)
+{
+ assert_not_optimized("reduce(m1_xy*v1_x,sum)");
+ assert_not_optimized("reduce(v1_x*v3_y,sum)");
+ assert_not_optimized("reduce(v2_x*v1_x,max)");
+ assert_not_optimized("reduce(v2_x+v1_x,sum)");
+ assert_not_optimized("reduce(v4_xd*v4_xd,sum)");
+ assert_not_optimized("reduce(m3_xym*m3_xym,sum)");
+}
+
+//-----------------------------------------------------------------------------
+
+GTEST_MAIN_RUN_ALL_TESTS()
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index 2e8c89f88fc..25612b8d5fd 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -5,6 +5,7 @@
#include "simple_value.h"
#include <vespa/eval/instruction/dense_dot_product_function.h>
+#include <vespa/eval/instruction/sparse_dot_product_function.h>
#include <vespa/eval/instruction/mixed_inner_product_function.h>
#include <vespa/eval/instruction/sum_max_dot_product_function.h>
#include <vespa/eval/instruction/dense_xw_product_function.h>
@@ -31,11 +32,7 @@ namespace vespalib::eval {
namespace {
-const TensorFunction &optimize_for_factory(const ValueBuilderFactory &factory, const TensorFunction &expr, Stash &stash) {
- if (&factory == &SimpleValueBuilderFactory::get()) {
- // never optimize simple value evaluation
- return expr;
- }
+const TensorFunction &optimize_for_factory(const ValueBuilderFactory &, const TensorFunction &expr, Stash &stash) {
using Child = TensorFunction::Child;
Child root(expr);
{
@@ -47,6 +44,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &factory, c
const Child &child = nodes.back().get();
child.set(SumMaxDotProductFunction::optimize(child.get(), stash));
child.set(DenseDotProductFunction::optimize(child.get(), stash));
+ child.set(SparseDotProductFunction::optimize(child.get(), stash));
child.set(DenseXWProductFunction::optimize(child.get(), stash));
child.set(DenseMatMulFunction::optimize(child.get(), stash));
child.set(DenseMultiMatMulFunction::optimize(child.get(), stash));
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 58d8905baf3..9b1789fbeea 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -28,7 +28,10 @@ NodeTypes get_types(const Function &function, const ParamRepo &param_repo) {
std::vector<ValueType> param_types;
for (size_t i = 0; i < function.num_params(); ++i) {
auto pos = param_repo.map.find(function.param_name(i));
- ASSERT_TRUE(pos != param_repo.map.end());
+ if (pos == param_repo.map.end()) {
+ TEST_STATE(fmt("param name: '%s'", function.param_name(i).data()).c_str());
+ ASSERT_TRUE(pos != param_repo.map.end());
+ }
param_types.push_back(ValueType::from_spec(pos->second.value.type()));
ASSERT_TRUE(!param_types.back().is_error());
}
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index 58d5290f5d9..cac69d23640 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -32,6 +32,7 @@ vespa_add_library(eval_instruction OBJECT
pow_as_map_optimizer.cpp
remove_trivial_dimension_optimizer.cpp
replace_type_function.cpp
+ sparse_dot_product_function.cpp
sum_max_dot_product_function.cpp
vector_from_doubles_function.cpp
)
diff --git a/eval/src/vespa/eval/instruction/generic_join.cpp b/eval/src/vespa/eval/instruction/generic_join.cpp
index abe29b8228c..6d6f86b7c4d 100644
--- a/eval/src/vespa/eval/instruction/generic_join.cpp
+++ b/eval/src/vespa/eval/instruction/generic_join.cpp
@@ -308,6 +308,17 @@ SparseJoinPlan::SparseJoinPlan(const ValueType &lhs_type, const ValueType &rhs_t
[](const auto &a, const auto &b){ return (a.name < b.name); });
}
+SparseJoinPlan::SparseJoinPlan(size_t num_mapped_dims)
+ : sources(num_mapped_dims, Source::BOTH), lhs_overlap(), rhs_overlap()
+{
+ lhs_overlap.reserve(num_mapped_dims);
+ rhs_overlap.reserve(num_mapped_dims);
+ for (size_t i = 0; i < num_mapped_dims; ++i) {
+ lhs_overlap.push_back(i);
+ rhs_overlap.push_back(i);
+ }
+}
+
bool
SparseJoinPlan::should_forward_lhs_index() const
{
diff --git a/eval/src/vespa/eval/instruction/generic_join.h b/eval/src/vespa/eval/instruction/generic_join.h
index 1fcfcf416cc..026a2938971 100644
--- a/eval/src/vespa/eval/instruction/generic_join.h
+++ b/eval/src/vespa/eval/instruction/generic_join.h
@@ -58,6 +58,7 @@ struct SparseJoinPlan {
bool should_forward_lhs_index() const;
bool should_forward_rhs_index() const;
SparseJoinPlan(const ValueType &lhs_type, const ValueType &rhs_type);
+ SparseJoinPlan(size_t num_mapped_dims); // full overlap plan
~SparseJoinPlan();
};
@@ -70,15 +71,14 @@ struct SparseJoinState {
const Value::Index &first_index;
const Value::Index &second_index;
const std::vector<size_t> &second_view_dims;
- std::vector<string_id> full_address;
- std::vector<string_id*> first_address;
- std::vector<const string_id*> address_overlap;
- std::vector<string_id*> second_only_address;
+ std::vector<string_id> full_address;
+ std::vector<string_id*> first_address;
+ std::vector<const string_id*> address_overlap;
+ std::vector<string_id*> second_only_address;
size_t lhs_subspace;
size_t rhs_subspace;
size_t &first_subspace;
size_t &second_subspace;
-
SparseJoinState(const SparseJoinPlan &plan, const Value::Index &lhs, const Value::Index &rhs);
~SparseJoinState();
};
diff --git a/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp b/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp
new file mode 100644
index 00000000000..1c9e552521d
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/sparse_dot_product_function.cpp
@@ -0,0 +1,107 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "sparse_dot_product_function.h"
+#include "generic_join.h"
+#include "detect_type.h"
+#include <vespa/eval/eval/fast_value.hpp>
+
+namespace vespalib::eval {
+
+using namespace tensor_function;
+using namespace operation;
+using namespace instruction;
+
+namespace {
+
+template <typename SCT, typename BCT>
+double my_fast_sparse_dot_product(const FastValueIndex &small_idx, const FastValueIndex &big_idx,
+ const SCT *small_cells, const BCT *big_cells)
+{
+ double result = 0.0;
+ small_idx.map.each_map_entry([&](auto small_subspace, auto hash) {
+ auto small_addr = small_idx.map.get_addr(small_subspace);
+ auto big_subspace = big_idx.map.lookup(small_addr, hash);
+ if (big_subspace != FastAddrMap::npos()) {
+ result += (small_cells[small_subspace] * big_cells[big_subspace]);
+ }
+ });
+ return result;
+}
+
+template <typename LCT, typename RCT>
+void my_sparse_dot_product_op(InterpretedFunction::State &state, uint64_t num_mapped_dims) {
+ const auto &lhs_idx = state.peek(1).index();
+ const auto &rhs_idx = state.peek(0).index();
+ const LCT *lhs_cells = state.peek(1).cells().typify<LCT>().cbegin();
+ const RCT *rhs_cells = state.peek(0).cells().typify<RCT>().cbegin();
+ if (auto indexes = detect_type<FastValueIndex>(lhs_idx, rhs_idx)) {
+ [[likely]];
+ const auto &lhs_fast = indexes.get<0>();
+ const auto &rhs_fast = indexes.get<1>();
+ double result = (rhs_fast.map.size() < lhs_fast.map.size())
+ ? my_fast_sparse_dot_product(rhs_fast, lhs_fast, rhs_cells, lhs_cells)
+ : my_fast_sparse_dot_product(lhs_fast, rhs_fast, lhs_cells, rhs_cells);
+ state.pop_pop_push(state.stash.create<ScalarValue<double>>(result));
+ } else {
+ [[unlikely]];
+ double result = 0.0;
+ SparseJoinPlan plan(num_mapped_dims);
+ SparseJoinState sparse(plan, lhs_idx, rhs_idx);
+ auto outer = sparse.first_index.create_view({});
+ auto inner = sparse.second_index.create_view(sparse.second_view_dims);
+ outer->lookup({});
+ while (outer->next_result(sparse.first_address, sparse.first_subspace)) {
+ inner->lookup(sparse.address_overlap);
+ if (inner->next_result(sparse.second_only_address, sparse.second_subspace)) {
+ result += (lhs_cells[sparse.lhs_subspace] * rhs_cells[sparse.rhs_subspace]);
+ }
+ }
+ state.pop_pop_push(state.stash.create<ScalarValue<double>>(result));
+ }
+}
+
+struct MyGetFun {
+ template <typename LCT, typename RCT>
+ static auto invoke() { return my_sparse_dot_product_op<LCT,RCT>; }
+};
+
+} // namespace <unnamed>
+
+SparseDotProductFunction::SparseDotProductFunction(const TensorFunction &lhs_in,
+ const TensorFunction &rhs_in)
+ : tensor_function::Op2(ValueType::make_type(CellType::DOUBLE, {}), lhs_in, rhs_in)
+{
+}
+
+InterpretedFunction::Instruction
+SparseDotProductFunction::compile_self(const ValueBuilderFactory &, Stash &) const
+{
+ auto op = typify_invoke<2,TypifyCellType,MyGetFun>(lhs().result_type().cell_type(), rhs().result_type().cell_type());
+ return InterpretedFunction::Instruction(op, lhs().result_type().count_mapped_dimensions());
+}
+
+bool
+SparseDotProductFunction::compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs)
+{
+ return (res.is_scalar() && (res.cell_type() == CellType::DOUBLE) &&
+ lhs.is_sparse() && (rhs.dimensions() == lhs.dimensions()));
+}
+
+const TensorFunction &
+SparseDotProductFunction::optimize(const TensorFunction &expr, Stash &stash)
+{
+ auto reduce = as<Reduce>(expr);
+ if (reduce && (reduce->aggr() == Aggr::SUM)) {
+ auto join = as<Join>(reduce->child());
+ if (join && (join->function() == Mul::f)) {
+ const TensorFunction &lhs = join->lhs();
+ const TensorFunction &rhs = join->rhs();
+ if (compatible_types(expr.result_type(), lhs.result_type(), rhs.result_type())) {
+ return stash.create<SparseDotProductFunction>(lhs, rhs);
+ }
+ }
+ }
+ return expr;
+}
+
+} // namespace
diff --git a/eval/src/vespa/eval/instruction/sparse_dot_product_function.h b/eval/src/vespa/eval/instruction/sparse_dot_product_function.h
new file mode 100644
index 00000000000..ccc7a61f5e8
--- /dev/null
+++ b/eval/src/vespa/eval/instruction/sparse_dot_product_function.h
@@ -0,0 +1,23 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/tensor_function.h>
+
+namespace vespalib::eval {
+
+/**
+ * Tensor function for a dot product between two sparse tensors.
+ */
+class SparseDotProductFunction : public tensor_function::Op2
+{
+public:
+ SparseDotProductFunction(const TensorFunction &lhs_in,
+ const TensorFunction &rhs_in);
+ InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
+ bool result_is_mutable() const override { return true; }
+ static bool compatible_types(const ValueType &res, const ValueType &lhs, const ValueType &rhs);
+ static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash);
+};
+
+} // namespace