aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:59 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-11-28 21:35:59 +0100
commitb5ffe229474223844c150e99d24ca618e5e9f8dd (patch)
tree7c9ac3da58ff567fae79019bc688ba2a4e4d904c /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
parent1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff)
Complete prototype TensorFlow mapping
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java24
1 files changed, 24 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 8c4dbfb0acb..c89f63c0395 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction {
this.combinator = combinator;
}
+ /** Returns the type resulting from applying Join to the two given types */
+ public static TensorType resultType(TensorType a, TensorType b) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (int i = 0; i < a.dimensions().size(); ++i) {
+ TensorType.Dimension aDim = a.dimensions().get(i);
+ for (int j = 0; j < b.dimensions().size(); ++j) {
+ TensorType.Dimension bDim = b.dimensions().get(j);
+ if (aDim.name().equals(bDim.name())) { // include
+ if (aDim.isIndexed() && bDim.isIndexed()) {
+ if (aDim.size().isPresent() || bDim.size().isPresent())
+ typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE),
+ bDim.size().orElse(Integer.MAX_VALUE)));
+ else
+ typeBuilder.indexed(aDim.name());
+ }
+ else {
+ typeBuilder.mapped(aDim.name());
+ }
+ }
+ }
+ }
+ return typeBuilder.build();
+ }
+
public TensorFunction argumentA() { return argumentA; }
public TensorFunction argumentB() { return argumentB; }
public DoubleBinaryOperator combinator() { return combinator; }