summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-02-01 20:52:16 +0000
committerArne Juul <arnej@verizonmedia.com>2021-02-02 10:03:58 +0000
commit9cc4c30a8a7f3c1ec2a2ee2261799b99f1b15aa6 (patch)
tree5b044ee882b1292e2ad8d169cd4a8b25cb964cd6 /eval
parent6895870413a8f96bd8b4d8f78dff8d4a6d0ac10c (diff)
use GenSpec in mixed_inner_product_function_test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp71
1 files changed, 31 insertions, 40 deletions
diff --git a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp
index fbe71f3ed63..6b549b4d4d4 100644
--- a/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp
+++ b/eval/src/tests/instruction/mixed_inner_product_function/mixed_inner_product_function_test.cpp
@@ -3,7 +3,7 @@
#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
-#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/eval/instruction/dense_dot_product_function.h>
#include <vespa/eval/instruction/dense_matmul_function.h>
#include <vespa/eval/instruction/dense_multi_matmul_function.h>
@@ -22,34 +22,25 @@ using namespace vespalib::eval::test;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-struct MyVecSeq : Sequence {
- double bias;
- double operator[](size_t i) const override { return (i + bias); }
- MyVecSeq(double cellBias) : bias(cellBias) {}
-};
-
-std::function<double(size_t)> my_vec_gen(double cellBias) {
- return [=] (size_t i) noexcept { return i + cellBias; };
-}
//-----------------------------------------------------------------------------
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
- .add_vector("x", 3, my_vec_gen(2.0))
- .add_vector("x", 3, my_vec_gen(13.25))
- .add_vector("y", 3, my_vec_gen(4.0))
- .add_vector("z", 3, my_vec_gen(0.25))
- .add_matrix("x", 3, "y", 1, my_vec_gen(5.0))
- .add_matrix("x", 1, "y", 3, my_vec_gen(6.0))
- .add_matrix("x", 3, "y", 3, my_vec_gen(1.5))
- .add_matrix("x", 3, "z", 3, my_vec_gen(2.5))
- .add_cube("x", 3, "y", 3, "z", 3, my_vec_gen(-4.0))
- .add("mix_x3zm", spec({x(3),z({"c","d"})}, MyVecSeq(0.5)))
- .add("mix_y3zm", spec({y(3),z({"c","d"})}, MyVecSeq(3.5)))
- .add("mix_x3zm_f", spec(float_cells({x(3),z({"c","d"})}), MyVecSeq(0.5)))
- .add("mix_y3zm_f", spec(float_cells({y(3),z({"c","d"})}), MyVecSeq(3.5)))
- .add("mix_x3y3zm", spec({x(3),y(3),z({"c","d"})}, MyVecSeq(0.0)))
+ .add_variants("x3", GenSpec().idx("x", 3).seq_bias(2.0))
+ .add_variants("x3$2", GenSpec().idx("x", 3).seq_bias(13.25))
+ .add_variants("y3", GenSpec().idx("y", 3).seq_bias(4.0))
+ .add_variants("z3", GenSpec().idx("z", 3).seq_bias(0.25))
+ .add_variants("x3y3", GenSpec().idx("x", 3).idx("y", 3).seq_bias(5.0))
+ .add_variants("x1y3", GenSpec().idx("x", 1).idx("y", 3).seq_bias(6.0))
+ .add_variants("x3y1", GenSpec().idx("x", 3).idx("y", 1).seq_bias(1.5))
+ .add_variants("x3z3", GenSpec().idx("x", 3).idx("z", 3).seq_bias(2.5))
+ .add_variants("x3y3z3", GenSpec().idx("x", 3).idx("y", 3).idx("z", 3).seq_bias(-4.0))
+ .add("mix_x3zm", GenSpec().idx("x", 3).map("z", {"c","d"}).seq_bias(0.5).gen())
+ .add("mix_y3zm", GenSpec().idx("y", 3).map("z", {"c","d"}).seq_bias(3.5).gen())
+ .add("mix_x3zm_f", GenSpec().idx("x", 3).map("z", {"c","d"}).cells_float().seq_bias(0.5).gen())
+ .add("mix_y3zm_f", GenSpec().idx("y", 3).map("z", {"c","d"}).cells_float().seq_bias(3.5).gen())
+ .add("mix_x3y3zm", GenSpec().idx("x", 3).idx("y", 3).map("z", {"c","d"}).seq_bias(0.0).gen())
;
}
@@ -101,35 +92,35 @@ TEST(MixedInnerProduct, use_dense_optimizers_when_possible) {
TEST(MixedInnerProduct, trigger_optimizer_when_possible) {
assert_mixed_optimized("reduce(x3 * mix_x3zm,sum,x)");
- assert_mixed_optimized("reduce(x3f * mix_x3zm,sum,x)");
+ assert_mixed_optimized("reduce(x3_f * mix_x3zm,sum,x)");
assert_mixed_optimized("reduce(x3 * mix_x3zm_f,sum,x)");
- assert_mixed_optimized("reduce(x3f * mix_x3zm_f,sum,x)");
+ assert_mixed_optimized("reduce(x3_f * mix_x3zm_f,sum,x)");
assert_mixed_optimized("reduce(x3$2 * mix_x3zm,sum,x)");
- assert_mixed_optimized("reduce(x3f$2 * mix_x3zm,sum,x)");
+ assert_mixed_optimized("reduce(x3$2_f * mix_x3zm,sum,x)");
assert_mixed_optimized("reduce(y3 * mix_y3zm,sum,y)");
- assert_mixed_optimized("reduce(y3f * mix_y3zm,sum,y)");
+ assert_mixed_optimized("reduce(y3_f * mix_y3zm,sum,y)");
assert_mixed_optimized("reduce(y3 * mix_y3zm_f,sum,y)");
- assert_mixed_optimized("reduce(y3f * mix_y3zm_f,sum,y)");
+ assert_mixed_optimized("reduce(y3_f * mix_y3zm_f,sum,y)");
assert_mixed_optimized("reduce(x3y1 * mix_x3zm,sum,x)");
- assert_mixed_optimized("reduce(x3y1f * mix_x3zm,sum,x)");
+ assert_mixed_optimized("reduce(x3y1_f * mix_x3zm,sum,x)");
assert_mixed_optimized("reduce(x3y1 * mix_x3zm,sum,x,y)");
- assert_mixed_optimized("reduce(x3y1f * mix_x3zm,sum,x,y)");
+ assert_mixed_optimized("reduce(x3y1_f * mix_x3zm,sum,x,y)");
assert_mixed_optimized("reduce(x1y3 * mix_y3zm,sum,y)");
- assert_mixed_optimized("reduce(x1y3f * mix_y3zm,sum,y)");
+ assert_mixed_optimized("reduce(x1y3_f * mix_y3zm,sum,y)");
assert_mixed_optimized("reduce(x1y3 * x1y3,sum,y)");
- assert_mixed_optimized("reduce(x1y3 * x1y3f,sum,y)");
- assert_mixed_optimized("reduce(x1y3f * x1y3,sum,y)");
- assert_mixed_optimized("reduce(x1y3f * x1y3f,sum,y)");
+ assert_mixed_optimized("reduce(x1y3 * x1y3_f,sum,y)");
+ assert_mixed_optimized("reduce(x1y3_f * x1y3,sum,y)");
+ assert_mixed_optimized("reduce(x1y3_f * x1y3_f,sum,y)");
assert_mixed_optimized("reduce(x1y3 * mix_y3zm,sum,y)");
- assert_mixed_optimized("reduce(x1y3f * mix_y3zm,sum,y)");
+ assert_mixed_optimized("reduce(x1y3_f * mix_y3zm,sum,y)");
assert_mixed_optimized("reduce(mix_x3zm * x3,sum,x)");
- assert_mixed_optimized("reduce(mix_x3zm * x3f,sum,x)");
+ assert_mixed_optimized("reduce(mix_x3zm * x3_f,sum,x)");
assert_mixed_optimized("reduce(mix_x3zm * x3y1,sum,x)");
- assert_mixed_optimized("reduce(mix_x3zm * x3y1f,sum,x)");
+ assert_mixed_optimized("reduce(mix_x3zm * x3y1_f,sum,x)");
assert_mixed_optimized("reduce(mix_y3zm * y3,sum,y)");
- assert_mixed_optimized("reduce(mix_y3zm * y3f,sum,y)");
+ assert_mixed_optimized("reduce(mix_y3zm * y3_f,sum,y)");
assert_mixed_optimized("reduce(mix_y3zm * x1y3,sum,y)");
- assert_mixed_optimized("reduce(mix_y3zm * x1y3f,sum,y)");
+ assert_mixed_optimized("reduce(mix_y3zm * x1y3_f,sum,y)");
}
TEST(MixedInnerProduct, should_not_trigger_optimizer_for_other_cases) {