summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-05 16:04:42 +0100
committerLester Solbakken <lesters@oath.com>2018-02-22 12:54:34 +0100
commitb1f46fcd0495dbce905fb8b7318781f4cf5965a7 (patch)
treed0a0506fe66e5af4af2a927101a0eb9ed9420d38 /vespajlib
parente307df56eaaf5b0ebca5aefb7f7e0c5c3a970bdb (diff)
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
2 files changed, 19 insertions, 2 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 14cd3e70866..33c8ba4f5e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -207,7 +207,7 @@ public class TensorType {
/** Returns a copy of this with the name set to the given name */
public abstract Dimension withName(String name);
- /** Returns true if this is an indexed bound or unboun type */
+ /** Returns true if this is an indexed bound or unbound type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
/**
@@ -254,6 +254,14 @@ public class TensorType {
return new IndexedBoundDimension(name, size);
}
+ public static Dimension indexed(String name) {
+ return new IndexedUnboundDimension(name);
+ }
+
+ public static Dimension mapped(String name) {
+ return new MappedDimension(name);
+ }
+
}
public static class IndexedBoundDimension extends TensorType.Dimension {
@@ -367,6 +375,15 @@ public class TensorType {
addDimensionsOf(type);
}
+ /**
+ * Creates a builder from the given dimensions.
+ */
+ public Builder(Iterable<Dimension> dimensions) {
+ for (TensorType.Dimension dimension : dimensions) {
+ dimension(dimension);
+ }
+ }
+
private static final boolean supportsMixedTypes = false;
private void addDimensionsOf(TensorType type) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 7812c985091..3f815a118d5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -251,7 +251,7 @@ public class Join extends PrimitiveTensorFunction {
int[] aToIndexes = mapIndexes(a.type(), joinedType);
int[] bToIndexes = mapIndexes(b.type(), joinedType);
joinTo(a, b, joinedType, joinedSize, aToIndexes, bToIndexes, false, builder);
- joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
+// joinTo(b, a, joinedType, joinedSize, bToIndexes, aToIndexes, true, builder);
return builder.build();
}