diff options
Diffstat (limited to 'vespajlib')
4 files changed, 17 insertions, 33 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 65fbf49d334..4f81f3baea8 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1346,11 +1346,8 @@ "methods": [ "public static com.yahoo.tensor.TensorType$Value[] values()", "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", - "public java.lang.String id()", - "public boolean isEqualOrLargerThan(com.yahoo.tensor.TensorType$Value)", "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", - "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)", - "public static com.yahoo.tensor.TensorType$Value fromId(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)" ], "fields": [ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", @@ -1393,7 +1390,8 @@ ], "methods": [ "public void <init>()", - "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)" + "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", + "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 7f73ef41032..b1c7a2341c0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -29,17 +29,7 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE("double"), FLOAT("float"); - - private final String id; - - Value(String id) { this.id = id; } - - public String id() { return id; } - - public boolean isEqualOrLargerThan(TensorType.Value other) { - return this == other || largestOf(this, other) == this; - } + DOUBLE, FLOAT; public static Value largestOf(List<Value> values) { if (values.isEmpty()) return Value.DOUBLE; // Default @@ -58,15 +48,6 @@ public class TensorType { return FLOAT; } - public static Value fromId(String valueTypeString) { - switch (valueTypeString) { - case "double" : return Value.DOUBLE; - case "float" : return Value.FLOAT; - default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + - " but was '" + valueTypeString + "'"); - } - } - }; /** The empty tensor type - which is the same as a double */ @@ -162,7 +143,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { - if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false; + if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); @@ -184,9 +165,7 @@ public class TensorType { @Override public String toString() { - return "tensor" + - (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") + - "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; + return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; } @Override @@ -251,7 +230,7 @@ public class TensorType { @Override public int hashCode() { - return Objects.hash(dimensions, valueType); + return dimensions.hashCode(); } /** diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index ba23868381c..d5f77be0dd0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -55,12 +55,21 @@ public class TensorTypeParser { return new TensorType.Builder(valueType, dimensions).build(); } + public static TensorType.Value toValueType(String valueTypeString) { + switch (valueTypeString) { + case "double" : return TensorType.Value.DOUBLE; + case "float" : return TensorType.Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); try { - return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); } catch (IllegalArgumentException e) { throw formatException(fullSpecString, e.getMessage()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index a547f941d8e..d3bb702175a 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -96,8 +96,6 @@ public class TensorTypeTestCase { assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); - assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString()); - assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString()); } private static void assertTensorType(String typeSpec) { |