summaryrefslogtreecommitdiffstats
path: root/eval
diff options
context:
space:
mode:
authorArne Juul <arnej@verizonmedia.com>2021-02-03 13:18:27 +0000
committerArne Juul <arnej@verizonmedia.com>2021-02-03 13:18:27 +0000
commit26fb2e08c165617b82aa8bb1f7f60cd843be22dc (patch)
tree04f96cc43582ce6642ef2a0ee15c051028b00127 /eval
parent2319ebb2fbbc87ac54448ec6a433a0f99aab940c (diff)
use GenSpec in dense_dot_product_function_test
Diffstat (limited to 'eval')
-rw-r--r--eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp32
1 files changed, 13 insertions, 19 deletions
diff --git a/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp b/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp
index 82a0baa8741..5121c587d5b 100644
--- a/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp
+++ b/eval/src/tests/instruction/dense_dot_product_function/dense_dot_product_function_test.cpp
@@ -3,7 +3,7 @@
#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/tensor_function.h>
#include <vespa/eval/eval/test/eval_fixture.h>
-#include <vespa/eval/eval/test/tensor_model.hpp>
+#include <vespa/eval/eval/test/gen_spec.h>
#include <vespa/eval/instruction/dense_dot_product_function.h>
#include <vespa/vespalib/testkit/test_kit.h>
#include <vespa/vespalib/util/stash.h>
@@ -18,14 +18,8 @@ using namespace vespalib::eval::test;
const ValueBuilderFactory &prod_factory = FastValueBuilderFactory::get();
-struct MyVecSeq : Sequence {
- double bias;
- double operator[](size_t i) const override { return (i + bias); }
- MyVecSeq(double cellBias) : bias(cellBias) {}
-};
-
TensorSpec makeTensor(size_t numCells, double cellBias) {
- return spec({x(numCells)}, MyVecSeq(cellBias));
+ return GenSpec().idx("x", numCells).seq_bias(cellBias).gen();
}
const double leftBias = 3.0;
@@ -45,7 +39,7 @@ void check_gen_with_result(size_t l, size_t r, double wanted) {
param_repo.add("b", makeTensor(r, rightBias));
vespalib::string expr = "reduce(a*b,sum,x)";
EvalFixture evaluator(prod_factory, expr, param_repo, true);
- EXPECT_EQUAL(spec(wanted), evaluator.result());
+ EXPECT_EQUAL(GenSpec().seq_bias(wanted).gen(), evaluator.result());
EXPECT_EQUAL(evaluator.result(), EvalFixture::ref(expr, param_repo));
auto info = evaluator.find_all<DenseDotProductFunction>();
EXPECT_EQUAL(info.size(), 1u);
@@ -93,16 +87,16 @@ TEST("require that dot product with equal sizes is correct") {
EvalFixture::ParamRepo make_params() {
return EvalFixture::ParamRepo()
- .add("v01_x1", spec({x(1)}, MyVecSeq(2.0)))
- .add("v02_x3", spec({x(3)}, MyVecSeq(4.0)))
- .add("v03_x3", spec({x(3)}, MyVecSeq(5.0)))
- .add("v04_y3", spec({y(3)}, MyVecSeq(10)))
- .add("v05_x5", spec({x(5)}, MyVecSeq(6.0)))
- .add("v06_x5", spec({x(5)}, MyVecSeq(7.0)))
- .add("v07_x5f", spec(float_cells({x(5)}), MyVecSeq(7.0)))
- .add("v08_x5f", spec(float_cells({x(5)}), MyVecSeq(6.0)))
- .add("m01_x3y3", spec({x(3),y(3)}, MyVecSeq(1.0)))
- .add("m02_x3y3", spec({x(3),y(3)}, MyVecSeq(2.0)));
+ .add("v01_x1", GenSpec().idx("x", 1).seq_bias(2.0).gen())
+ .add("v02_x3", GenSpec().idx("x", 3).seq_bias(4.0).gen())
+ .add("v03_x3", GenSpec().idx("x", 3).seq_bias(5.0).gen())
+ .add("v04_y3", GenSpec().idx("y", 3).seq_bias(10).gen())
+ .add("v05_x5", GenSpec().idx("x", 5).seq_bias(6.0).gen())
+ .add("v06_x5", GenSpec().idx("x", 5).seq_bias(7.0).gen())
+ .add("v07_x5f", GenSpec().cells_float().idx("x", 5).seq_bias(7.0).gen())
+ .add("v08_x5f", GenSpec().cells_float().idx("x", 5).seq_bias(6.0).gen())
+ .add("m01_x3y3", GenSpec().idx("x", 3).idx("y", 3).seq_bias(1.0).gen())
+ .add("m02_x3y3", GenSpec().idx("x", 3).idx("y", 3).seq_bias(2.0).gen());
}
EvalFixture::ParamRepo param_repo = make_params();