From b5ffe229474223844c150e99d24ca618e5e9f8dd Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 28 Nov 2017 21:35:59 +0100 Subject: Complete prototype TensorFlow mapping --- .../src/main/java/com/yahoo/tensor/TensorType.java | 3 +++ .../main/java/com/yahoo/tensor/functions/Join.java | 24 ++++++++++++++++++++++ .../java/com/yahoo/tensor/functions/Matmul.java | 5 +++++ .../java/com/yahoo/tensor/functions/Reduce.java | 9 ++++++++ .../java/com/yahoo/tensor/functions/Rename.java | 4 ++++ .../java/com/yahoo/tensor/functions/Softmax.java | 6 ++++++ 6 files changed, 51 insertions(+) (limited to 'vespajlib/src/main/java/com/yahoo') diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index c05c35d6df3..c27ac57415d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -52,6 +52,9 @@ public class TensorType { public static TensorType fromSpec(String specString) { return TensorTypeParser.fromSpec(specString); } + + /** Returns the number of dimensions of this: dimensions().size() */ + public int rank() { return dimensions.size(); } /** Returns an immutable list of the dimensions of this */ public List dimensions() { return dimensions; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 8c4dbfb0acb..c89f63c0395 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } + /** Returns the type resulting from applying Join to the two given types */ + public static TensorType resultType(TensorType a, TensorType b) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (int i = 0; i < a.dimensions().size(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + for (int j = 0; j < b.dimensions().size(); ++j) { + TensorType.Dimension bDim = b.dimensions().get(j); + if (aDim.name().equals(bDim.name())) { // include + if (aDim.isIndexed() && bDim.isIndexed()) { + if (aDim.size().isPresent() || bDim.size().isPresent()) + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), + bDim.size().orElse(Integer.MAX_VALUE))); + else + typeBuilder.indexed(aDim.name()); + } + else { + typeBuilder.mapped(aDim.name()); + } + } + } + } + return typeBuilder.build(); + } + public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index bb27e937699..cbb3f159623 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.List; @@ -20,6 +21,10 @@ public class Matmul extends CompositeTensorFunction { this.argument2 = argument2; this.dimension = dimension; } + + public static TensorType resultType(TensorType a, TensorType b, String dimension) { + return Reduce.resultType(Join.resultType(a, b), ImmutableList.of(dimension)); + } @Override public List functionArguments() { return ImmutableList.of(argument1, argument2); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index cfc78be7e0c..aa28a26deb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,6 +61,15 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } + public static TensorType resultType(TensorType type, List reduceDimensions) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) { + if ( ! reduceDimensions.contains(dimension.name())) + b.dimension(dimension); + } + return b.build(); + } + public TensorFunction argument() { return argument; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 6b0daf1b49a..6e52760424e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -28,6 +28,10 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List fromDimensions; private final List toDimensions; + + public Rename(TensorFunction argument, String fromDimension, String toDimension) { + this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); + } public Rename(TensorFunction argument, List fromDimensions, List toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bf279eb24d8..45f78389c16 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -2,6 +2,8 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; +import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.List; @@ -19,6 +21,10 @@ public class Softmax extends CompositeTensorFunction { this.argument = argument; this.dimension = dimension; } + + public static TensorType resultType(TensorType type, String dimension) { + return Reduce.resultType(type, ImmutableList.of(dimension)); + } @Override public List functionArguments() { return Collections.singletonList(argument); } -- cgit v1.2.3