diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-05 16:04:42 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-22 12:54:34 +0100 |
commit | b1f46fcd0495dbce905fb8b7318781f4cf5965a7 (patch) | |
tree | d0a0506fe66e5af4af2a927101a0eb9ed9420d38 /vespajlib | |
parent | e307df56eaaf5b0ebca5aefb7f7e0c5c3a970bdb (diff) |
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 19 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 2 |
2 files changed, 19 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 14cd3e70866..33c8ba4f5e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -207,7 +207,7 @@ public class TensorType { /** Returns a copy of this with the name set to the given name */ public abstract Dimension withName(String name); - /** Returns true if this is an indexed bound or unboun type */ + /** Returns true if this is an indexed bound or unbound type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } /** @@ -254,6 +254,14 @@ public class TensorType { return new IndexedBoundDimension(name, size); } + public static Dimension indexed(String name) { + return new IndexedUnboundDimension(name); + } + + public static Dimension mapped(String name) { + return new MappedDimension(name); + } + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -367,6 +375,15 @@ public class TensorType { addDimensionsOf(type); } + /** + * Creates a builder from the given dimensions. + */ + public Builder(Iterable<Dimension> dimensions) { + for (TensorType.Dimension dimension : dimensions) { + dimension(dimension); + } + } + private static final boolean supportsMixedTypes = false; private void addDimensionsOf(TensorType type) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 7812c985091..3f815a118d5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -251,7 +251,7 @@ public class Join extends PrimitiveTensorFunction { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder); - joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); +// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); return builder.build(); } |