diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2017-11-28 21:35:59 +0100 |
commit | b5ffe229474223844c150e99d24ca618e5e9f8dd (patch) | |
tree | 7c9ac3da58ff567fae79019bc688ba2a4e4d904c /vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | |
parent | 1d6791e6fa004ae80e85dbc6a6c7c2e4b8037a4f (diff) |
Complete prototype TensorFlow mapping
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 24 |
1 files changed, 24 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 8c4dbfb0acb..c89f63c0395 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -46,6 +46,30 @@ public class Join extends PrimitiveTensorFunction { this.combinator = combinator; } + /** Returns the type resulting from applying Join to the two given types */ + public static TensorType resultType(TensorType a, TensorType b) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (int i = 0; i < a.dimensions().size(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + for (int j = 0; j < b.dimensions().size(); ++j) { + TensorType.Dimension bDim = b.dimensions().get(j); + if (aDim.name().equals(bDim.name())) { // include + if (aDim.isIndexed() && bDim.isIndexed()) { + if (aDim.size().isPresent() || bDim.size().isPresent()) + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), + bDim.size().orElse(Integer.MAX_VALUE))); + else + typeBuilder.indexed(aDim.name()); + } + else { + typeBuilder.mapped(aDim.name()); + } + } + } + } + return typeBuilder.build(); + } + public TensorFunction argumentA() { return argumentA; } public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } |