summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:59 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:59 +0100
commitb5ffe229474223844c150e99d24ca618e5e9f8dd (patch)
tree7c9ac3da58ff567fae79019bc688ba2a4e4d904c /vespajlib
parent1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff)
Complete prototype TensorFlow mapping
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java22
7 files changed, 63 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index c05c35d6df3..c27ac57415d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -52,6 +52,9 @@ public class TensorType {
public static TensorType fromSpec(String specString) {
return TensorTypeParser.fromSpec(specString);
}
+
+ /** Returns the number of dimensions of this: dimensions().size() */
+ public int rank() { return dimensions.size(); }
/** Returns an immutable list of the dimensions of this */
public List<Dimension> dimensions() { return dimensions; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 8c4dbfb0acb..c89f63c0395 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction {
this.combinator = combinator;
}
+ /** Returns the type resulting from applying Join to the two given types */
+ public static TensorType resultType(TensorType a, TensorType b) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (int i = 0; i < a.dimensions().size(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ for (int j = 0; j < b.dimensions().size(); ++j) {
+ TensorType.Dimension bDim = b.dimensions().get(j);
+ if (aDim.name().equals(bDim.name())) { // include
+ if (aDim.isIndexed() && bDim.isIndexed()) {
+ if (aDim.size().isPresent() || bDim.size().isPresent())
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
+ bDim.size().orElse(Integer.MAX_VALUE)));
+ else
+ typeBuilder.indexed(aDim.name());
+ }
+ else {
+ typeBuilder.mapped(aDim.name());
+ }
+ }
+ }
+ }
+ return typeBuilder.build();
+ }
+
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index bb27e937699..cbb3f159623 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.List;
@@ -20,6 +21,10 @@ public class Matmul extends CompositeTensorFunction {
this.argument2 = argument2;
this.dimension = dimension;
}
+
+ public static TensorType resultType(TensorType a, TensorType b, String dimension) {
+ return Reduce.resultType(Join.resultType(a, b), ImmutableList.of(dimension));
+ }
@Override
public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index cfc78be7e0c..aa28a26deb2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction {
this.dimensions = ImmutableList.copyOf(dimensions);
}
+ public static TensorType resultType(TensorType type, List<String> reduceDimensions) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if ( ! reduceDimensions.contains(dimension.name()))
+ b.dimension(dimension);
+ }
+ return b.build();
+ }
+
public TensorFunction argument() { return argument; }
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 6b0daf1b49a..6e52760424e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -28,6 +28,10 @@ public class Rename extends PrimitiveTensorFunction {
private final TensorFunction argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
+
+ public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
+ }
public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bf279eb24d8..45f78389c16 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -2,6 +2,8 @@
package com.yahoo.tensor.functions;
import com.google.common.annotations.Beta;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.TensorType;
import java.util.Collections;
import java.util.List;
@@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction {
this.argument = argument;
this.dimension = dimension;
}
+
+ public static TensorType resultType(TensorType type, String dimension) {
+ return Reduce.resultType(type, ImmutableList.of(dimension));
+ }
@Override
public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); }
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
index 6606e278102..9643c0a56e7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/MatmulTestCase.java
@@ -14,8 +14,8 @@ public class MatmulTestCase {
@Test
public void testMatmul2d() {
- // Convention: a is the 'outermost' dimension, etc.
- Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3])"));
+ // d0 is the 'outermost' dimension, etc.
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3])"));
ab.cell( 1,0, 0);
ab.cell( 2,0, 1);
ab.cell( 3,0, 2);
@@ -24,7 +24,7 @@ public class MatmulTestCase {
ab.cell( 6,1, 2);
Tensor a = ab.build();
- Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[3],b[2])"));
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[3],d1[2])"));
bb.cell( 7,0, 0);
bb.cell( 8,0, 1);
bb.cell( 9,1, 0);
@@ -33,21 +33,22 @@ public class MatmulTestCase {
bb.cell(12,2, 1);
Tensor b = bb.build();
- Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],c[2])"));
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2])"));
rb.cell( 58,0, 0);
rb.cell( 64,0, 1);
rb.cell(139,1, 0);
rb.cell(154,1, 1);
Tensor r = rb.build();
- Tensor result = a.matmul(b.rename(ImmutableList.of("a","b"),ImmutableList.of("b","c")), "b");
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d0","d1"), ImmutableList.of("d1","d2")), "d1")
+ .rename("d2","d1");
assertEquals(r, result);
}
@Test
public void testMatmul3d() {
// Convention: a is the 'outermost' dimension, etc.
- Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],c[3])"));
+ Tensor.Builder ab = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[3])"));
ab.cell( 1,0, 0, 0);
ab.cell( 2,0, 0, 1);
ab.cell( 3,0, 0, 2);
@@ -62,7 +63,7 @@ public class MatmulTestCase {
ab.cell(12,1, 1, 2);
Tensor a = ab.build();
- Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[3],c[2])"));
+ Tensor.Builder bb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[3],d2[2])"));
bb.cell(13,0, 0, 0);
bb.cell(14,0, 0, 1);
bb.cell(15,0, 1, 0);
@@ -77,7 +78,7 @@ public class MatmulTestCase {
bb.cell(24,1, 2, 1);
Tensor b = bb.build();
- Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(a[2],b[2],d[2])"));
+ Tensor.Builder rb = Tensor.Builder.of(TensorType.fromSpec("tensor(d0[2],d1[2],d2[2])"));
rb.cell( 94,0, 0, 0);
rb.cell(100,0, 0, 1);
rb.cell(229,0, 1, 0);
@@ -88,8 +89,9 @@ public class MatmulTestCase {
rb.cell(730,1, 1, 1);
Tensor r = rb.build();
- Tensor result = a.matmul(b.rename(ImmutableList.of("b","c"),ImmutableList.of("c","d")), "c");
- System.out.println(result);
+ Tensor result = a.matmul(b.rename(ImmutableList.of("d1","d2"), ImmutableList.of("d2","d3")), "d2")
+ .rename("d3","d2");
+ assertEquals(r, result);
}
}