diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-05-07 16:31:29 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-05-07 16:31:29 +0200 |
commit | 84738dbd4aa45ecd54f2f3d04af3b31490fdf766 (patch) | |
tree | 64d1587f3a937aedd22822bdfea1e936cecec7ab /vespajlib | |
parent | a2b9e7ec76a39f31890fd854bbd43887e9507675 (diff) |
Emit float tensor types in config when specified
Diffstat (limited to 'vespajlib')
4 files changed, 32 insertions, 17 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 4f81f3baea8..0fded291a83 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1346,8 +1346,10 @@ "methods": [ "public static com.yahoo.tensor.TensorType$Value[] values()", "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", + "public java.lang.String id()", "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", - "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)" + "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)", + "public static com.yahoo.tensor.TensorType$Value fromId(java.lang.String)" ], "fields": [ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", @@ -1390,8 +1392,7 @@ ], "methods": [ "public void <init>()", - "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", - "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)" + "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index b1c7a2341c0..7f73ef41032 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -29,7 +29,17 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE, FLOAT; + DOUBLE("double"), FLOAT("float"); + + private final String id; + + Value(String id) { this.id = id; } + + public String id() { return id; } + + public boolean isEqualOrLargerThan(TensorType.Value other) { + return this == other || largestOf(this, other) == this; + } public static Value largestOf(List<Value> values) { if (values.isEmpty()) return Value.DOUBLE; // Default @@ -48,6 +58,15 @@ public class TensorType { return FLOAT; } + public static Value fromId(String valueTypeString) { + switch (valueTypeString) { + case "double" : return Value.DOUBLE; + case "float" : return Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + }; /** The empty tensor type - which is the same as a double */ @@ -143,7 +162,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { - if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed + if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false; if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); @@ -165,7 +184,9 @@ public class TensorType { @Override public String toString() { - return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; + return "tensor" + + (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") + + "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; } @Override @@ -230,7 +251,7 @@ public class TensorType { @Override public int hashCode() { - return dimensions.hashCode(); + return Objects.hash(dimensions, valueType); } /** diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index d5f77be0dd0..ba23868381c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -55,21 +55,12 @@ public class TensorTypeParser { return new TensorType.Builder(valueType, dimensions).build(); } - public static TensorType.Value toValueType(String valueTypeString) { - switch (valueTypeString) { - case "double" : return TensorType.Value.DOUBLE; - case "float" : return TensorType.Value.FLOAT; - default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + - " but was '" + valueTypeString + "'"); - } - } - private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); try { - return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); } catch (IllegalArgumentException e) { throw formatException(fullSpecString, e.getMessage()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index d3bb702175a..a547f941d8e 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -96,6 +96,8 @@ public class TensorTypeTestCase { assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); + assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString()); + assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString()); } private static void assertTensorType(String typeSpec) { |