summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2017-01-10 14:47:44 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2017-01-10 14:47:44 +0100
commit7d48aa76c6c89851bf5d99109e41d2b485bc87ab (patch)
tree9a624a94ca3d5e8071c90f6c9a8cbb5bb4c55e44 /vespajlib/src/main/java/com/yahoo/tensor
parent8c6329d755c778850bba7c1c1ed69eafebba8863 (diff)
Maintain TensorType in documents
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java19
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java1
2 files changed, 20 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 82f36972a47..e58a08c8d31 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -77,6 +77,25 @@ public class TensorType {
return Optional.empty();
}
+ /**
+ * Returns whether a tensor of the given type can be assigned to this type,
+ * i.e of this type is a generalization of the given type.
+ */
+ public boolean isAssignableTo(TensorType other) {
+ if (other.dimensions().size() != this.dimensions().size()) return false;
+ for (int i = 0; i < other.dimensions().size(); i++) {
+ Dimension thisDimension = this.dimensions().get(i);
+ Dimension otherDimension = other.dimensions().get(i);
+ if (thisDimension.isIndexed() != other.isIndexed()) return false;
+ if ( ! thisDimension.name().equals(otherDimension.name())) return false;
+ if (thisDimension.size().isPresent()) {
+ if ( ! otherDimension.size().isPresent()) return false;
+ if (otherDimension.size().get() > thisDimension.size().get() ) return false;
+ }
+ }
+ return true;
+ }
+
@Override
public String toString() {
return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 5a45f20b6d8..5bb93b9da83 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.serialization;
import com.google.common.annotations.Beta;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
/**
* Class used by clients for serializing a Tensor object into binary format or