diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
commit | b5ffe229474223844c150e99d24ca618e5e9f8dd (patch) | |
tree | 7c9ac3da58ff567fae79019bc688ba2a4e4d904c /vespajlib | |
parent | 1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff) |
Complete prototype TensorFlow mapping
Diffstat (limited to 'vespajlib')
7 files changed, 63 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c05c35d6df3..c27ac57415d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -52,6 +52,9 @@ public class TensorType { public static TensorType fromSpec(String specString) { return TensorTypeParser.fromSpec(specString); } + + /** Returns the number of dimensions of this: dimensions().size() */ + public int rank() { return dimensions.size(); } /** Returns an immutable list of the dimensions of this */ public List<Dimension> dimensions() { return dimensions; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 8c4dbfb0acb..c89f63c0395 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } + /** Returns the type resulting from applying Join to the two given types */ + public static TensorType resultType(TensorType a, TensorType b) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (int i = 0; i < a.dimensions().size(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + for (int j = 0; j < b.dimensions().size(); ++j) { + TensorType.Dimension bDim = b.dimensions().get(j); + if (aDim.name().equals(bDim.name())) { // include + if (aDim.isIndexed() && bDim.isIndexed()) { + if (aDim.size().isPresent() || bDim.size().isPresent()) + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), + bDim.size().orElse(Integer.MAX_VALUE))); + else + typeBuilder.indexed(aDim.name()); + } + else { + typeBuilder.mapped(aDim.name()); + } + } + } + } + return typeBuilder.build(); + } + public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index bb27e937699..cbb3f159623 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.List; @@ -20,6 +21,10 @@ public class Matmul extends CompositeTensorFunction { this.argument2 = argument2; this.dimension = dimension; } + + public static TensorType resultType(TensorType a, TensorType b, String dimension) { + return Reduce.resultType(Join.resultType(a, b), ImmutableList.of(dimension)); + } @Override public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index cfc78be7e0c..aa28a26deb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } + public static TensorType resultType(TensorType type, List<String> reduceDimensions) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) { + if ( ! reduceDimensions.contains(dimension.name())) + b.dimension(dimension); + } + return b.build(); + } + public TensorFunction argument() { return argument; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6b0daf1b49a..6e52760424e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -28,6 +28,10 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; + + public Rename(TensorFunction argument, String fromDimension, String toDimension) { + this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); + } public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bf279eb24d8..45f78389c16 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -2,6 +2,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.List; @@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } + + public static TensorType resultType(TensorType type, String dimension) { + return Reduce.resultType(type, ImmutableList.of(dimension)); + } @Override public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java index 6606e278102..9643c0a56e7 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java @@ -14,8 +14,8 @@ public class MatmulTestCase { @Test public void testMatmul2d() { - // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])")); + // d0 is the 'outermost' dimension, etc. + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])")); ab.cell( 1,0, 0); ab.cell( 2,0, 1); ab.cell( 3,0, 2); @@ -24,7 +24,7 @@ public class MatmulTestCase { ab.cell( 6,1, 2); Tensor a = ab.build(); - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])")); + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])")); bb.cell( 7,0, 0); bb.cell( 8,0, 1); bb.cell( 9,1, 0); @@ -33,21 +33,22 @@ public class MatmulTestCase { bb.cell(12,2, 1); Tensor b = bb.build(); - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])")); + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])")); rb.cell( 58,0, 0); rb.cell( 64,0, 1); rb.cell(139,1, 0); rb.cell(154,1, 1); Tensor r = rb.build(); - Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b"); + Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1") + .rename("d2","d1"); assertEquals(r, result); } @Test public void testMatmul3d() { // Convention: a is the 'outermost' dimension, etc. - Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])")); + Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])")); ab.cell( 1,0, 0, 0); ab.cell( 2,0, 0, 1); ab.cell( 3,0, 0, 2); @@ -62,7 +63,7 @@ public class MatmulTestCase { ab.cell(12,1, 1, 2); Tensor a = ab.build(); - Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])")); + Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])")); bb.cell(13,0, 0, 0); bb.cell(14,0, 0, 1); bb.cell(15,0, 1, 0); @@ -77,7 +78,7 @@ public class MatmulTestCase { bb.cell(24,1, 2, 1); Tensor b = bb.build(); - Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])")); + Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])")); rb.cell( 94,0, 0, 0); rb.cell(100,0, 0, 1); rb.cell(229,0, 1, 0); @@ -88,8 +89,9 @@ public class MatmulTestCase { rb.cell(730,1, 1, 1); Tensor r = rb.build(); - Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c"); - System.out.println(result); + Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2") + .rename("d3","d2"); + assertEquals(r, result); } } |