summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHåvard Pettersen <havardpe@yahooinc.com>2023-09-12 12:15:29 +0000
committerHåvard Pettersen <havardpe@yahooinc.com>2023-09-12 12:15:29 +0000
commit7a3a53ec4e11eff07cf425f368e5faed477e5fb5 (patch)
tree76223812496ca7160027ad3ce169cad6bccb12b7 /eval
parent5e335474fccb3dbfe0e631e72648d3ae8b1ff703 (diff)
improve testing by verifying corner cases
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp71
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.cpp6
-rw-r--r--eval/src/vespa/eval/eval/optimize_tensor_function.h4
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.cpp30
-rw-r--r--eval/src/vespa/eval/instruction/universal_dot_product.h3
5 files changed, 78 insertions, 36 deletions
diff --git a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
index e1967f012cb..bf9aeead461 100644
--- a/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
+++ b/eval/src/tests/instruction/universal_dot_product/universal_dot_product_test.cpp
@@ -15,6 +15,7 @@
#include <vespa/vespalib/util/benchmark_timer.h>
#include <vespa/vespalib/util/classname.h>
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/trinary.h>
#include <vespa/vespalib/gtest/gtest.h>
#include <optional>
@@ -115,7 +116,17 @@ Optimize universal_only() {
return Optimize::specific("universal_only", my_optimizer);
}
-void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
+Trinary tri(bool value) {
+ return value ? Trinary::True : Trinary::False;
+}
+
+bool satisfies(bool actual, Trinary expect) {
+ return (expect == Trinary::Undefined) || (actual == (expect == Trinary::True));
+}
+
+void verify(const vespalib::string &expr, select_cell_type_t select_cell_type,
+ Trinary expect_forward, Trinary expect_distinct, Trinary expect_single)
+{
++verify_cnt;
auto fun = Function::parse(expr);
ASSERT_FALSE(fun->has_error());
@@ -134,10 +145,18 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
const ValueType &expected_type = node_types.get_type(fun->root());
ASSERT_FALSE(expected_type.is_error());
Stash stash;
- size_t count = 0;
+ std::vector<const TensorFunction *> list;
const TensorFunction &plain_fun = make_tensor_function(prod_factory, fun->root(), node_types, stash);
- const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash, &count);
- ASSERT_GT(count, 0);
+ const TensorFunction &optimized = apply_tensor_function_optimizer(plain_fun, universal_only().optimizer, stash,
+ [&list](const auto &node){
+ list.push_back(std::addressof(node));
+ });
+ ASSERT_EQ(list.size(), 1);
+ auto node = as<UniversalDotProduct>(*list[0]);
+ ASSERT_TRUE(node);
+ EXPECT_TRUE(satisfies(node->forward(), expect_forward));
+ EXPECT_TRUE(satisfies(node->distinct(), expect_distinct));
+ EXPECT_TRUE(satisfies(node->single(), expect_single));
InterpretedFunction ifun(prod_factory, optimized);
InterpretedFunction::Context ctx(ifun);
const Value &actual = ifun.eval(ctx, params);
@@ -152,7 +171,12 @@ void verify(const vespalib::string &expr, select_cell_type_t select_cell_type) {
auto expected = eval_ref(*fun, select_cell_type);
EXPECT_EQ(spec_from_value(actual), expected);
}
-void verify(const vespalib::string &expr) { verify(expr, always_double); }
+void verify(const vespalib::string &expr) {
+ verify(expr, always_double, Trinary::Undefined, Trinary::Undefined, Trinary::Undefined);
+}
+void verify(const vespalib::string &expr, select_cell_type_t select_cell_type, bool forward, bool distinct, bool single) {
+ verify(expr, select_cell_type, tri(forward), tri(distinct), tri(single));
+}
using cost_list_t = std::vector<std::pair<vespalib::string,double>>;
std::vector<std::pair<vespalib::string,cost_list_t>> benchmark_results;
@@ -192,8 +216,9 @@ void benchmark(const vespalib::string &expr, std::vector<Optimize> list) {
break;
case Optimize::With::SPECIFIC:
size_t count = 0;
- optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash, &count));
- ASSERT_GT(count, 0);
+ optimized = std::addressof(apply_tensor_function_optimizer(plain_fun, optimize.optimizer, stash,
+ [&count](const auto &)noexcept{ ++count; }));
+ ASSERT_EQ(count, 1);
break;
}
ASSERT_NE(optimized, nullptr);
@@ -255,36 +280,26 @@ TEST(UniversalDotProductTest, test_select_cell_types) {
}
TEST(UniversalDotProductTest, universal_dot_product_works_for_various_cases) {
- // forward, distinct, single
- verify("reduce(2.0*3.0, sum)");
+ // forward, distinct, single
+ verify("reduce(2.0*3.0, sum)", always_double, true, true, true);
for (CellType lct: CellTypeUtils::list_types()) {
for (CellType rct: CellTypeUtils::list_types()) {
auto sel2 = select(lct, rct);
- // !forward, !distinct, !single
- verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2);
-
- // !forward, !distinct, single
- verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2);
-
- // !forward, distinct, !single
- verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2);
-
- // forward, !distinct, !single
- verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2);
-
- // forward, !distinct, single
- verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2);
-
- // forward, distinct, !single
- verify("reduce(a4_1x8*x8,sum,x)", sel2);
+ // forward, distinct, single
+ verify("reduce(a4_1x8*a2_1x8,sum,a,x)", sel2, false, false, false);
+ verify("reduce(a4_1x8*a2_1x8,sum,a)", sel2, false, false, true);
+ verify("reduce(a4_1x8*a2_1x8,sum,x)", sel2, false, true, false);
+ verify("reduce(a4_1x8*b2_1x8,sum,b,x)", sel2, true, false, false);
+ verify("reduce(a4_1x8*b2_1x8,sum,b)", sel2, true, false, true);
+ verify("reduce(a4_1x8*x8,sum,x)", sel2, true, true, false);
}
}
// !forward, distinct, single
-
+ //
// This case is not possible since 'distinct' implies '!single' as
// long as we reduce anything. The only expression allowed to
- // reduce nothing is the scalar case.
+ // reduce nothing is the scalar case, which satisfies 'forward'
}
TEST(UniversalDotProductTest, universal_dot_product_works_with_complex_dimension_nesting) {
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
index 4013021aaa4..7255d308c81 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.cpp
@@ -140,7 +140,7 @@ const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factor
return optimize_tensor_function(factory, function, stash, OptimizeTensorFunctionOptions());
}
-const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, size_t *count) {
+const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, tensor_function_listener listener) {
Child root(function);
run_optimize_pass(root, [&](const Child &child)
{
@@ -148,9 +148,7 @@ const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &func
const TensorFunction &child_after = optimizer(child_before, stash);
if (std::addressof(child_after) != std::addressof(child_before)) {
child.set(child_after);
- if (count != nullptr) {
- ++(*count);
- }
+ listener(child_after);
}
});
return root.get();
diff --git a/eval/src/vespa/eval/eval/optimize_tensor_function.h b/eval/src/vespa/eval/eval/optimize_tensor_function.h
index 4a5945860e7..fd8c9b33d8c 100644
--- a/eval/src/vespa/eval/eval/optimize_tensor_function.h
+++ b/eval/src/vespa/eval/eval/optimize_tensor_function.h
@@ -22,6 +22,8 @@ const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factor
const TensorFunction &optimize_tensor_function(const ValueBuilderFactory &factory, const TensorFunction &function, Stash &stash);
using tensor_function_optimizer = std::function<const TensorFunction &(const TensorFunction &expr, Stash &stash)>;
-const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash, size_t *count = nullptr);
+using tensor_function_listener = std::function<void(const TensorFunction &expr)>;
+const TensorFunction &apply_tensor_function_optimizer(const TensorFunction &function, tensor_function_optimizer optimizer, Stash &stash,
+ tensor_function_listener = [](const TensorFunction &)noexcept{});
} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.cpp b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
index 414a54f09a8..e023609114a 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.cpp
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.cpp
@@ -40,6 +40,9 @@ struct UniversalDotProductParam {
dense_plan.res_stride.pop_back();
}
}
+ bool forward() const { return sparse_plan.maybe_forward_lhs_index(); }
+ bool distinct() const { return sparse_plan.is_distinct() && dense_plan.is_distinct(); }
+ bool single() const { return vector_size == 1; }
};
template <typename OCT>
@@ -204,12 +207,33 @@ UniversalDotProduct::compile_self(const ValueBuilderFactory &, Stash &stash) con
auto op = typify_invoke<6,MyTypify,SelectUniversalDotProduct>(lhs().result_type().cell_meta(),
rhs().result_type().cell_meta(),
result_type().cell_meta().is_scalar,
- param.sparse_plan.maybe_forward_lhs_index(),
- param.sparse_plan.is_distinct() && param.dense_plan.is_distinct(),
- param.vector_size == 1);
+ param.forward(),
+ param.distinct(),
+ param.single());
return InterpretedFunction::Instruction(op, wrap_param<UniversalDotProductParam>(param));
}
+bool
+UniversalDotProduct::forward() const
+{
+ UniversalDotProductParam param(result_type(), lhs().result_type(), rhs().result_type());
+ return param.forward();
+}
+
+bool
+UniversalDotProduct::distinct() const
+{
+ UniversalDotProductParam param(result_type(), lhs().result_type(), rhs().result_type());
+ return param.distinct();
+}
+
+bool
+UniversalDotProduct::single() const
+{
+ UniversalDotProductParam param(result_type(), lhs().result_type(), rhs().result_type());
+ return param.single();
+}
+
const TensorFunction &
UniversalDotProduct::optimize(const TensorFunction &expr, Stash &stash, bool force)
{
diff --git a/eval/src/vespa/eval/instruction/universal_dot_product.h b/eval/src/vespa/eval/instruction/universal_dot_product.h
index 40fd109cc73..2572ab47c65 100644
--- a/eval/src/vespa/eval/instruction/universal_dot_product.h
+++ b/eval/src/vespa/eval/instruction/universal_dot_product.h
@@ -19,6 +19,9 @@ public:
UniversalDotProduct(const ValueType &res_type, const TensorFunction &lhs, const TensorFunction &rhs);
InterpretedFunction::Instruction compile_self(const ValueBuilderFactory &factory, Stash &stash) const override;
bool result_is_mutable() const override { return true; }
+ bool forward() const;
+ bool distinct() const;
+ bool single() const;
static const TensorFunction &optimize(const TensorFunction &expr, Stash &stash, bool force);
};