diff options
author | Jon Bratseth <jonbratseth@yahoo.com> | 2018-02-23 08:30:24 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-02-23 08:30:24 +0100 |
commit | 58cef4ed623166635bd0afe07a8175c2ff5ec3dc (patch) | |
tree | c178d859808ed4b50f54a2826aaac22ecab97061 /vespajlib | |
parent | d572d662ab899d1fd4a3832c8fc984f2d71b2b42 (diff) | |
parent | b1f46fcd0495dbce905fb8b7318781f4cf5965a7 (diff) |
Merge pull request #5118 from vespa-engine/lesters/rename-tensorflow-constants
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 19 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java | 2 |
2 files changed, 19 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bf1825446e4..0176dac6821 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -214,7 +214,7 @@ public class TensorType { /** Returns a copy of this with the name set to the given name */ public abstract Dimension withName(String name); - /** Returns true if this is an indexed bound or unboun type */ + /** Returns true if this is an indexed bound or unbound type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } /** @@ -261,6 +261,14 @@ public class TensorType { return new IndexedBoundDimension(name, size); } + public static Dimension indexed(String name) { + return new IndexedUnboundDimension(name); + } + + public static Dimension mapped(String name) { + return new MappedDimension(name); + } + } public static class IndexedBoundDimension extends TensorType.Dimension { @@ -374,6 +382,15 @@ public class TensorType { addDimensionsOf(type); } + /** + * Creates a builder from the given dimensions. + */ + public Builder(Iterable<Dimension> dimensions) { + for (TensorType.Dimension dimension : dimensions) { + dimension(dimension); + } + } + private static final boolean supportsMixedTypes = false; private void addDimensionsOf(TensorType type) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 17e1c103ea3..50b0e706a43 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -251,7 +251,7 @@ public class Join extends PrimitiveTensorFunction { int[] aToIndexes = mapIndexes(a.type(), joinedType); int[] bToIndexes = mapIndexes(b.type(), joinedType); joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder); - joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); +// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder); return builder.build(); } |