aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-02-10 10:33:53 +0100
committerLester Solbakken <lesters@oath.com>2020-02-10 10:33:53 +0100
commitc6ea3aa88e8929c2cbfe90f9c9ffdde482b7adc5 (patch)
tree550c76a8310c4951a3c5ae4c6e53af889bb9b54c /searchlib/src
parent7b5b53d288ab8b3c9ec8e054d4d5ecf2f88f7ff0 (diff)
Add gather,slice,cast,unsqueeze onnx operations
Diffstat (limited to 'searchlib/src')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java50
1 files changed, 50 insertions, 0 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6a87e0c6d46..807eb3aa7ce 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -386,6 +386,56 @@ public class EvaluationTestCase {
// tensor result dimensions are given from argument dimensions, not the resulting values
tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }");
tester.assertEvaluates("tensor(x{},y{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{},y{}):{ {x:1,y:0}:1, {x:2,y:1}:1 }");
+
+ }
+
+ @Test
+ public void testTake() {
+ EvaluationTester tester = new EvaluationTester();
+
+ // numpy.take(a, indices, axis) with tensors.
+
+ // 1 dim input, 1 dim indices
+ tester.assertEvaluates("tensor(d0[3]):[1, 3, 5]",
+ "tensor(d0[3])(tensor0{a0:(tensor1{indices0:(d0)})})",
+ "tensor(a0[6]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[3]):[0, 2, 4]");
+
+ // 1 dim input, 1 dim indices - negative indices
+ tester.assertEvaluates("tensor(d0[3]):[1, 5, 3]",
+ "tensor(d0[3])(tensor0{a0:(fmod(6 + tensor1{indices0:(d0)}, 6) ) })",
+ "tensor(a0[6]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[3]):[0, -2, -4]");
+
+ // 2 dim input, 1 dim indices - axis 0
+ tester.assertEvaluates("tensor(d0[4],d1[2]):[5, 6, 3, 4, 1, 2, 5, 6]",
+ "tensor(d0[4],d1[2])(tensor0{a0:(tensor1{indices0:(d0)}),a1:(d1)})",
+ "tensor(a0[3],a1[2]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[4]):[2, 1, 0, 2]");
+
+ // 1 dim input, 2 dim indices - axis 0
+ tester.assertEvaluates("tensor(d0[2],d1[2]):[1, 2, 4, 6]",
+ "tensor(d0[2],d1[2])(tensor0{a0:(tensor1{indices0:(d0),indices1:(d1)}) })",
+ "tensor(a0[6]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[2],indices1[2]):[0, 1, 3, 5]");
+
+ // 2 dim input, 2 dim indices - axis 0
+ tester.assertEvaluates("tensor(d0[2],d1[2],d2[2]):[1,2,3,4,3,4,5,6]",
+ "tensor(d0[2],d1[2],d2[2])(tensor0{a0:(tensor1{indices0:(d0),indices1:(d1)}),a1:(d2)})",
+ "tensor(a0[3],a1[2]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[2],indices1[2]):[0, 1, 1, 2]");
+
+ // 2 dim input, 1 dim indices - axis 1
+ tester.assertEvaluates("tensor(d0[3],d1[4]):[1,2,1,2,3,4,3,4,5,6,5,6]",
+ "tensor(d0[3],d1[4])(tensor0{a0:(d0), a1:(tensor1{indices0:(d1)}) })",
+ "tensor(a0[3],a1[2]):[1, 2, 3, 4, 5, 6]",
+ "tensor(indices0[4]):[0, 1, 0, 1]");
+
+ // 2 dim input, 2 dim indices - axis 1
+ tester.assertEvaluates("tensor(d0[3],d1[1],d2[2]):[1,3,4,6,7,9]",
+ "tensor(d0[3],d1[1],d2[2])(tensor0{a0:(d0), a1:(tensor1{indices0:(d1),indices1:(d2)}) })", // can add an if
+ "tensor(a0[3],a1[3]):[1, 2, 3, 4, 5, 6, 7, 8, 9]",
+ "tensor(indices0[1],indices1[2]):[0, 2]");
}
@Test