summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2020-04-30 13:42:32 +0200
committerGitHub <noreply@github.com>2020-04-30 13:42:32 +0200
commit417021f18f70f1a4dcc0690911430bdad2d14317 (patch)
tree771eb0bccedcd90740032bad84f06fb8334f86da
parent1e98945c9313adeae298b9be79b986678fb6ca39 (diff)
parent6cf232863ef9047412aa62cc6465ebbdd719a642 (diff)
Merge pull request #13120 from vespa-engine/havardpe/tensor-lambda-peek-optimizer
lambda peek optimizer
-rw-r--r--eval/CMakeLists.txt1
-rw-r--r--eval/src/tests/eval/node_tools/CMakeLists.txt8
-rw-r--r--eval/src/tests/eval/node_tools/node_tools_test.cpp120
-rw-r--r--eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp112
-rw-r--r--eval/src/vespa/eval/eval/CMakeLists.txt1
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.cpp12
-rw-r--r--eval/src/vespa/eval/eval/llvm/compiled_function.h13
-rw-r--r--eval/src/vespa/eval/eval/node_tools.cpp209
-rw-r--r--eval/src/vespa/eval/eval/node_tools.h16
-rw-r--r--eval/src/vespa/eval/eval/tensor_function.h3
-rw-r--r--eval/src/vespa/eval/eval/test/eval_fixture.cpp9
-rw-r--r--eval/src/vespa/eval/tensor/dense/CMakeLists.txt4
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_cell_range_function.cpp53
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_cell_range_function.h31
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.cpp89
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.h31
-rw-r--r--eval/src/vespa/eval/tensor/dense/dense_lambda_peek_optimizer.cpp196
17 files changed, 878 insertions, 30 deletions
diff --git a/eval/CMakeLists.txt b/eval/CMakeLists.txt
index 8eb198a9a0b..bb13638cf1d 100644
--- a/eval/CMakeLists.txt
+++ b/eval/CMakeLists.txt
@@ -18,6 +18,7 @@ vespa_define_module(
src/tests/eval/function_speed
src/tests/eval/gbdt
src/tests/eval/interpreted_function
+ src/tests/eval/node_tools
src/tests/eval/node_types
src/tests/eval/param_usage
src/tests/eval/simple_tensor
diff --git a/eval/src/tests/eval/node_tools/CMakeLists.txt b/eval/src/tests/eval/node_tools/CMakeLists.txt
new file mode 100644
index 00000000000..437aea3ac8f
--- /dev/null
+++ b/eval/src/tests/eval/node_tools/CMakeLists.txt
@@ -0,0 +1,8 @@
+# Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+vespa_add_executable(eval_node_tools_test_app TEST
+ SOURCES
+ node_tools_test.cpp
+ DEPENDS
+ vespaeval
+)
+vespa_add_test(NAME eval_node_tools_test_app COMMAND eval_node_tools_test_app)
diff --git a/eval/src/tests/eval/node_tools/node_tools_test.cpp b/eval/src/tests/eval/node_tools/node_tools_test.cpp
new file mode 100644
index 00000000000..ca89650127e
--- /dev/null
+++ b/eval/src/tests/eval/node_tools/node_tools_test.cpp
@@ -0,0 +1,120 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include <vespa/vespalib/testkit/test_kit.h>
+#include <vespa/eval/eval/function.h>
+#include <vespa/eval/eval/node_tools.h>
+
+using namespace vespalib::eval;
+
+auto make_copy(const Function &fun) {
+ std::vector<vespalib::string> params;
+ for (size_t i = 0; i < fun.num_params(); ++i) {
+ params.push_back(fun.param_name(i));
+ }
+ return Function::create(NodeTools::copy(fun.root()), params);
+}
+
+void verify_copy(const vespalib::string &expr, const vespalib::string &expect) {
+ auto fun = Function::parse(expr);
+ auto fun_copy = make_copy(*fun);
+ EXPECT_EQUAL(fun_copy->dump(), expect);
+}
+void verify_copy(const vespalib::string &expr) { verify_copy(expr, expr); }
+
+TEST("require that required parameter count can be detected") {
+ auto function = Function::parse({"a","b","c"}, "(c+a)+(b+1)");
+ const auto &root = function->root();
+ ASSERT_EQUAL(root.num_children(), 2u);
+ const auto &n_c_a = root.get_child(0);
+ const auto &n_b_1 = root.get_child(1);
+ ASSERT_EQUAL(n_c_a.num_children(), 2u);
+ const auto &n_c = n_c_a.get_child(0);
+ const auto &n_a = n_c_a.get_child(1);
+ ASSERT_EQUAL(n_b_1.num_children(), 2u);
+ const auto &n_b = n_b_1.get_child(0);
+ const auto &n_1 = n_b_1.get_child(1);
+ EXPECT_EQUAL(NodeTools::min_num_params(root), 3u);
+ EXPECT_EQUAL(NodeTools::min_num_params(n_c_a), 3u);
+ EXPECT_EQUAL(NodeTools::min_num_params(n_b_1), 2u);
+ EXPECT_EQUAL(NodeTools::min_num_params(n_c), 3u);
+ EXPECT_EQUAL(NodeTools::min_num_params(n_a), 1u);
+ EXPECT_EQUAL(NodeTools::min_num_params(n_b), 2u);
+ EXPECT_EQUAL(NodeTools::min_num_params(n_1), 0u);
+}
+
+TEST("require that basic node types can be copied") {
+ TEST_DO(verify_copy("123"));
+ TEST_DO(verify_copy("foo"));
+ TEST_DO(verify_copy("\"string value\""));
+ TEST_DO(verify_copy("(a in [1,\"2\",3])"));
+ TEST_DO(verify_copy("(-a)"));
+ TEST_DO(verify_copy("(!a)"));
+ TEST_DO(verify_copy("if(a,b,c)"));
+ TEST_DO(verify_copy("if(a,b,c,0.7)"));
+ TEST_DO(verify_copy("#", "[]...[missing value]...[#]"));
+}
+
+TEST("require that operator node types can be copied") {
+ TEST_DO(verify_copy("(a+b)"));
+ TEST_DO(verify_copy("(a-b)"));
+ TEST_DO(verify_copy("(a*b)"));
+ TEST_DO(verify_copy("(a/b)"));
+ TEST_DO(verify_copy("(a%b)"));
+ TEST_DO(verify_copy("(a^b)"));
+ TEST_DO(verify_copy("(a==b)"));
+ TEST_DO(verify_copy("(a!=b)"));
+ TEST_DO(verify_copy("(a~=b)"));
+ TEST_DO(verify_copy("(a<b)"));
+ TEST_DO(verify_copy("(a<=b)"));
+ TEST_DO(verify_copy("(a>b)"));
+ TEST_DO(verify_copy("(a>=b)"));
+ TEST_DO(verify_copy("(a&&b)"));
+ TEST_DO(verify_copy("(a||b)"));
+}
+
+TEST("require that call node types can be copied") {
+ TEST_DO(verify_copy("cos(a)"));
+ TEST_DO(verify_copy("sin(a)"));
+ TEST_DO(verify_copy("tan(a)"));
+ TEST_DO(verify_copy("cosh(a)"));
+ TEST_DO(verify_copy("sinh(a)"));
+ TEST_DO(verify_copy("tanh(a)"));
+ TEST_DO(verify_copy("acos(a)"));
+ TEST_DO(verify_copy("asin(a)"));
+ TEST_DO(verify_copy("atan(a)"));
+ TEST_DO(verify_copy("exp(a)"));
+ TEST_DO(verify_copy("log10(a)"));
+ TEST_DO(verify_copy("log(a)"));
+ TEST_DO(verify_copy("sqrt(a)"));
+ TEST_DO(verify_copy("ceil(a)"));
+ TEST_DO(verify_copy("fabs(a)"));
+ TEST_DO(verify_copy("floor(a)"));
+ TEST_DO(verify_copy("atan2(a,b)"));
+ TEST_DO(verify_copy("ldexp(a,b)"));
+ TEST_DO(verify_copy("pow(a,b)"));
+ TEST_DO(verify_copy("fmod(a,b)"));
+ TEST_DO(verify_copy("min(a,b)"));
+ TEST_DO(verify_copy("max(a,b)"));
+ TEST_DO(verify_copy("isNan(a)"));
+ TEST_DO(verify_copy("relu(a)"));
+ TEST_DO(verify_copy("sigmoid(a)"));
+ TEST_DO(verify_copy("elu(a)"));
+}
+
+TEST("require that tensor node types can NOT be copied (yet)") {
+ TEST_DO(verify_copy("map(a,f(x)(x))", "not implemented"));
+ TEST_DO(verify_copy("join(a,b,f(x,y)(x*y))", "not implemented"));
+ TEST_DO(verify_copy("merge(a,b,f(x,y)(y))", "not implemented"));
+ TEST_DO(verify_copy("reduce(a,sum)", "not implemented"));
+ TEST_DO(verify_copy("rename(a,x,y)", "not implemented"));
+ TEST_DO(verify_copy("concat(a,b,x)", "not implemented"));
+ TEST_DO(verify_copy("tensor(x[3]):[1,2,3]", "not implemented"));
+ TEST_DO(verify_copy("tensor(x[3])(x)", "not implemented"));
+ TEST_DO(verify_copy("a{x:0}", "not implemented"));
+}
+
+TEST("require that nested expressions can be copied") {
+ TEST_DO(verify_copy("min(a,if(((b+3)==7),(!c),(d+7)))"));
+}
+
+TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
index 5b0f2cf0a7e..3c35f90c521 100644
--- a/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
+++ b/eval/src/tests/eval/tensor_lambda/tensor_lambda_test.cpp
@@ -6,6 +6,8 @@
#include <vespa/eval/eval/simple_tensor_engine.h>
#include <vespa/eval/tensor/default_tensor_engine.h>
#include <vespa/eval/tensor/dense/dense_replace_type_function.h>
+#include <vespa/eval/tensor/dense/dense_cell_range_function.h>
+#include <vespa/eval/tensor/dense/dense_lambda_peek_function.h>
#include <vespa/eval/tensor/dense/dense_fast_rename_optimizer.h>
#include <vespa/eval/tensor/dense/dense_tensor.h>
#include <vespa/eval/eval/test/tensor_model.hpp>
@@ -27,42 +29,81 @@ EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
.add("a", spec(1))
.add("x3", spec({x(3)}, N()))
- .add("x3f", spec(float_cells({x(3)}), N()));
+ .add("x3f", spec(float_cells({x(3)}), N()))
+ .add("x3m", spec({x({"0", "1", "2"})}, N()))
+ .add("x3y5", spec({x(3), y(5)}, N()))
+ .add("x3y5f", spec(float_cells({x(3), y(5)}), N()))
+ .add("x15", spec({x(15)}, N()))
+ .add("x15f", spec(float_cells({x(15)}), N()));
}
EvalFixture::ParamRepo param_repo = make_params();
-void verify_dynamic(const vespalib::string &expr, const vespalib::string &expect) {
+template <typename T, typename F>
+void verify_impl(const vespalib::string &expr, const vespalib::string &expect, F &&inspect) {
EvalFixture fixture(prod_engine, expr, param_repo, true);
+ EvalFixture slow_fixture(prod_engine, expr, param_repo, false);
+ EXPECT_EQUAL(fixture.result(), slow_fixture.result());
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expect, param_repo));
- auto info = fixture.find_all<Lambda>();
- EXPECT_EQUAL(info.size(), 1u);
+ auto info = fixture.find_all<T>();
+ if (EXPECT_EQUAL(info.size(), 1u)) {
+ inspect(info[0]);
+ }
+}
+template <typename T>
+void verify_impl(const vespalib::string &expr, const vespalib::string &expect) {
+ verify_impl<T>(expr, expect, [](const T*){});
+}
+
+void verify_generic(const vespalib::string &expr, const vespalib::string &expect) {
+ verify_impl<Lambda>(expr, expect);
+}
+
+void verify_reshape(const vespalib::string &expr, const vespalib::string &expect) {
+ verify_impl<DenseReplaceTypeFunction>(expr, expect);
+}
+
+void verify_range(const vespalib::string &expr, const vespalib::string &expect) {
+ verify_impl<DenseCellRangeFunction>(expr, expect);
+}
+
+void verify_compiled(const vespalib::string &expr, const vespalib::string &expect,
+ const vespalib::string &expect_idx_fun)
+{
+ verify_impl<DenseLambdaPeekFunction>(expr, expect,
+ [&](const DenseLambdaPeekFunction *info)
+ {
+ EXPECT_EQUAL(info->idx_fun_dump(), expect_idx_fun);
+ });
}
void verify_const(const vespalib::string &expr, const vespalib::string &expect) {
- EvalFixture fixture(prod_engine, expr, param_repo, true);
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expr, param_repo));
- EXPECT_EQUAL(fixture.result(), EvalFixture::ref(expect, param_repo));
- auto info = fixture.find_all<ConstValue>();
- EXPECT_EQUAL(info.size(), 1u);
+ verify_impl<ConstValue>(expr, expect);
}
+//-----------------------------------------------------------------------------
+
TEST("require that simple constant tensor lambda works") {
TEST_DO(verify_const("tensor(x[3])(x+1)", "tensor(x[3]):[1,2,3]"));
}
TEST("require that simple dynamic tensor lambda works") {
- TEST_DO(verify_dynamic("tensor(x[3])(x+a)", "tensor(x[3]):[1,2,3]"));
+ TEST_DO(verify_generic("tensor(x[3])(x+a)", "tensor(x[3]):[1,2,3]"));
}
TEST("require that tensor lambda can be used for tensor slicing") {
- TEST_DO(verify_dynamic("tensor(x[2])(x3{x:(x+a)})", "tensor(x[2]):[2,3]"));
- TEST_DO(verify_dynamic("tensor(x[2])(a+x3{x:(x)})", "tensor(x[2]):[2,3]"));
+ TEST_DO(verify_generic("tensor(x[2])(x3{x:(x+a)})", "tensor(x[2]):[2,3]"));
+ TEST_DO(verify_generic("tensor(x[2])(a+x3{x:(x)})", "tensor(x[2]):[2,3]"));
}
-TEST("require that tensor lambda can be used for tensor casting") {
- TEST_DO(verify_dynamic("tensor(x[3])(x3f{x:(x)})", "tensor(x[3]):[1,2,3]"));
- TEST_DO(verify_dynamic("tensor<float>(x[3])(x3{x:(x)})", "tensor<float>(x[3]):[1,2,3]"));
+TEST("require that tensor lambda can be used for cell type casting") {
+ TEST_DO(verify_compiled("tensor(x[3])(x3f{x:(x)})", "tensor(x[3]):[1,2,3]", "f(x)(x)"));
+ TEST_DO(verify_compiled("tensor<float>(x[3])(x3{x:(x)})", "tensor<float>(x[3]):[1,2,3]", "f(x)(x)"));
+}
+
+TEST("require that tensor lambda can be used to convert from sparse to dense tensors") {
+ TEST_DO(verify_generic("tensor(x[3])(x3m{x:(x)})", "tensor(x[3]):[1,2,3]"));
+ TEST_DO(verify_generic("tensor(x[2])(x3m{x:(x)})", "tensor(x[2]):[1,2]"));
}
TEST("require that constant nested tensor lambda using tensor peek works") {
@@ -70,7 +111,46 @@ TEST("require that constant nested tensor lambda using tensor peek works") {
}
TEST("require that dynamic nested tensor lambda using tensor peek works") {
- TEST_DO(verify_dynamic("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})", "tensor(x[2]):[1,3]"));
+ TEST_DO(verify_generic("tensor(x[2])(tensor(y[2])((x+y)+a){y:(x)})", "tensor(x[2]):[1,3]"));
+}
+
+TEST("require that tensor reshape is optimized") {
+ TEST_DO(verify_reshape("tensor(x[15])(x3y5{x:(x/5),y:(x%5)})", "x15"));
+ TEST_DO(verify_reshape("tensor(x[3],y[5])(x15{x:(x*5+y)})", "x3y5"));
+ TEST_DO(verify_reshape("tensor<float>(x[15])(x3y5f{x:(x/5),y:(x%5)})", "x15f"));
+}
+
+TEST("require that tensor reshape with non-matching cell type requires cell copy") {
+ TEST_DO(verify_compiled("tensor(x[15])(x3y5f{x:(x/5),y:(x%5)})", "x15", "f(x)((floor((x/5))*5)+(x%5))"));
+ TEST_DO(verify_compiled("tensor<float>(x[15])(x3y5{x:(x/5),y:(x%5)})", "x15f", "f(x)((floor((x/5))*5)+(x%5))"));
+ TEST_DO(verify_compiled("tensor(x[3],y[5])(x15f{x:(x*5+y)})", "x3y5", "f(x,y)((x*5)+y)"));
+ TEST_DO(verify_compiled("tensor<float>(x[3],y[5])(x15{x:(x*5+y)})", "x3y5f", "f(x,y)((x*5)+y)"));
+}
+
+TEST("require that tensor cell subrange view is optimized") {
+ TEST_DO(verify_range("tensor(y[5])(x3y5{x:1,y:(y)})", "x3y5{x:1}"));
+ TEST_DO(verify_range("tensor(x[3])(x15{x:(x+5)})", "tensor(x[3]):[6,7,8]"));
+ TEST_DO(verify_range("tensor<float>(y[5])(x3y5f{x:1,y:(y)})", "x3y5f{x:1}"));
+ TEST_DO(verify_range("tensor<float>(x[3])(x15f{x:(x+5)})", "tensor<float>(x[3]):[6,7,8]"));
+}
+
+TEST("require that tensor cell subrange with non-matching cell type requires cell copy") {
+ TEST_DO(verify_compiled("tensor(x[3])(x15f{x:(x+5)})", "tensor(x[3]):[6,7,8]", "f(x)(x+5)"));
+ TEST_DO(verify_compiled("tensor<float>(x[3])(x15{x:(x+5)})", "tensor<float>(x[3]):[6,7,8]", "f(x)(x+5)"));
+}
+
+TEST("require that non-continuous cell extraction is optimized") {
+ TEST_DO(verify_compiled("tensor(x[3])(x3y5{x:(x),y:2})", "x3y5{y:2}", "f(x)((floor(x)*5)+2)"));
+ TEST_DO(verify_compiled("tensor(x[3])(x3y5f{x:(x),y:2})", "x3y5{y:2}", "f(x)((floor(x)*5)+2)"));
+ TEST_DO(verify_compiled("tensor<float>(x[3])(x3y5{x:(x),y:2})", "x3y5f{y:2}", "f(x)((floor(x)*5)+2)"));
+ TEST_DO(verify_compiled("tensor<float>(x[3])(x3y5f{x:(x),y:2})", "x3y5f{y:2}", "f(x)((floor(x)*5)+2)"));
+}
+
+TEST("require that out-of-bounds cell extraction is not optimized") {
+ TEST_DO(verify_generic("tensor(x[3])(x3y5{x:1,y:(x+3)})", "tensor(x[3]):[9,10,0]"));
+ TEST_DO(verify_generic("tensor(x[3])(x3y5{x:1,y:(x-1)})", "tensor(x[3]):[0,6,7]"));
+ TEST_DO(verify_generic("tensor(x[3])(x3y5{x:(x+1),y:(x)})", "tensor(x[3]):[6,12,0]"));
+ TEST_DO(verify_generic("tensor(x[3])(x3y5{x:(x-1),y:(x)})", "tensor(x[3]):[0,2,8]"));
}
TEST("require that non-double result from inner tensor lambda function fails type resolving") {
diff --git a/eval/src/vespa/eval/eval/CMakeLists.txt b/eval/src/vespa/eval/eval/CMakeLists.txt
index 20a7141393f..002a027c3a9 100644
--- a/eval/src/vespa/eval/eval/CMakeLists.txt
+++ b/eval/src/vespa/eval/eval/CMakeLists.txt
@@ -13,6 +13,7 @@ vespa_add_library(eval_eval OBJECT
key_gen.cpp
lazy_params.cpp
make_tensor_function.cpp
+ node_tools.cpp
node_types.cpp
operation.cpp
operator_nodes.cpp
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
index 4101cf10e1f..043ad248251 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.cpp
@@ -31,16 +31,16 @@ double my_resolve(void *ctx, size_t idx) { return ((double *)ctx)[idx]; }
} // namespace vespalib::eval::<unnamed>
-CompiledFunction::CompiledFunction(const Function &function_in, PassParams pass_params_in,
+CompiledFunction::CompiledFunction(const nodes::Node &root_in, size_t num_params_in, PassParams pass_params_in,
const gbdt::Optimize::Chain &forest_optimizers)
: _llvm_wrapper(),
_address(nullptr),
- _num_params(function_in.num_params()),
+ _num_params(num_params_in),
_pass_params(pass_params_in)
{
- size_t id = _llvm_wrapper.make_function(function_in.num_params(),
+ size_t id = _llvm_wrapper.make_function(num_params_in,
_pass_params,
- function_in.root(),
+ root_in,
forest_optimizers);
_llvm_wrapper.compile();
_address = _llvm_wrapper.get_function_address(id);
@@ -120,7 +120,7 @@ CompiledFunction::estimate_cost_us(const std::vector<double> &params, double bud
}
Function::Issues
-CompiledFunction::detect_issues(const Function &function)
+CompiledFunction::detect_issues(const nodes::Node &node)
{
struct NotSupported : NodeTraverser {
std::vector<vespalib::string> issues;
@@ -141,7 +141,7 @@ CompiledFunction::detect_issues(const Function &function)
}
}
} checker;
- function.root().traverse(checker);
+ node.traverse(checker);
return Function::Issues(std::move(checker.issues));
}
diff --git a/eval/src/vespa/eval/eval/llvm/compiled_function.h b/eval/src/vespa/eval/eval/llvm/compiled_function.h
index 40737e3ff92..6fb68e2df72 100644
--- a/eval/src/vespa/eval/eval/llvm/compiled_function.h
+++ b/eval/src/vespa/eval/eval/llvm/compiled_function.h
@@ -38,10 +38,14 @@ private:
public:
typedef std::unique_ptr<CompiledFunction> UP;
- CompiledFunction(const Function &function_in, PassParams pass_params_in,
+ CompiledFunction(const nodes::Node &root_in, size_t num_params_in, PassParams pass_params_in,
const gbdt::Optimize::Chain &forest_optimizers);
+ CompiledFunction(const Function &function_in, PassParams pass_params_in, const gbdt::Optimize::Chain &forest_optimizers)
+ : CompiledFunction(function_in.root(), function_in.num_params(), pass_params_in, forest_optimizers) {}
+ CompiledFunction(const nodes::Node &root_in, size_t num_params_in, PassParams pass_params_in)
+ : CompiledFunction(root_in, num_params_in, pass_params_in, gbdt::Optimize::best) {}
CompiledFunction(const Function &function_in, PassParams pass_params_in)
- : CompiledFunction(function_in, pass_params_in, gbdt::Optimize::best) {}
+ : CompiledFunction(function_in.root(), function_in.num_params(), pass_params_in, gbdt::Optimize::best) {}
CompiledFunction(CompiledFunction &&rhs);
size_t num_params() const { return _num_params; }
PassParams pass_params() const { return _pass_params; }
@@ -63,7 +67,10 @@ public:
return _llvm_wrapper.get_forests();
}
double estimate_cost_us(const std::vector<double> &params, double budget = 5.0) const;
- static Function::Issues detect_issues(const Function &function);
+ static Function::Issues detect_issues(const nodes::Node &node);
+ static Function::Issues detect_issues(const Function &function) {
+ return detect_issues(function.root());
+ }
static bool should_use_lazy_params(const Function &function);
};
diff --git a/eval/src/vespa/eval/eval/node_tools.cpp b/eval/src/vespa/eval/eval/node_tools.cpp
new file mode 100644
index 00000000000..7bbe095c060
--- /dev/null
+++ b/eval/src/vespa/eval/eval/node_tools.cpp
@@ -0,0 +1,209 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "node_tools.h"
+#include <vespa/eval/eval/node_traverser.h>
+#include <vespa/eval/eval/node_visitor.h>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::nodes;
+
+namespace vespalib::eval {
+
+namespace {
+
+struct CountParams : NodeTraverser, EmptyNodeVisitor {
+ size_t result = 0;
+ void visit(const Symbol &symbol) override {
+ result = std::max(result, symbol.id() + 1);
+ }
+ bool open(const Node &) override { return true; }
+ void close(const Node &node) override { node.accept(*this); }
+};
+
+struct CopyNode : NodeTraverser, NodeVisitor {
+
+ std::unique_ptr<Error> error;
+ std::vector<Node_UP> stack;
+
+ CopyNode() : error(), stack() {}
+ ~CopyNode() override;
+
+ Node_UP result() {
+ if (error) {
+ return std::move(error);
+ }
+ if (stack.size() != 1) {
+ return std::make_unique<Error>("invalid result stack");
+ }
+ return std::move(stack.back());
+ }
+
+ //-------------------------------------------------------------------------
+
+ void fail(const vespalib::string &msg) {
+ if (!error) {
+ error = std::make_unique<Error>(msg);
+ }
+ }
+
+ void not_implemented(const Node &) {
+ fail("not implemented");
+ }
+
+ std::vector<Node_UP> get_children(size_t n) {
+ std::vector<Node_UP> result;
+ if (stack.size() >= n) {
+ for (size_t i = 0; i < n; ++i) {
+ result.push_back(std::move(stack[stack.size() - (n - i)]));
+ }
+ stack.resize(stack.size() - n);
+ } else {
+ fail("stack underflow");
+ for (size_t i = 0; i < n; ++i) {
+ result.push_back(std::make_unique<Error>("placeholder"));
+ }
+ }
+ return result;
+ }
+
+ //-------------------------------------------------------------------------
+
+ void wire_operator(Operator_UP op) {
+ auto list = get_children(2);
+ op->bind(std::move(list[0]), std::move(list[1]));
+ stack.push_back(std::move(op));
+ }
+
+ void wire_call(Call_UP call) {
+ auto list = get_children(call->num_params());
+ for (size_t i = 0; i < list.size(); ++i) {
+ call->bind_next(std::move(list[i]));
+ }
+ stack.push_back(std::move(call));
+ }
+
+ template <typename T> void copy_operator(const T &) { wire_operator(T::create()); }
+ template <typename T> void copy_call(const T &) { wire_call(T::create()); }
+
+ //-------------------------------------------------------------------------
+
+ // basic nodes
+ void visit(const Number &node) override {
+ stack.push_back(std::make_unique<Number>(node.value()));
+ }
+ void visit(const Symbol &node) override {
+ stack.push_back(std::make_unique<Symbol>(node.id()));
+ }
+ void visit(const String &node) override {
+ stack.push_back(std::make_unique<String>(node.value()));
+ }
+ void visit(const In &node) override {
+ for (size_t i = 0; i < node.num_entries(); ++i) {
+ // only String/Number allowed here; copy to stack
+ node.get_entry(i).accept(*this);
+ }
+ auto list = get_children(node.num_entries() + 1);
+ auto my_node = std::make_unique<In>(std::move(list[0]));
+ for (size_t i = 1; i < list.size(); ++i) {
+ my_node->add_entry(std::move(list[i]));
+ }
+ stack.push_back(std::move(my_node));
+ }
+ void visit(const Neg &) override {
+ auto list = get_children(1);
+ stack.push_back(std::make_unique<Neg>(std::move(list[0])));
+ }
+ void visit(const Not &) override {
+ auto list = get_children(1);
+ stack.push_back(std::make_unique<Not>(std::move(list[0])));
+ }
+ void visit(const If &node) override {
+ auto list = get_children(3);
+ stack.push_back(std::make_unique<If>(std::move(list[0]), std::move(list[1]), std::move(list[2]), node.p_true()));
+ }
+ void visit(const Error &node) override {
+ stack.push_back(std::make_unique<Error>(node.message()));
+ }
+
+ // tensor nodes
+ void visit(const TensorMap &node) override { not_implemented(node); }
+ void visit(const TensorJoin &node) override { not_implemented(node); }
+ void visit(const TensorMerge &node) override { not_implemented(node); }
+ void visit(const TensorReduce &node) override { not_implemented(node); }
+ void visit(const TensorRename &node) override { not_implemented(node); }
+ void visit(const TensorConcat &node) override { not_implemented(node); }
+ void visit(const TensorCreate &node) override { not_implemented(node); }
+ void visit(const TensorLambda &node) override { not_implemented(node); }
+ void visit(const TensorPeek &node) override { not_implemented(node); }
+
+ // operator nodes
+ void visit(const Add &node) override { copy_operator(node); }
+ void visit(const Sub &node) override { copy_operator(node); }
+ void visit(const Mul &node) override { copy_operator(node); }
+ void visit(const Div &node) override { copy_operator(node); }
+ void visit(const Mod &node) override { copy_operator(node); }
+ void visit(const Pow &node) override { copy_operator(node); }
+ void visit(const Equal &node) override { copy_operator(node); }
+ void visit(const NotEqual &node) override { copy_operator(node); }
+ void visit(const Approx &node) override { copy_operator(node); }
+ void visit(const Less &node) override { copy_operator(node); }
+ void visit(const LessEqual &node) override { copy_operator(node); }
+ void visit(const Greater &node) override { copy_operator(node); }
+ void visit(const GreaterEqual &node) override { copy_operator(node); }
+ void visit(const And &node) override { copy_operator(node); }
+ void visit(const Or &node) override { copy_operator(node); }
+
+ // call nodes
+ void visit(const Cos &node) override { copy_call(node); }
+ void visit(const Sin &node) override { copy_call(node); }
+ void visit(const Tan &node) override { copy_call(node); }
+ void visit(const Cosh &node) override { copy_call(node); }
+ void visit(const Sinh &node) override { copy_call(node); }
+ void visit(const Tanh &node) override { copy_call(node); }
+ void visit(const Acos &node) override { copy_call(node); }
+ void visit(const Asin &node) override { copy_call(node); }
+ void visit(const Atan &node) override { copy_call(node); }
+ void visit(const Exp &node) override { copy_call(node); }
+ void visit(const Log10 &node) override { copy_call(node); }
+ void visit(const Log &node) override { copy_call(node); }
+ void visit(const Sqrt &node) override { copy_call(node); }
+ void visit(const Ceil &node) override { copy_call(node); }
+ void visit(const Fabs &node) override { copy_call(node); }
+ void visit(const Floor &node) override { copy_call(node); }
+ void visit(const Atan2 &node) override { copy_call(node); }
+ void visit(const Ldexp &node) override { copy_call(node); }
+ void visit(const Pow2 &node) override { copy_call(node); }
+ void visit(const Fmod &node) override { copy_call(node); }
+ void visit(const Min &node) override { copy_call(node); }
+ void visit(const Max &node) override { copy_call(node); }
+ void visit(const IsNan &node) override { copy_call(node); }
+ void visit(const Relu &node) override { copy_call(node); }
+ void visit(const Sigmoid &node) override { copy_call(node); }
+ void visit(const Elu &node) override { copy_call(node); }
+
+ // traverse nodes
+ bool open(const Node &) override { return !error; }
+ void close(const Node &node) override { node.accept(*this); }
+};
+
+CopyNode::~CopyNode() = default;
+
+} // namespace vespalib::eval::<unnamed>
+
+size_t
+NodeTools::min_num_params(const Node &node)
+{
+ CountParams count_params;
+ node.traverse(count_params);
+ return count_params.result;
+}
+
+Node_UP
+NodeTools::copy(const Node &node)
+{
+ CopyNode copy_node;
+ node.traverse(copy_node);
+ return copy_node.result();
+}
+
+} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/node_tools.h b/eval/src/vespa/eval/eval/node_tools.h
new file mode 100644
index 00000000000..a358056c90f
--- /dev/null
+++ b/eval/src/vespa/eval/eval/node_tools.h
@@ -0,0 +1,16 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <memory>
+
+namespace vespalib::eval {
+
+namespace nodes { struct Node; }
+
+struct NodeTools {
+ static size_t min_num_params(const nodes::Node &node);
+ static std::unique_ptr<nodes::Node> copy(const nodes::Node &node);
+};
+
+} // namespace vespalib::eval
diff --git a/eval/src/vespa/eval/eval/tensor_function.h b/eval/src/vespa/eval/eval/tensor_function.h
index 2cc70f50b15..e1961079017 100644
--- a/eval/src/vespa/eval/eval/tensor_function.h
+++ b/eval/src/vespa/eval/eval/tensor_function.h
@@ -343,6 +343,9 @@ private:
public:
Lambda(const ValueType &result_type_in, const std::vector<size_t> &bindings_in, const Function &lambda_in, NodeTypes lambda_types_in)
: Node(result_type_in), _bindings(bindings_in), _lambda(lambda_in.shared_from_this()), _lambda_types(std::move(lambda_types_in)) {}
+ const std::vector<size_t> &bindings() const { return _bindings; }
+ const Function &lambda() const { return *_lambda; }
+ const NodeTypes &types() const { return _lambda_types; }
static TensorSpec create_spec_impl(const ValueType &type, const LazyParams &params, const std::vector<size_t> &bind, const InterpretedFunction &fun);
TensorSpec create_spec(const LazyParams &params, const InterpretedFunction &fun) const {
return create_spec_impl(result_type(), params, _bindings, fun);
diff --git a/eval/src/vespa/eval/eval/test/eval_fixture.cpp b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
index 7c4d4caf854..325fb208319 100644
--- a/eval/src/vespa/eval/eval/test/eval_fixture.cpp
+++ b/eval/src/vespa/eval/eval/test/eval_fixture.cpp
@@ -26,7 +26,14 @@ NodeTypes get_types(const Function &function, const ParamRepo &param_repo) {
param_types.push_back(ValueType::from_spec(pos->second.value.type()));
ASSERT_TRUE(!param_types.back().is_error());
}
- return NodeTypes(function, param_types);
+ NodeTypes node_types(function, param_types);
+ if (!node_types.errors().empty()) {
+ for (const auto &msg: node_types.errors()) {
+ fprintf(stderr, "eval_fixture: type error: %s\n", msg.c_str());
+ }
+ }
+ ASSERT_TRUE(node_types.errors().empty());
+ return node_types;
}
std::set<size_t> get_mutable(const Function &function, const ParamRepo &param_repo) {
diff --git a/eval/src/vespa/eval/tensor/dense/CMakeLists.txt b/eval/src/vespa/eval/tensor/dense/CMakeLists.txt
index 9e4c9857bd1..1b9b51d6ad2 100644
--- a/eval/src/vespa/eval/tensor/dense/CMakeLists.txt
+++ b/eval/src/vespa/eval/tensor/dense/CMakeLists.txt
@@ -2,11 +2,14 @@
vespa_add_library(eval_tensor_dense OBJECT
SOURCES
dense_add_dimension_optimizer.cpp
+ dense_cell_range_function.cpp
dense_dimension_combiner.cpp
dense_dot_product_function.cpp
dense_fast_rename_optimizer.cpp
dense_inplace_join_function.cpp
dense_inplace_map_function.cpp
+ dense_lambda_peek_function.cpp
+ dense_lambda_peek_optimizer.cpp
dense_matmul_function.cpp
dense_remove_dimension_optimizer.cpp
dense_replace_type_function.cpp
@@ -14,7 +17,6 @@ vespa_add_library(eval_tensor_dense OBJECT
dense_tensor_address_mapper.cpp
dense_tensor_cells_iterator.cpp
dense_tensor_create_function.cpp
- dense_lambda_peek_optimizer.cpp
dense_tensor_modify.cpp
dense_tensor_peek_function.cpp
dense_tensor_reduce.cpp
diff --git a/eval/src/vespa/eval/tensor/dense/dense_cell_range_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_cell_range_function.cpp
new file mode 100644
index 00000000000..9b93f5e7d72
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/dense/dense_cell_range_function.cpp
@@ -0,0 +1,53 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "dense_cell_range_function.h"
+#include "dense_tensor_view.h"
+#include <vespa/eval/eval/value.h>
+
+namespace vespalib::tensor {
+
+using eval::Value;
+using eval::ValueType;
+using eval::TensorFunction;
+using eval::TensorEngine;
+using eval::as;
+using namespace eval::tensor_function;
+
+namespace {
+
+template <typename CT>
+void my_cell_range_op(eval::InterpretedFunction::State &state, uint64_t param) {
+ const auto *self = (const DenseCellRangeFunction *)(param);
+ auto old_cells = DenseTensorView::typify_cells<CT>(state.peek(0));
+ ConstArrayRef<CT> new_cells(&old_cells[self->offset()], self->length());
+ state.pop_push(state.stash.create<DenseTensorView>(self->result_type(), TypedCells(new_cells)));
+}
+
+struct MyCellRangeOp {
+ template <typename CT>
+ static auto get_fun() { return my_cell_range_op<CT>; }
+};
+
+} // namespace vespalib::tensor::<unnamed>
+
+DenseCellRangeFunction::DenseCellRangeFunction(const eval::ValueType &result_type,
+ const eval::TensorFunction &child,
+ size_t offset, size_t length)
+ : eval::tensor_function::Op1(result_type, child),
+ _offset(offset),
+ _length(length)
+{
+}
+
+DenseCellRangeFunction::~DenseCellRangeFunction() = default;
+
+eval::InterpretedFunction::Instruction
+DenseCellRangeFunction::compile_self(const TensorEngine &, Stash &) const
+{
+ static_assert(sizeof(uint64_t) == sizeof(this));
+ assert(result_type().cell_type() == child().result_type().cell_type());
+ auto op = select_1<MyCellRangeOp>(result_type().cell_type());
+ return eval::InterpretedFunction::Instruction(op, (uint64_t)this);
+}
+
+} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/dense/dense_cell_range_function.h b/eval/src/vespa/eval/tensor/dense/dense_cell_range_function.h
new file mode 100644
index 00000000000..3f220826324
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/dense/dense_cell_range_function.h
@@ -0,0 +1,31 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/tensor_function.h>
+
+namespace vespalib::tensor {
+
+/**
+ * Tensor function creating a view to a continuous range of cells in
+ * another tensor. The value type will (typically) change, but the
+ * cell type must remain the same.
+ **/
+class DenseCellRangeFunction : public eval::tensor_function::Op1
+{
+private:
+ size_t _offset;
+ size_t _length;
+
+public:
+ DenseCellRangeFunction(const eval::ValueType &result_type,
+ const eval::TensorFunction &child,
+ size_t offset, size_t length);
+ ~DenseCellRangeFunction() override;
+ size_t offset() const { return _offset; }
+ size_t length() const { return _length; }
+ eval::InterpretedFunction::Instruction compile_self(const eval::TensorEngine &engine, Stash &stash) const override;
+ bool result_is_mutable() const override { return child().result_is_mutable(); }
+};
+
+} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.cpp b/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.cpp
new file mode 100644
index 00000000000..0c7debde4ef
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.cpp
@@ -0,0 +1,89 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "dense_lambda_peek_function.h"
+#include "dense_tensor_view.h"
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/llvm/compile_cache.h>
+
+namespace vespalib::tensor {
+
+using eval::CompileCache;
+using eval::Function;
+using eval::InterpretedFunction;
+using eval::PassParams;
+using eval::TensorEngine;
+using eval::TensorFunction;
+using eval::Value;
+using eval::ValueType;
+using eval::as;
+using namespace eval::tensor_function;
+
+namespace {
+
+struct Self {
+ const ValueType &result_type;
+ CompileCache::Token::UP compile_token;
+ Self(const ValueType &result_type_in, const Function &function)
+ : result_type(result_type_in),
+ compile_token(CompileCache::compile(function, PassParams::ARRAY)) {}
+};
+
+bool step_params(std::vector<double> &params, const ValueType &type) {
+ const auto &dims = type.dimensions();
+ for (size_t idx = params.size(); idx-- > 0; ) {
+ if (size_t(params[idx] += 1.0) < dims[idx].size) {
+ return true;
+ } else {
+ params[idx] = 0.0;
+ }
+ }
+ return false;
+}
+
+template <typename DST_CT, typename SRC_CT>
+void my_lambda_peek_op(InterpretedFunction::State &state, uint64_t param) {
+ const auto *self = (const Self *)(param);
+ auto src_cells = DenseTensorView::typify_cells<SRC_CT>(state.peek(0));
+ ArrayRef<DST_CT> dst_cells = state.stash.create_array<DST_CT>(self->result_type.dense_subspace_size());
+ DST_CT *dst = &dst_cells[0];
+ std::vector<double> params(self->result_type.dimensions().size(), 0.0);
+ auto idx_fun = self->compile_token->get().get_function();
+ do {
+ *dst++ = src_cells[size_t(idx_fun(&params[0]))];
+ } while(step_params(params, self->result_type));
+ state.pop_push(state.stash.create<DenseTensorView>(self->result_type, TypedCells(dst_cells)));
+}
+
+struct MyLambdaPeekOp {
+ template <typename DST_CT, typename SRC_CT>
+ static auto get_fun() { return my_lambda_peek_op<DST_CT, SRC_CT>; }
+};
+
+} // namespace vespalib::tensor::<unnamed>
+
+DenseLambdaPeekFunction::DenseLambdaPeekFunction(const ValueType &result_type,
+ const TensorFunction &child,
+ std::shared_ptr<Function const> idx_fun)
+ : Op1(result_type, child),
+ _idx_fun(std::move(idx_fun))
+{
+}
+
+DenseLambdaPeekFunction::~DenseLambdaPeekFunction() = default;
+
+InterpretedFunction::Instruction
+DenseLambdaPeekFunction::compile_self(const TensorEngine &, Stash &stash) const
+{
+ const Self &self = stash.create<Self>(result_type(), *_idx_fun);
+ auto op = select_2<MyLambdaPeekOp>(result_type().cell_type(), child().result_type().cell_type());
+ static_assert(sizeof(uint64_t) == sizeof(&self));
+ assert(child().result_type().is_dense());
+ return InterpretedFunction::Instruction(op, (uint64_t)&self);
+}
+
+vespalib::string
+DenseLambdaPeekFunction::idx_fun_dump() const {
+ return _idx_fun->dump_as_lambda();
+}
+
+} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.h b/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.h
new file mode 100644
index 00000000000..a4146787bfa
--- /dev/null
+++ b/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_function.h
@@ -0,0 +1,31 @@
+// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/eval/eval/tensor_function.h>
+
+namespace vespalib::tensor {
+
+/**
+ * Tensor function creating a new dense tensor based on peeking cells
+ * of a single existing dense tensor. Which cells to peek is described
+ * by a single (compilable) function mapping the individual dimension
+ * indexes of the tensor to be created into global cell indexes of the
+ * tensor to be peeked.
+ **/
+class DenseLambdaPeekFunction : public eval::tensor_function::Op1
+{
+private:
+ std::shared_ptr<eval::Function const> _idx_fun;
+
+public:
+ DenseLambdaPeekFunction(const eval::ValueType &result_type,
+ const eval::TensorFunction &child,
+ std::shared_ptr<eval::Function const> idx_fun);
+ ~DenseLambdaPeekFunction() override;
+ eval::InterpretedFunction::Instruction compile_self(const eval::TensorEngine &engine, Stash &stash) const override;
+ vespalib::string idx_fun_dump() const;
+ bool result_is_mutable() const override { return true; }
+};
+
+} // namespace vespalib::tensor
diff --git a/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_optimizer.cpp b/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_optimizer.cpp
index 14954a77834..cb42ff86fbe 100644
--- a/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_optimizer.cpp
+++ b/eval/src/vespa/eval/tensor/dense/dense_lambda_peek_optimizer.cpp
@@ -1,13 +1,203 @@
// Copyright Verizon Media. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "dense_lambda_peek_optimizer.h"
+#include "dense_tensor_view.h"
+#include "dense_replace_type_function.h"
+#include "dense_cell_range_function.h"
+#include "dense_lambda_peek_function.h"
+#include <vespa/eval/eval/value.h>
+#include <vespa/eval/eval/node_tools.h>
+#include <vespa/eval/eval/basic_nodes.h>
+#include <vespa/eval/eval/operator_nodes.h>
+#include <vespa/eval/eval/call_nodes.h>
+#include <vespa/eval/eval/tensor_nodes.h>
+#include <vespa/eval/eval/llvm/compile_cache.h>
+#include <optional>
+
+using namespace vespalib::eval;
+using namespace vespalib::eval::nodes;
namespace vespalib::tensor {
-const eval::TensorFunction &
-DenseLambdaPeekOptimizer::optimize(const eval::TensorFunction &expr, Stash &)
+namespace {
+
+// 'simple peek': deterministic peek into a single parameter with
+// compilable dimension index expressions.
+const TensorPeek *find_simple_peek(const tensor_function::Lambda &lambda) {
+ const Function &function = lambda.lambda();
+ const size_t num_dims = lambda.result_type().dimensions().size();
+ auto peek = as<TensorPeek>(function.root());
+ if (peek && (function.num_params() == (num_dims + 1))) {
+ auto param = as<Symbol>(peek->get_child(0));
+ if (param && (param->id() == num_dims)) {
+ for (size_t i = 1; i < peek->num_children(); ++i) {
+ const Node &dim_expr = peek->get_child(i);
+ if (NodeTools::min_num_params(dim_expr) > num_dims) {
+ return nullptr;
+ }
+ if (CompiledFunction::detect_issues(dim_expr)) {
+ return nullptr;
+ }
+ }
+ return peek;
+ }
+ }
+ return nullptr;
+}
+
+Node_UP make_dim_expr(const TensorPeek::Dim &src_dim) {
+ if (src_dim.second.is_expr()) {
+ return NodeTools::copy(*src_dim.second.expr);
+ } else {
+ return std::make_unique<Number>(as_number(src_dim.second.label));
+ }
+}
+
+template <typename OP>
+Node_UP make_op(Node_UP a, Node_UP b) {
+ auto res = std::make_unique<OP>();
+ res->bind(std::move(a), std::move(b));
+ return res;
+}
+
+Node_UP make_floor(Node_UP a) {
+ auto res = std::make_unique<Floor>();
+ res->bind_next(std::move(a));
+ return res;
+}
+
+struct PeekAnalyzer {
+ std::vector<size_t> dst_dim_sizes;
+ std::vector<size_t> src_dim_sizes;
+ std::vector<CompiledFunction::UP> src_dim_funs;
+ std::shared_ptr<Function const> src_idx_fun;
+
+ struct CellRange {
+ size_t offset;
+ size_t length;
+ bool is_full(size_t num_cells) const {
+ return ((offset == 0) && (length == num_cells));
+ }
+ };
+
+ struct Result {
+ bool valid;
+ std::optional<CellRange> cell_range;
+ static Result simple(CellRange range) { return Result{true, range}; }
+ static Result complex() { return Result{true, std::nullopt}; }
+ static Result invalid() { return Result{false, std::nullopt}; }
+ };
+
+ PeekAnalyzer(const ValueType &dst_type, const ValueType &src_type,
+ const TensorPeek::DimList &dim_list)
+ {
+ for (const auto dim: dst_type.dimensions()) {
+ dst_dim_sizes.push_back(dim.size);
+ }
+ for (const auto dim: src_type.dimensions()) {
+ src_dim_sizes.push_back(dim.size);
+ }
+ Node_UP idx_expr;
+ size_t num_params = dst_dim_sizes.size();
+ for (size_t i = 0; i < dim_list.size(); ++i) {
+ auto dim_expr = make_dim_expr(dim_list[i]);
+ src_dim_funs.push_back(std::make_unique<CompiledFunction>(*dim_expr, num_params, PassParams::ARRAY));
+ if (i == 0) {
+ idx_expr = std::move(dim_expr);
+ } else {
+ idx_expr = make_floor(std::move(idx_expr));
+ idx_expr = make_op<Mul>(std::move(idx_expr), std::make_unique<Number>(src_dim_sizes[i]));
+ idx_expr = make_op<Add>(std::move(idx_expr), std::move(dim_expr));
+ }
+ }
+ src_idx_fun = Function::create(std::move(idx_expr), dst_type.dimension_names());
+ }
+
+ bool step_params(std::vector<double> &params) {
+ for (size_t idx = params.size(); idx-- > 0; ) {
+ if (size_t(params[idx] += 1.0) < dst_dim_sizes[idx]) {
+ return true;
+ } else {
+ params[idx] = 0.0;
+ }
+ }
+ return false;
+ }
+
+ size_t calculate_index(const std::vector<size_t> &src_address) {
+ size_t result = 0;
+ for (size_t i = 0; i < src_address.size(); ++i) {
+ result *= src_dim_sizes[i];
+ result += src_address[i];
+ }
+ return result;
+ }
+
+ Result analyze_indexes() {
+ CellRange range{0,0};
+ bool is_complex = false;
+ std::vector<double> params(dst_dim_sizes.size(), 0.0);
+ std::vector<size_t> src_address(src_dim_sizes.size(), 0);
+ do {
+ for (size_t i = 0; i < src_dim_funs.size(); ++i) {
+ auto dim_fun = src_dim_funs[i]->get_function();
+ size_t dim_idx = dim_fun(&params[0]);
+ if (dim_idx >= src_dim_sizes[i]) {
+ return Result::invalid();
+ }
+ src_address[i] = dim_idx;
+ }
+ size_t idx = calculate_index(src_address);
+ if (range.length == 0) {
+ range.offset = idx;
+ }
+ if (idx == (range.offset + range.length)) {
+ ++range.length;
+ } else {
+ is_complex = true;
+ }
+ } while(step_params(params));
+ if (is_complex) {
+ return Result::complex();
+ }
+ return Result::simple(range);
+ }
+};
+
+} // namespace vespalib::tensor::<unnamed>
+
+const TensorFunction &
+DenseLambdaPeekOptimizer::optimize(const TensorFunction &expr, Stash &stash)
{
+ if (auto lambda = as<tensor_function::Lambda>(expr)) {
+ if (auto peek = find_simple_peek(*lambda)) {
+ const ValueType &dst_type = lambda->result_type();
+ const ValueType &src_type = lambda->types().get_type(peek->param());
+ if (src_type.is_dense()) {
+ assert(lambda->bindings().size() == 1);
+ assert(src_type.dimensions().size() == peek->dim_list().size());
+ size_t param_idx = lambda->bindings()[0];
+ PeekAnalyzer analyzer(dst_type, src_type, peek->dim_list());
+ auto result = analyzer.analyze_indexes();
+ if (result.valid) {
+ const auto &get_param = tensor_function::inject(src_type, param_idx, stash);
+ if (result.cell_range && (dst_type.cell_type() == src_type.cell_type())) {
+ auto cell_range = result.cell_range.value();
+ if (cell_range.is_full(src_type.dense_subspace_size())) {
+ return DenseReplaceTypeFunction::create_compact(dst_type, get_param, stash);
+ } else {
+ return stash.create<DenseCellRangeFunction>(dst_type, get_param,
+ cell_range.offset, cell_range.length);
+ }
+ } else {
+ return stash.create<DenseLambdaPeekFunction>(dst_type, get_param,
+ std::move(analyzer.src_idx_fun));
+ }
+ }
+ }
+ }
+ }
return expr;
}
-}
+} // namespace vespalib::tensor