summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
commit5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch)
tree2b65d4f48b92bf7ec846b3efd5d5259244bc234a /vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
parent6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff)
Add tensor value type
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java27
1 files changed, 24 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bded55405c0..5bd44cbc327 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -25,8 +25,29 @@ import java.util.stream.Collectors;
public class TensorType {
/** The permissible cell value types. Default is double. */
- // Types added here must also be added to TensorTypeParser.parseValueTypeSpec
- public enum Value { DOUBLE, FLOAT};
+ public enum Value {
+
+ // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
+ DOUBLE, FLOAT;
+
+ public static Value largestOf(List<Value> values) {
+ if (values.isEmpty()) return Value.DOUBLE; // Default
+ Value largest = null;
+ for (Value value : values) {
+ if (largest == null)
+ largest = value;
+ else
+ largest = largestOf(largest, value);
+ }
+ return largest;
+ }
+
+ public static Value largestOf(Value value1, Value value2) {
+ if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
+ return FLOAT;
+ }
+
+ };
/** The empty tensor type - which is the same as a double */
public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
@@ -170,7 +191,7 @@ public class TensorType {
if (this.equals(other)) return Optional.of(this); // shortcut
if (this.dimensions.size() != other.dimensions.size()) return Optional.empty();
- Builder b = new Builder();
+ Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType));
for (int i = 0; i < dimensions.size(); i++) {
Dimension thisDim = this.dimensions().get(i);
Dimension otherDim = other.dimensions().get(i);