diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 29 |
1 files changed, 25 insertions, 4 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b1c7a2341c0..7f73ef41032 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -29,7 +29,17 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE, FLOAT; + DOUBLE("double"), FLOAT("float"); + + private final String id; + + Value(String id) { this.id = id; } + + public String id() { return id; } + + public boolean isEqualOrLargerThan(TensorType.Value other) { + return this == other || largestOf(this, other) == this; + } public static Value largestOf(List<Value> values) { if (values.isEmpty()) return Value.DOUBLE; // Default @@ -48,6 +58,15 @@ public class TensorType { return FLOAT; } + public static Value fromId(String valueTypeString) { + switch (valueTypeString) { + case "double" : return Value.DOUBLE; + case "float" : return Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + }; /** The empty tensor type - which is the same as a double */ @@ -143,7 +162,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { - if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed + if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false; if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); @@ -165,7 +184,9 @@ public class TensorType { @Override public String toString() { - return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; + return "tensor" + + (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") + + "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; } @Override @@ -230,7 +251,7 @@ public class TensorType { @Override public int hashCode() { - return dimensions.hashCode(); + return Objects.hash(dimensions, valueType); } /** |