diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
commit | b5ffe229474223844c150e99d24ca618e5e9f8dd (patch) | |
tree | 7c9ac3da58ff567fae79019bc688ba2a4e4d904c /vespajlib/src/test/java/com | |
parent | 1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff) |
Complete prototype TensorFlow mapping
Diffstat (limited to 'vespajlib/src/test/java/com')
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java | 22 |
1 files changed, 12 insertions, 10 deletions
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java index 6606e278102..9643c0a56e7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java @@ -14,8 +14,8 @@ public class MatmulTestCase { @Test public void testMatmul2d() { - // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])")); + // d0 is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])")); ab.cell( 1,0, 0); ab.cell( 2,0, 1); ab.cell( 3,0, 2); @@ -24,7 +24,7 @@ public class MatmulTestCase { ab.cell( 6,1, 2); Tensor a = ab.build(); - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])")); + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])")); bb.cell( 7,0, 0); bb.cell( 8,0, 1); bb.cell( 9,1, 0); @@ -33,21 +33,22 @@ public class MatmulTestCase { bb.cell(12,2, 1); Tensor b = bb.build(); - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])")); + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])")); rb.cell( 58,0, 0); rb.cell( 64,0, 1); rb.cell(139,1, 0); rb.cell(154,1, 1); Tensor r = rb.build(); - Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b"); + Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1") + .rename("d2","d1"); assertEquals(r, result); } @Test public void testMatmul3d() { // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])")); + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])")); ab.cell( 1,0, 0, 0); ab.cell( 2,0, 0, 1); ab.cell( 3,0, 0, 2); @@ -62,7 +63,7 @@ public class MatmulTestCase { ab.cell(12,1, 1, 2); Tensor a = ab.build(); - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])")); + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])")); bb.cell(13,0, 0, 0); bb.cell(14,0, 0, 1); bb.cell(15,0, 1, 0); @@ -77,7 +78,7 @@ public class MatmulTestCase { bb.cell(24,1, 2, 1); Tensor b = bb.build(); - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])")); + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])")); rb.cell( 94,0, 0, 0); rb.cell(100,0, 0, 1); rb.cell(229,0, 1, 0); @@ -88,8 +89,9 @@ public class MatmulTestCase { rb.cell(730,1, 1, 1); Tensor r = rb.build(); - Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c"); - System.out.println(result); + Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2") + .rename("d3","d2"); + assertEquals(r, result); } } |