summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-02-23 08:30:24 +0100
committerGitHub <noreply@github.com>2018-02-23 08:30:24 +0100
commit58cef4ed623166635bd0afe07a8175c2ff5ec3dc (patch)
treec178d859808ed4b50f54a2826aaac22ecab97061 /vespajlib
parentd572d662ab899d1fd4a3832c8fc984f2d71b2b42 (diff)
parentb1f46fcd0495dbce905fb8b7318781f4cf5965a7 (diff)
Merge pull request #5118 from vespa-engine/lesters/rename-tensorflow-constants
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
2 files changed, 19 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bf1825446e4..0176dac6821 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -214,7 +214,7 @@ public class TensorType {
/** Returns a copy of this with the name set to the given name */
public abstract Dimension withName(String name);
- /** Returns true if this is an indexed bound or unboun type */
+ /** Returns true if this is an indexed bound or unbound type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
/**
@@ -261,6 +261,14 @@ public class TensorType {
return new IndexedBoundDimension(name, size);
}
+ public static Dimension indexed(String name) {
+ return new IndexedUnboundDimension(name);
+ }
+
+ public static Dimension mapped(String name) {
+ return new MappedDimension(name);
+ }
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
@@ -374,6 +382,15 @@ public class TensorType {
addDimensionsOf(type);
}
+ /**
+ * Creates a builder from the given dimensions.
+ */
+ public Builder(Iterable<Dimension> dimensions) {
+ for (TensorType.Dimension dimension : dimensions) {
+ dimension(dimension);
+ }
+ }
+
private static final boolean supportsMixedTypes = false;
private void addDimensionsOf(TensorType type) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 17e1c103ea3..50b0e706a43 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -251,7 +251,7 @@ public class Join extends PrimitiveTensorFunction {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
- joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
+// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
return builder.build();
}