diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-05-08 12:27:09 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-08-13 13:22:31 +0200 |
commit | b32202458cce6a00686fab7bac777b6cb9ee34de (patch) | |
tree | d4fb80f544a13442f56e5d7d30ec09e26c6f8fe3 /vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | |
parent | e15d87688f4da812e93500598fa653164b47b9bd (diff) |
Merge
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorType.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 33 |
1 files changed, 25 insertions, 8 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 9869f1e908c..319947607d2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -29,7 +29,17 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE, FLOAT; + DOUBLE("double"), FLOAT("float"); + + private final String id; + + Value(String id) { this.id = id; } + + public String id() { return id; } + + public boolean isEqualOrLargerThan(TensorType.Value other) { + return this == other || largestOf(this, other) == this; + } public static Value largestOf(List<Value> values) { if (values.isEmpty()) return Value.DOUBLE; // Default @@ -51,6 +61,15 @@ public class TensorType { @Override public String toString() { return name().toLowerCase(); } + public static Value fromId(String valueTypeString) { + switch (valueTypeString) { + case "double" : return Value.DOUBLE; + case "float" : return Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + }; /** The empty tensor type - which is the same as a double */ @@ -146,7 +165,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { - if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed + if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false; if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); @@ -168,11 +187,9 @@ public class TensorType { @Override public String toString() { - if ((rank() == 0) || (valueType == Value.DOUBLE)) { - return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; - } else { - return "tensor<" + valueType + ">(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; - } + return "tensor" + + (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") + + "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; } @Override @@ -238,7 +255,7 @@ public class TensorType { @Override public int hashCode() { - return dimensions.hashCode(); + return Objects.hash(dimensions, valueType); } /** |