aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@oath.com>2021-01-21 13:29:28 +0000
committerHåvard Pettersen <havardpe@oath.com>2021-01-22 12:02:08 +0000
commit6c2c269516896dc97204ec07abb73211d3dc29c0 (patch)
treea197535d6e4a174d470f3b18073c3a7915fa505d
parente41754597890c4611980fee95e8aec8f9b29e476 (diff)
mixed simple join
-rw-r--r--eval/CMakeLists.txt2
-rw-r--r--eval/src/tests/instruction/dense_simple_join_function/CMakeLists.txt8
-rw-r--r--eval/src/tests/instruction/mixed_simple_join_function/CMakeLists.txt8
-rw-r--r--eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp (renamed from eval/src/tests/instruction/dense_simple_join_function/dense_simple_join_function_test.cpp)55
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp4
-rw-r--r--eval/src/vespa/eval/instruction/CMakeLists.txt2
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp (renamed from eval/src/vespa/eval/instruction/dense_simple_join_function.cpp)73
-rw-r--r--eval/src/vespa/eval/instruction/mixed_simple_join_function.h (renamed from eval/src/vespa/eval/instruction/dense_simple_join_function.h)18
8 files changed, 109 insertions, 61 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 051cf777f2c..1d9fa7478a0 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -44,7 +44,6 @@ vespa_define_module(
src/tests/instruction/dense_multi_matmul_function
src/tests/instruction/dense_replace_type_function
src/tests/instruction/dense_simple_expand_function
- src/tests/instruction/dense_simple_join_function
src/tests/instruction/dense_single_reduce_function
src/tests/instruction/dense_tensor_create_function
src/tests/instruction/dense_tensor_peek_function
@@ -62,6 +61,7 @@ vespa_define_module(
src/tests/instruction/join_with_number
src/tests/instruction/mixed_inner_product_function
src/tests/instruction/mixed_map_function
+ src/tests/instruction/mixed_simple_join_function
src/tests/instruction/pow_as_map_optimizer
src/tests/instruction/remove_trivial_dimension_optimizer
src/tests/instruction/vector_from_doubles_function
diff --git a/eval/src/tests/instruction/dense_simple_join_function/CMakeLists.txt b/eval/src/tests/instruction/dense_simple_join_function/CMakeLists.txt
deleted file mode 100644
index 8a2df392145..00000000000
--- a/eval/src/tests/instruction/dense_simple_join_function/CMakeLists.txt
+++ /dev/null
@@ -1,8 +0,0 @@
-# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-vespa_add_executable(eval_dense_simple_join_function_test_app TEST
- SOURCES
- dense_simple_join_function_test.cpp
- DEPENDS
- vespaeval
-)
-vespa_add_test(NAME eval_dense_simple_join_function_test_app COMMAND eval_dense_simple_join_function_test_app)
diff --git a/eval/src/tests/instruction/mixed_simple_join_function/CMakeLists.txt b/eval/src/tests/instruction/mixed_simple_join_function/CMakeLists.txt
new file mode 100644
index 00000000000..f603c600691
--- /dev/null
+++ b/eval/src/tests/instruction/mixed_simple_join_function/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_mixed_simple_join_function_test_app TEST
+ SOURCES
+ mixed_simple_join_function_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_mixed_simple_join_function_test_app COMMAND eval_mixed_simple_join_function_test_app)
diff --git a/eval/src/tests/instruction/dense_simple_join_function/dense_simple_join_function_test.cpp b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
index 2186d49385e..9c891adf179 100644
--- a/eval/src/tests/instruction/dense_simple_join_function/dense_simple_join_function_test.cpp
+++ b/eval/src/tests/instruction/mixed_simple_join_function/mixed_simple_join_function_test.cpp
@@ -2,7 +2,7 @@
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/eval/eval/tensor_function.h>
-#include <vespa/eval/instruction/dense_simple_join_function.h>
+#include <vespa/eval/instruction/mixed_simple_join_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
#include <vespa/eval/eval/test/tensor_model.hpp>
@@ -15,8 +15,8 @@ using namespace vespalib::eval::tensor_function;
using vespalib::make_string_short::fmt;
-using Primary = DenseSimpleJoinFunction::Primary;
-using Overlap = DenseSimpleJoinFunction::Overlap;
+using Primary = MixedSimpleJoinFunction::Primary;
+using Overlap = MixedSimpleJoinFunction::Overlap;
namespace vespalib::eval {
@@ -47,12 +47,14 @@ EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("a", spec(1.5))
.add("b", spec(2.5))
- .add("sparse", spec({x({"a"})}, N()))
- .add("mixed", spec({x({"a"}),y(5)}, N()))
+ .add("sparse", spec({x({"a", "b", "c"})}, N()))
+ .add("mixed", spec({x({"a", "b", "c"}),y(5),z(3)}, N()))
+ .add("empty_mixed", spec({x({}),y(5),z(3)}, N()))
+ .add_mutable("@mixed", spec({x({"a", "b", "c"}),y(5),z(3)}, N()))
.add_cube("a", 1, "b", 1, "c", 1)
.add_cube("x", 1, "y", 1, "z", 1)
.add_cube("x", 3, "y", 5, "z", 3)
- .add_vector("x", 5)
+ .add_vector("z", 3)
.add_dense({{"c", 5}, {"d", 1}})
.add_dense({{"b", 1}, {"c", 5}})
.add_matrix("x", 3, "y", 5, [](size_t idx) noexcept { return double((idx * 2) + 3); })
@@ -69,7 +71,7 @@ void verify_optimized(const vespalib::string &expr, Primary primary, Overlap ove
EvalFixture fixture(prod_factory, expr, param_repo, true, true);
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseSimpleJoinFunction>();
+ auto info = fixture.find_all<MixedSimpleJoinFunction>();
ASSERT_EQUAL(info.size(), 1u);
EXPECT_TRUE(info[0]->result_is_mutable());
EXPECT_EQUAL(info[0]->primary(), primary);
@@ -81,7 +83,9 @@ void verify_optimized(const vespalib::string &expr, Primary primary, Overlap ove
if (i == size_t(p_inplace)) {
EXPECT_EQUAL(fixture.get_param(i), fixture.result());
} else {
- EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+ if (!fixture.result().cells().empty()) {
+ EXPECT_NOT_EQUAL(fixture.get_param(i), fixture.result());
+ }
}
}
}
@@ -91,7 +95,7 @@ void verify_not_optimized(const vespalib::string &expr) {
EvalFixture fixture(prod_factory, expr, param_repo, true);
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
EXPECT_EQUAL(fixture.result(), slow_fixture.result());
- auto info = fixture.find_all<DenseSimpleJoinFunction>();
+ auto info = fixture.find_all<MixedSimpleJoinFunction>();
EXPECT_TRUE(info.empty());
}
@@ -112,7 +116,7 @@ TEST("require that outer nesting is preferred to inner nesting") {
}
TEST("require that non-subset join is not optimized") {
- TEST_DO(verify_not_optimized("x5+y5"));
+ TEST_DO(verify_not_optimized("y5+z3"));
}
TEST("require that subset join with complex overlap is not optimized") {
@@ -207,7 +211,7 @@ TEST("require that scalar values are not optimized") {
TEST_DO(verify_not_optimized("mixed+a"));
}
-TEST("require that mapped tensors are not optimized") {
+TEST("require that sparse tensors are mostly not optimized") {
TEST_DO(verify_not_optimized("sparse+sparse"));
TEST_DO(verify_not_optimized("sparse+y5"));
TEST_DO(verify_not_optimized("y5+sparse"));
@@ -215,10 +219,33 @@ TEST("require that mapped tensors are not optimized") {
TEST_DO(verify_not_optimized("mixed+sparse"));
}
-TEST("require mixed tensors are not optimized") {
+TEST("require that sparse tensor joined with trivial dense tensor is optimized") {
+ TEST_DO(verify_optimized("sparse+a1b1c1", Primary::LHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("a1b1c1+sparse", Primary::RHS, Overlap::FULL, false, 1));
+}
+
+TEST("require that primary tensor can be empty") {
+ TEST_DO(verify_optimized("empty_mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("y5z3+empty_mixed", Primary::RHS, Overlap::FULL, false, 1));
+}
+
+TEST("require that mixed tensors can be optimized") {
TEST_DO(verify_not_optimized("mixed+mixed"));
- TEST_DO(verify_not_optimized("mixed+y5"));
- TEST_DO(verify_not_optimized("y5+mixed"));
+ TEST_DO(verify_optimized("mixed+y5z3", Primary::LHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("mixed+y5", Primary::LHS, Overlap::OUTER, false, 3));
+ TEST_DO(verify_optimized("mixed+z3", Primary::LHS, Overlap::INNER, false, 5));
+ TEST_DO(verify_optimized("y5z3+mixed", Primary::RHS, Overlap::FULL, false, 1));
+ TEST_DO(verify_optimized("y5+mixed", Primary::RHS, Overlap::OUTER, false, 3));
+ TEST_DO(verify_optimized("z3+mixed", Primary::RHS, Overlap::INNER, false, 5));
+}
+
+TEST("require that mixed tensors can be inplace") {
+ TEST_DO(verify_optimized("@mixed+y5z3", Primary::LHS, Overlap::FULL, true, 1, 0));
+ TEST_DO(verify_optimized("@mixed+y5", Primary::LHS, Overlap::OUTER, true, 3, 0));
+ TEST_DO(verify_optimized("@mixed+z3", Primary::LHS, Overlap::INNER, true, 5, 0));
+ TEST_DO(verify_optimized("y5z3+@mixed", Primary::RHS, Overlap::FULL, true, 1, 1));
+ TEST_DO(verify_optimized("y5+@mixed", Primary::RHS, Overlap::OUTER, true, 3, 1));
+ TEST_DO(verify_optimized("z3+@mixed", Primary::RHS, Overlap::INNER, true, 5, 1));
}
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index f50e3b187f0..ea63171b260 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -15,7 +15,7 @@
#include <vespa/eval/instruction/remove_trivial_dimension_optimizer.h>
#include <vespa/eval/instruction/dense_lambda_peek_optimizer.h>
#include <vespa/eval/instruction/dense_simple_expand_function.h>
-#include <vespa/eval/instruction/dense_simple_join_function.h>
+#include <vespa/eval/instruction/mixed_simple_join_function.h>
#include <vespa/eval/instruction/join_with_number_function.h>
#include <vespa/eval/instruction/pow_as_map_optimizer.h>
#include <vespa/eval/instruction/mixed_map_function.h>
@@ -69,7 +69,7 @@ const TensorFunction &optimize_for_factory(const ValueBuilderFactory &factory, c
child.set(FastRenameOptimizer::optimize(child.get(), stash));
child.set(PowAsMapOptimizer::optimize(child.get(), stash));
child.set(MixedMapFunction::optimize(child.get(), stash));
- child.set(DenseSimpleJoinFunction::optimize(child.get(), stash));
+ child.set(MixedSimpleJoinFunction::optimize(child.get(), stash));
child.set(JoinWithNumberFunction::optimize(child.get(), stash));
child.set(DenseSingleReduceFunction::optimize(child.get(), stash));
nodes.pop_back();
diff --git a/eval/src/vespa/eval/instruction/CMakeLists.txt b/eval/src/vespa/eval/instruction/CMakeLists.txt
index fdb4bd2b5cb..113df255658 100644
--- a/eval/src/vespa/eval/instruction/CMakeLists.txt
+++ b/eval/src/vespa/eval/instruction/CMakeLists.txt
@@ -10,7 +10,6 @@ vespa_add_library(eval_instruction OBJECT
dense_matmul_function.cpp
dense_multi_matmul_function.cpp
dense_simple_expand_function.cpp
- dense_simple_join_function.cpp
dense_single_reduce_function.cpp
dense_tensor_create_function.cpp
dense_tensor_peek_function.cpp
@@ -29,6 +28,7 @@ vespa_add_library(eval_instruction OBJECT
join_with_number_function.cpp
mixed_inner_product_function.cpp
mixed_map_function.cpp
+ mixed_simple_join_function.cpp
pow_as_map_optimizer.cpp
remove_trivial_dimension_optimizer.cpp
replace_type_function.cpp
diff --git a/eval/src/vespa/eval/instruction/dense_simple_join_function.cpp b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
index 76d020eef9d..b6be8347220 100644
--- a/eval/src/vespa/eval/instruction/dense_simple_join_function.cpp
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.cpp
@@ -1,6 +1,6 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-#include "dense_simple_join_function.h"
+#include "mixed_simple_join_function.h"
#include <vespa/vespalib/objects/objectvisitor.h>
#include <vespa/eval/eval/value.h>
#include <vespa/eval/eval/operation.h>
@@ -16,8 +16,8 @@ using vespalib::ArrayRef;
using namespace operation;
using namespace tensor_function;
-using Primary = DenseSimpleJoinFunction::Primary;
-using Overlap = DenseSimpleJoinFunction::Overlap;
+using Primary = MixedSimpleJoinFunction::Primary;
+using Overlap = MixedSimpleJoinFunction::Overlap;
using op_function = InterpretedFunction::op_function;
using Instruction = InterpretedFunction::Instruction;
@@ -40,9 +40,10 @@ struct TypifyOverlap {
struct JoinParams {
const ValueType &result_type;
size_t factor;
+ size_t subspace_size;
join_fun_t function;
JoinParams(const ValueType &result_type_in, size_t factor_in, join_fun_t function_in)
- : result_type(result_type_in), factor(factor_in), function(function_in) {}
+ : result_type(result_type_in), factor(factor_in), subspace_size(result_type.dense_subspace_size()), function(function_in) {}
};
template <typename OCT, bool pri_mut, typename PCT>
@@ -65,25 +66,33 @@ void my_simple_join_op(State &state, uint64_t param) {
auto pri_cells = state.peek(swap ? 0 : 1).cells().typify<PCT>();
auto sec_cells = state.peek(swap ? 1 : 0).cells().typify<SCT>();
auto dst_cells = make_dst_cells<OCT, pri_mut>(pri_cells, state.stash);
- if (overlap == Overlap::FULL) {
- apply_op2_vec_vec(dst_cells.begin(), pri_cells.begin(), sec_cells.begin(), dst_cells.size(), my_op);
- } else if (overlap == Overlap::OUTER) {
- size_t offset = 0;
- size_t factor = params.factor;
- for (SCT cell: sec_cells) {
- apply_op2_vec_num(dst_cells.begin() + offset, pri_cells.begin() + offset, cell, factor, my_op);
- offset += factor;
- }
- } else {
- assert(overlap == Overlap::INNER);
- size_t offset = 0;
- size_t factor = params.factor;
- for (size_t i = 0; i < factor; ++i) {
- apply_op2_vec_vec(dst_cells.begin() + offset, pri_cells.begin() + offset, sec_cells.begin(), sec_cells.size(), my_op);
- offset += sec_cells.size();
+ const auto &index = state.peek(swap ? 0 : 1).index();
+ size_t subspace_size = params.subspace_size;
+ const PCT *pri = pri_cells.begin();
+ OCT *dst = dst_cells.begin();
+ for (; pri < pri_cells.end(); pri += subspace_size, dst += subspace_size) {
+ if constexpr (overlap == Overlap::FULL) {
+ apply_op2_vec_vec(dst, pri, sec_cells.begin(), subspace_size, my_op);
+ } else if constexpr (overlap == Overlap::OUTER) {
+ size_t offset = 0;
+ size_t factor = params.factor;
+ for (SCT cell: sec_cells) {
+ apply_op2_vec_num(dst + offset, pri + offset, cell, factor, my_op);
+ offset += factor;
+ }
+ } else {
+ static_assert(overlap == Overlap::INNER);
+ size_t offset = 0;
+ size_t factor = params.factor;
+ for (size_t i = 0; i < factor; ++i) {
+ apply_op2_vec_vec(dst + offset, pri + offset, sec_cells.begin(), sec_cells.size(), my_op);
+ offset += sec_cells.size();
+ }
}
}
- state.pop_pop_push(state.stash.create<DenseValueView>(params.result_type, TypedCells(dst_cells)));
+ assert(pri == pri_cells.end());
+ assert(dst == dst_cells.end());
+ state.pop_pop_push(state.stash.create<ValueView>(params.result_type, index, TypedCells(dst_cells)));
}
//-----------------------------------------------------------------------------
@@ -103,6 +112,11 @@ bool can_use_as_output(const TensorFunction &fun, CellType result_cell_type) {
}
Primary select_primary(const TensorFunction &lhs, const TensorFunction &rhs, CellType result_cell_type) {
+ if (!lhs.result_type().is_dense()) {
+ return Primary::LHS;
+ } else if (!rhs.result_type().is_dense()) {
+ return Primary::RHS;
+ }
size_t lhs_size = lhs.result_type().dense_subspace_size();
size_t rhs_size = rhs.result_type().dense_subspace_size();
if (lhs_size > rhs_size) {
@@ -124,6 +138,7 @@ Primary select_primary(const TensorFunction &lhs, const TensorFunction &rhs, Cel
std::optional<Overlap> detect_overlap(const TensorFunction &primary, const TensorFunction &secondary) {
std::vector<ValueType::Dimension> a = primary.result_type().nontrivial_indexed_dimensions();
std::vector<ValueType::Dimension> b = secondary.result_type().nontrivial_indexed_dimensions();
+ assert(secondary.result_type().is_dense());
if (b.size() > a.size()) {
return std::nullopt;
} else if (b == a) {
@@ -146,7 +161,7 @@ std::optional<Overlap> detect_overlap(const TensorFunction &lhs, const TensorFun
//-----------------------------------------------------------------------------
-DenseSimpleJoinFunction::DenseSimpleJoinFunction(const ValueType &result_type,
+MixedSimpleJoinFunction::MixedSimpleJoinFunction(const ValueType &result_type,
const TensorFunction &lhs,
const TensorFunction &rhs,
join_fun_t function_in,
@@ -158,10 +173,10 @@ DenseSimpleJoinFunction::DenseSimpleJoinFunction(const ValueType &result_type,
{
}
-DenseSimpleJoinFunction::~DenseSimpleJoinFunction() = default;
+MixedSimpleJoinFunction::~MixedSimpleJoinFunction() = default;
bool
-DenseSimpleJoinFunction::primary_is_mutable() const
+MixedSimpleJoinFunction::primary_is_mutable() const
{
if (_primary == Primary::LHS) {
return lhs().result_is_mutable();
@@ -171,7 +186,7 @@ DenseSimpleJoinFunction::primary_is_mutable() const
}
size_t
-DenseSimpleJoinFunction::factor() const
+MixedSimpleJoinFunction::factor() const
{
const TensorFunction &p = (_primary == Primary::LHS) ? lhs() : rhs();
const TensorFunction &s = (_primary == Primary::LHS) ? rhs() : lhs();
@@ -182,7 +197,7 @@ DenseSimpleJoinFunction::factor() const
}
Instruction
-DenseSimpleJoinFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
+MixedSimpleJoinFunction::compile_self(const ValueBuilderFactory &, Stash &stash) const
{
const JoinParams &params = stash.create<JoinParams>(result_type(), factor(), function());
auto op = typify_invoke<6,MyTypify,MyGetFun>(lhs().result_type().cell_type(),
@@ -193,18 +208,18 @@ DenseSimpleJoinFunction::compile_self(const ValueBuilderFactory &, Stash &stash)
}
const TensorFunction &
-DenseSimpleJoinFunction::optimize(const TensorFunction &expr, Stash &stash)
+MixedSimpleJoinFunction::optimize(const TensorFunction &expr, Stash &stash)
{
if (auto join = as<Join>(expr)) {
const TensorFunction &lhs = join->lhs();
const TensorFunction &rhs = join->rhs();
- if (lhs.result_type().is_dense() && rhs.result_type().is_dense()) {
+ if (lhs.result_type().is_dense() || rhs.result_type().is_dense()) {
Primary primary = select_primary(lhs, rhs, join->result_type().cell_type());
std::optional<Overlap> overlap = detect_overlap(lhs, rhs, primary);
if (overlap.has_value()) {
const TensorFunction &ptf = (primary == Primary::LHS) ? lhs : rhs;
assert(ptf.result_type().dense_subspace_size() == join->result_type().dense_subspace_size());
- return stash.create<DenseSimpleJoinFunction>(join->result_type(), lhs, rhs, join->function(),
+ return stash.create<MixedSimpleJoinFunction>(join->result_type(), lhs, rhs, join->function(),
primary, overlap.value());
}
}
diff --git a/eval/src/vespa/eval/instruction/dense_simple_join_function.h b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
index 8fa0be9d021..94e5f3c52b5 100644
--- a/eval/src/vespa/eval/instruction/dense_simple_join_function.h
+++ b/eval/src/vespa/eval/instruction/mixed_simple_join_function.h
@@ -8,11 +8,17 @@
namespace vespalib::eval {
/**
- * Tensor function for simple join operations on dense tensors.
- * TODO: consider if this is useful anymore, maybe we just need
- * to handle inplace.
+ * Tensor function for simple join operations between a primary and a
+ * secondary tensor that may be evaluated in-place if the primary
+ * tensor is mutable and has the same cell-type as the result.
+ *
+ * The secondary tensor must be dense and contain a subset of the
+ * dimensions present in the dense subspace of the primary tensor. The
+ * common dimensions must have a simple overlap pattern ('inner',
+ * 'outer' or 'full'). The primary tensor may be mixed, in which case
+ * the index will be forwarded to the result.
**/
-class DenseSimpleJoinFunction : public tensor_function::Join
+class MixedSimpleJoinFunction : public tensor_function::Join
{
using Super = tensor_function::Join;
public:
@@ -23,13 +29,13 @@ private:
Primary _primary;
Overlap _overlap;
public:
- DenseSimpleJoinFunction(const ValueType &result_type,
+ MixedSimpleJoinFunction(const ValueType &result_type,
const TensorFunction &lhs,
const TensorFunction &rhs,
join_fun_t function_in,
Primary primary_in,
Overlap overlap_in);
- ~DenseSimpleJoinFunction() override;
+ ~MixedSimpleJoinFunction() override;
Primary primary() const { return _primary; }
Overlap overlap() const { return _overlap; }
bool primary_is_mutable() const;