summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorHaavard <havardpe@yahoo-inc.com>2017-03-01 10:19:09 +0000
committerHaavard <havardpe@yahoo-inc.com>2017-03-02 14:16:56 +0000
commite027f0b792c068bf505229165bf0eeea83cf5985 (patch)
tree8ee15b106c4e7ddc2c0f6d1d856ee9655e1d0d0d /eval
parentd95edfd1d1e83ba2e3654355ec574f9a1b99bb6b (diff)
analyze parameter usage
Diffstat (limited to 'eval')
-rw-r--r--eval/CMakeLists.txt3
-rw-r--r--eval/src/tests/eval/param_usage/CMakeLists.txt8
-rw-r--r--eval/src/tests/eval/param_usage/param_usage_test.cpp65
-rw-r--r--eval/src/vespa/eval/eval/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/eval/param_usage.cpp99
-rw-r--r--eval/src/vespa/eval/eval/param_usage.h35
6 files changed, 210 insertions, 1 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 5bf24a04202..79998163249 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -15,15 +15,16 @@ vespa_define_module(
src/tests/eval/gbdt
src/tests/eval/interpreted_function
src/tests/eval/node_types
+ src/tests/eval/param_usage
src/tests/eval/simple_tensor
src/tests/eval/tensor_function
src/tests/eval/value_cache
src/tests/eval/value_type
- src/tests/tensor/sparse_tensor_builder
src/tests/tensor/dense_dot_product_function
src/tests/tensor/dense_tensor_address_combiner
src/tests/tensor/dense_tensor_builder
src/tests/tensor/dense_tensor_function_compiler
+ src/tests/tensor/sparse_tensor_builder
src/tests/tensor/tensor_address
src/tests/tensor/tensor_conformance
src/tests/tensor/tensor_mapper
diff --git a/eval/src/tests/eval/param_usage/CMakeLists.txt b/eval/src/tests/eval/param_usage/CMakeLists.txt
new file mode 100644
index 00000000000..9ddc005dc87
--- /dev/null
+++ b/eval/src/tests/eval/param_usage/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_param_usage_test_app TEST
+ SOURCES
+ param_usage_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_param_usage_test_app COMMAND eval_param_usage_test_app)
diff --git a/eval/src/tests/eval/param_usage/param_usage_test.cpp b/eval/src/tests/eval/param_usage/param_usage_test.cpp
new file mode 100644
index 00000000000..ff0c6667279
--- /dev/null
+++ b/eval/src/tests/eval/param_usage/param_usage_test.cpp
@@ -0,0 +1,65 @@
+// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/function.h>
+#include <vespa/eval/eval/param_usage.h>
+#include <vespa/vespalib/test/insertion_operators.h>
+
+using vespalib::approx_equal;
+using namespace vespalib::eval;
+
+struct List {
+ std::vector<double> list;
+ List(std::vector<double> list_in) : list(std::move(list_in)) {}
+ bool operator==(const List &rhs) const {
+ if (list.size() != rhs.list.size()) {
+ return false;
+ }
+ for (size_t i = 0; i < list.size(); ++i) {
+ if (!approx_equal(list[i], rhs.list[i])) {
+ return false;
+ }
+ }
+ return true;
+ }
+};
+
+std::ostream &operator<<(std::ostream &out, const List &list) {
+ return out << list.list;
+}
+
+TEST("require that simple expression has appropriate parameter usage") {
+ std::vector<vespalib::string> params({"x", "y", "z"});
+ Function function = Function::parse(params, "(x+y)*y");
+ EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 2.0, 0.0}));
+ EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 1.0, 0.0}));
+}
+
+TEST("require that if children have 50% probability each by default") {
+ std::vector<vespalib::string> params({"x", "y", "z", "w"});
+ Function function = Function::parse(params, "if(w,(x+y)*y,(y+z)*z)");
+ EXPECT_EQUAL(List(count_param_usage(function)), List({0.5, 1.5, 1.0, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(function)), List({0.5, 1.0, 0.5, 1.0}));
+}
+
+TEST("require that if children probability can be adjusted") {
+ std::vector<vespalib::string> params({"x", "y", "z"});
+ Function function = Function::parse(params, "if(z,x*x,y*y,0.8)");
+ EXPECT_EQUAL(List(count_param_usage(function)), List({1.6, 0.4, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(function)), List({0.8, 0.2, 1.0}));
+}
+
+TEST("require that chained if statements are combined correctly") {
+ std::vector<vespalib::string> params({"x", "y", "z", "w"});
+ Function function = Function::parse(params, "if(z,x,y)+if(w,y,x)");
+ EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 1.0, 1.0, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(function)), List({0.75, 0.75, 1.0, 1.0}));
+}
+
+TEST("require that multi-level if statements are combined correctly") {
+ std::vector<vespalib::string> params({"x", "y", "z", "w"});
+ Function function = Function::parse(params, "if(z,if(w,y*x,x*x),if(w,y*x,x*x))");
+ EXPECT_EQUAL(List(count_param_usage(function)), List({1.5, 0.5, 1.0, 1.0}));
+ EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 0.5, 1.0, 1.0}));
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/CMakeLists.txt b/eval/src/vespa/eval/eval/CMakeLists.txt
index 9bcc3dd0742..145409af3e1 100644
--- a/eval/src/vespa/eval/eval/CMakeLists.txt
+++ b/eval/src/vespa/eval/eval/CMakeLists.txt
@@ -12,6 +12,7 @@ vespa_add_library(eval_eval OBJECT
node_types.cpp
operation.cpp
operator_nodes.cpp
+ param_usage.cpp
simple_tensor.cpp
simple_tensor_engine.cpp
tensor.cpp
diff --git a/eval/src/vespa/eval/eval/param_usage.cpp b/eval/src/vespa/eval/eval/param_usage.cpp
new file mode 100644
index 00000000000..35886df119e
--- /dev/null
+++ b/eval/src/vespa/eval/eval/param_usage.cpp
@@ -0,0 +1,99 @@
+// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/fastos/fastos.h>
+#include "param_usage.h"
+#include "function.h"
+#include "node_traverser.h"
+#include "basic_nodes.h"
+
+namespace vespalib {
+namespace eval {
+
+using namespace nodes;
+
+namespace {
+
+//-----------------------------------------------------------------------------
+
+struct CountUsage : NodeTraverser {
+ double p;
+ std::vector<double> result;
+ CountUsage(size_t num_params) : p(1.0), result(num_params, 0.0) {}
+ bool open(const Node &node) override {
+ if (auto if_node = as<If>(node)) {
+ double my_p = p;
+ if_node->cond().traverse(*this);
+ p = my_p * if_node->p_true();
+ if_node->true_expr().traverse(*this);
+ p = my_p * (1 - if_node->p_true());
+ if_node->false_expr().traverse(*this);
+ p = my_p;
+ return false;
+ }
+ return true;
+ }
+ void close(const Node &node) override {
+ auto symbol = as<Symbol>(node);
+ if (symbol && (symbol->id() >= 0)) {
+ result[symbol->id()] += p;
+ }
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+struct CheckUsage : NodeTraverser {
+ std::vector<double> result;
+ CheckUsage(size_t num_params) : result(num_params) {}
+ void merge(const std::vector<double> &true_result,
+ const std::vector<double> &false_result,
+ double p_true)
+ {
+ for (size_t i = 0; i < result.size(); ++i) {
+ double p_mixed = (true_result[i] * p_true) + (false_result[i] * (1 - p_true));
+ double p_not_used = (1 - result[i]) * (1 - p_mixed);
+ result[i] = (1 - p_not_used);
+ }
+ }
+ bool open(const Node &node) override {
+ if (auto if_node = as<If>(node)) {
+ if_node->cond().traverse(*this);
+ CheckUsage check_true(result.size());
+ if_node->true_expr().traverse(check_true);
+ CheckUsage check_false(result.size());
+ if_node->false_expr().traverse(check_false);
+ merge(check_true.result, check_false.result, if_node->p_true());
+ return false;
+ }
+ return true;
+ }
+ void close(const Node &node) override {
+ auto symbol = as<Symbol>(node);
+ if (symbol && (symbol->id() >= 0)) {
+ result[symbol->id()] = 1.0;
+ }
+ }
+};
+
+//-----------------------------------------------------------------------------
+
+} // namespace vespalib::eval::<unnamed>
+
+std::vector<double>
+count_param_usage(const Function &function)
+{
+ CountUsage count_usage(function.num_params());
+ function.root().traverse(count_usage);
+ return count_usage.result;
+}
+
+std::vector<double>
+check_param_usage(const Function &function)
+{
+ CheckUsage check_usage(function.num_params());
+ function.root().traverse(check_usage);
+ return check_usage.result;
+}
+
+} // namespace vespalib::eval
+} // namespace vespalib
diff --git a/eval/src/vespa/eval/eval/param_usage.h b/eval/src/vespa/eval/eval/param_usage.h
new file mode 100644
index 00000000000..e7ed82907ed
--- /dev/null
+++ b/eval/src/vespa/eval/eval/param_usage.h
@@ -0,0 +1,35 @@
+// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vector>
+
+namespace vespalib {
+namespace eval {
+
+class Function;
+
+/**
+ * Calculate the expected number of times each parameter will be
+ * used. Note: Correlation between condition checks and effects of
+ * short-circuit evaluation and constant value optimizations are not
+ * taken into account.
+ *
+ * @return expected parameter usage per parameter
+ * @param function the function to analyze
+ **/
+std::vector<double> count_param_usage(const Function &function);
+
+/**
+ * Calculate the probability that each parameter will be used. Note:
+ * Correlation between condition checks and effects of short-circuit
+ * evaluation and constant value optimizations are not taken into
+ * account.
+ *
+ * @return parameter usage probability per parameter
+ * @param function the function to analyze
+ **/
+std::vector<double> check_param_usage(const Function &function);
+
+} // namespace vespalib::eval
+} // namespace vespalib