diff options
author | Haavard <havardpe@yahoo-inc.com> | 2017-03-01 10:19:09 +0000 |
---|---|---|
committer | Haavard <havardpe@yahoo-inc.com> | 2017-03-02 14:16:56 +0000 |
commit | e027f0b792c068bf505229165bf0eeea83cf5985 (patch) | |
tree | 8ee15b106c4e7ddc2c0f6d1d856ee9655e1d0d0d /eval/src | |
parent | d95edfd1d1e83ba2e3654355ec574f9a1b99bb6b (diff) |
analyze parameter usage
Diffstat (limited to 'eval/src')
-rw-r--r-- | eval/src/tests/eval/param_usage/CMakeLists.txt | 8 | ||||
-rw-r--r-- | eval/src/tests/eval/param_usage/param_usage_test.cpp | 65 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/CMakeLists.txt | 1 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/param_usage.cpp | 99 | ||||
-rw-r--r-- | eval/src/vespa/eval/eval/param_usage.h | 35 |
5 files changed, 208 insertions, 0 deletions
diff --git a/eval/src/tests/eval/param_usage/CMakeLists.txt b/eval/src/tests/eval/param_usage/CMakeLists.txt new file mode 100644 index 00000000000..9ddc005dc87 --- /dev/null +++ b/eval/src/tests/eval/param_usage/CMakeLists.txt @@ -0,0 +1,8 @@ +# Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +vespa_add_executable(eval_param_usage_test_app TEST + SOURCES + param_usage_test.cpp + DEPENDS + vespaeval +) +vespa_add_test(NAME eval_param_usage_test_app COMMAND eval_param_usage_test_app) diff --git a/eval/src/tests/eval/param_usage/param_usage_test.cpp b/eval/src/tests/eval/param_usage/param_usage_test.cpp new file mode 100644 index 00000000000..ff0c6667279 --- /dev/null +++ b/eval/src/tests/eval/param_usage/param_usage_test.cpp @@ -0,0 +1,65 @@ +// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +#include <vespa/vespalib/testkit/test_kit.h> +#include <vespa/eval/eval/function.h> +#include <vespa/eval/eval/param_usage.h> +#include <vespa/vespalib/test/insertion_operators.h> + +using vespalib::approx_equal; +using namespace vespalib::eval; + +struct List { + std::vector<double> list; + List(std::vector<double> list_in) : list(std::move(list_in)) {} + bool operator==(const List &rhs) const { + if (list.size() != rhs.list.size()) { + return false; + } + for (size_t i = 0; i < list.size(); ++i) { + if (!approx_equal(list[i], rhs.list[i])) { + return false; + } + } + return true; + } +}; + +std::ostream &operator<<(std::ostream &out, const List &list) { + return out << list.list; +} + +TEST("require that simple expression has appropriate parameter usage") { + std::vector<vespalib::string> params({"x", "y", "z"}); + Function function = Function::parse(params, "(x+y)*y"); + EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 2.0, 0.0})); + EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 1.0, 0.0})); +} + +TEST("require that if children have 50% probability each by default") { + std::vector<vespalib::string> params({"x", "y", "z", "w"}); + Function function = Function::parse(params, "if(w,(x+y)*y,(y+z)*z)"); + EXPECT_EQUAL(List(count_param_usage(function)), List({0.5, 1.5, 1.0, 1.0})); + EXPECT_EQUAL(List(check_param_usage(function)), List({0.5, 1.0, 0.5, 1.0})); +} + +TEST("require that if children probability can be adjusted") { + std::vector<vespalib::string> params({"x", "y", "z"}); + Function function = Function::parse(params, "if(z,x*x,y*y,0.8)"); + EXPECT_EQUAL(List(count_param_usage(function)), List({1.6, 0.4, 1.0})); + EXPECT_EQUAL(List(check_param_usage(function)), List({0.8, 0.2, 1.0})); +} + +TEST("require that chained if statements are combined correctly") { + std::vector<vespalib::string> params({"x", "y", "z", "w"}); + Function function = Function::parse(params, "if(z,x,y)+if(w,y,x)"); + EXPECT_EQUAL(List(count_param_usage(function)), List({1.0, 1.0, 1.0, 1.0})); + EXPECT_EQUAL(List(check_param_usage(function)), List({0.75, 0.75, 1.0, 1.0})); +} + +TEST("require that multi-level if statements are combined correctly") { + std::vector<vespalib::string> params({"x", "y", "z", "w"}); + Function function = Function::parse(params, "if(z,if(w,y*x,x*x),if(w,y*x,x*x))"); + EXPECT_EQUAL(List(count_param_usage(function)), List({1.5, 0.5, 1.0, 1.0})); + EXPECT_EQUAL(List(check_param_usage(function)), List({1.0, 0.5, 1.0, 1.0})); +} + +TEST_MAIN() { TEST_RUN_ALL(); } diff --git a/eval/src/vespa/eval/eval/CMakeLists.txt b/eval/src/vespa/eval/eval/CMakeLists.txt index 9bcc3dd0742..145409af3e1 100644 --- a/eval/src/vespa/eval/eval/CMakeLists.txt +++ b/eval/src/vespa/eval/eval/CMakeLists.txt @@ -12,6 +12,7 @@ vespa_add_library(eval_eval OBJECT node_types.cpp operation.cpp operator_nodes.cpp + param_usage.cpp simple_tensor.cpp simple_tensor_engine.cpp tensor.cpp diff --git a/eval/src/vespa/eval/eval/param_usage.cpp b/eval/src/vespa/eval/eval/param_usage.cpp new file mode 100644 index 00000000000..35886df119e --- /dev/null +++ b/eval/src/vespa/eval/eval/param_usage.cpp @@ -0,0 +1,99 @@ +// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#include <vespa/fastos/fastos.h> +#include "param_usage.h" +#include "function.h" +#include "node_traverser.h" +#include "basic_nodes.h" + +namespace vespalib { +namespace eval { + +using namespace nodes; + +namespace { + +//----------------------------------------------------------------------------- + +struct CountUsage : NodeTraverser { + double p; + std::vector<double> result; + CountUsage(size_t num_params) : p(1.0), result(num_params, 0.0) {} + bool open(const Node &node) override { + if (auto if_node = as<If>(node)) { + double my_p = p; + if_node->cond().traverse(*this); + p = my_p * if_node->p_true(); + if_node->true_expr().traverse(*this); + p = my_p * (1 - if_node->p_true()); + if_node->false_expr().traverse(*this); + p = my_p; + return false; + } + return true; + } + void close(const Node &node) override { + auto symbol = as<Symbol>(node); + if (symbol && (symbol->id() >= 0)) { + result[symbol->id()] += p; + } + } +}; + +//----------------------------------------------------------------------------- + +struct CheckUsage : NodeTraverser { + std::vector<double> result; + CheckUsage(size_t num_params) : result(num_params) {} + void merge(const std::vector<double> &true_result, + const std::vector<double> &false_result, + double p_true) + { + for (size_t i = 0; i < result.size(); ++i) { + double p_mixed = (true_result[i] * p_true) + (false_result[i] * (1 - p_true)); + double p_not_used = (1 - result[i]) * (1 - p_mixed); + result[i] = (1 - p_not_used); + } + } + bool open(const Node &node) override { + if (auto if_node = as<If>(node)) { + if_node->cond().traverse(*this); + CheckUsage check_true(result.size()); + if_node->true_expr().traverse(check_true); + CheckUsage check_false(result.size()); + if_node->false_expr().traverse(check_false); + merge(check_true.result, check_false.result, if_node->p_true()); + return false; + } + return true; + } + void close(const Node &node) override { + auto symbol = as<Symbol>(node); + if (symbol && (symbol->id() >= 0)) { + result[symbol->id()] = 1.0; + } + } +}; + +//----------------------------------------------------------------------------- + +} // namespace vespalib::eval::<unnamed> + +std::vector<double> +count_param_usage(const Function &function) +{ + CountUsage count_usage(function.num_params()); + function.root().traverse(count_usage); + return count_usage.result; +} + +std::vector<double> +check_param_usage(const Function &function) +{ + CheckUsage check_usage(function.num_params()); + function.root().traverse(check_usage); + return check_usage.result; +} + +} // namespace vespalib::eval +} // namespace vespalib diff --git a/eval/src/vespa/eval/eval/param_usage.h b/eval/src/vespa/eval/eval/param_usage.h new file mode 100644 index 00000000000..e7ed82907ed --- /dev/null +++ b/eval/src/vespa/eval/eval/param_usage.h @@ -0,0 +1,35 @@ +// Copyright 2017 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +#pragma once + +#include <vector> + +namespace vespalib { +namespace eval { + +class Function; + +/** + * Calculate the expected number of times each parameter will be + * used. Note: Correlation between condition checks and effects of + * short-circuit evaluation and constant value optimizations are not + * taken into account. + * + * @return expected parameter usage per parameter + * @param function the function to analyze + **/ +std::vector<double> count_param_usage(const Function &function); + +/** + * Calculate the probability that each parameter will be used. Note: + * Correlation between condition checks and effects of short-circuit + * evaluation and constant value optimizations are not taken into + * account. + * + * @return parameter usage probability per parameter + * @param function the function to analyze + **/ +std::vector<double> check_param_usage(const Function &function); + +} // namespace vespalib::eval +} // namespace vespalib |