aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/test/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:59 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:59 +0100
commitb5ffe229474223844c150e99d24ca618e5e9f8dd (patch)
tree7c9ac3da58ff567fae79019bc688ba2a4e4d904c /vespajlib/src/test/java/com
parent1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff)
Complete prototype TensorFlow mapping
Diffstat (limited to 'vespajlib/src/test/java/com')
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java22
1 files changed, 12 insertions, 10 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
index 6606e278102..9643c0a56e7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
@@ -14,8 +14,8 @@ public class MatmulTestCase {
@Test
public void testMatmul2d() {
- // Convention: a is the 'outermost' dimension, etc.
- Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])"));
+ // d0 is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])"));
ab.cell( 1,0, 0);
ab.cell( 2,0, 1);
ab.cell( 3,0, 2);
@@ -24,7 +24,7 @@ public class MatmulTestCase {
ab.cell( 6,1, 2);
Tensor a = ab.build();
- Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])"));
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])"));
bb.cell( 7,0, 0);
bb.cell( 8,0, 1);
bb.cell( 9,1, 0);
@@ -33,21 +33,22 @@ public class MatmulTestCase {
bb.cell(12,2, 1);
Tensor b = bb.build();
- Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])"));
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])"));
rb.cell( 58,0, 0);
rb.cell( 64,0, 1);
rb.cell(139,1, 0);
rb.cell(154,1, 1);
Tensor r = rb.build();
- Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b");
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1")
+ .rename("d2","d1");
assertEquals(r, result);
}
@Test
public void testMatmul3d() {
// Convention: a is the 'outermost' dimension, etc.
- Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])"));
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])"));
ab.cell( 1,0, 0, 0);
ab.cell( 2,0, 0, 1);
ab.cell( 3,0, 0, 2);
@@ -62,7 +63,7 @@ public class MatmulTestCase {
ab.cell(12,1, 1, 2);
Tensor a = ab.build();
- Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])"));
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])"));
bb.cell(13,0, 0, 0);
bb.cell(14,0, 0, 1);
bb.cell(15,0, 1, 0);
@@ -77,7 +78,7 @@ public class MatmulTestCase {
bb.cell(24,1, 2, 1);
Tensor b = bb.build();
- Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])"));
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])"));
rb.cell( 94,0, 0, 0);
rb.cell(100,0, 0, 1);
rb.cell(229,0, 1, 0);
@@ -88,8 +89,9 @@ public class MatmulTestCase {
rb.cell(730,1, 1, 1);
Tensor r = rb.build();
- Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c");
- System.out.println(result);
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2")
+ .rename("d3","d2");
+ assertEquals(r, result);
}
}