summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/Tensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java23
1 files changed, 15 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 260c48ace7f..6f655fd5860 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -73,8 +73,6 @@ public interface Tensor {
if (type().dimensions().size() > 0)
throw new IllegalStateException("This tensor is not dimensionless. Dimensions: " + type().dimensions().size());
if (size() == 0) return Double.NaN;
- if (size() > 1)
- throw new IllegalStateException("This tensor does not have a single value, it has " + size());
return valueIterator().next();
}
@@ -213,9 +211,7 @@ public interface Tensor {
/** Returns true if the two given tensors are mathematically equivalent, that is whether both have the same content */
static boolean equals(Tensor a, Tensor b) {
- if (a == b) return true;
- if ( ! a.cells().equals(b.cells())) return false;
- return true;
+ return a == b || a.cells().equals(b.cells());
}
// ----------------- Factories
@@ -250,9 +246,8 @@ public interface Tensor {
}
interface Builder {
-
+
/** Creates a suitable builder for the given type */
- // TODO: Create version of this which takes size info and use it when possible
static Builder of(TensorType type) {
boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
@@ -263,7 +258,19 @@ public interface Tensor {
else // indexed or empty
return IndexedTensor.Builder.of(type);
}
-
+
+ /** Creates a suitable builder for the given type */
+ static Builder of(TensorType type, int[] dimensionSizes) {
+ boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed());
+ boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
+ if (containsIndexed && containsMapped)
+ throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
+ if (containsMapped)
+ return MappedTensor.Builder.of(type);
+ else // indexed or empty
+ return IndexedTensor.Builder.of(type, dimensionSizes);
+ }
+
/** Returns the type this is building */
TensorType type();