summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-05-07 16:31:29 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-05-07 16:31:29 +0200
commit84738dbd4aa45ecd54f2f3d04af3b31490fdf766 (patch)
tree64d1587f3a937aedd22822bdfea1e936cecec7ab /vespajlib
parenta2b9e7ec76a39f31890fd854bbd43887e9507675 (diff)
Emit float tensor types in config when specified
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java2
4 files changed, 32 insertions, 17 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 4f81f3baea8..0fded291a83 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1346,8 +1346,10 @@
"methods": [
"public static com.yahoo.tensor.TensorType$Value[] values()",
"public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)",
+ "public java.lang.String id()",
"public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)",
- "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)"
+ "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)",
+ "public static com.yahoo.tensor.TensorType$Value fromId(java.lang.String)"
],
"fields": [
"public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
@@ -1390,8 +1392,7 @@
],
"methods": [
"public void <init>()",
- "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
- "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index b1c7a2341c0..7f73ef41032 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -29,7 +29,17 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE, FLOAT;
+ DOUBLE("double"), FLOAT("float");
+
+ private final String id;
+
+ Value(String id) { this.id = id; }
+
+ public String id() { return id; }
+
+ public boolean isEqualOrLargerThan(TensorType.Value other) {
+ return this == other || largestOf(this, other) == this;
+ }
public static Value largestOf(List<Value> values) {
if (values.isEmpty()) return Value.DOUBLE; // Default
@@ -48,6 +58,15 @@ public class TensorType {
return FLOAT;
}
+ public static Value fromId(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return Value.DOUBLE;
+ case "float" : return Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
};
/** The empty tensor type - which is the same as a double */
@@ -143,7 +162,7 @@ public class TensorType {
}
private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
- if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed
+ if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false;
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
@@ -165,7 +184,9 @@ public class TensorType {
@Override
public String toString() {
- return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
+ return "tensor" +
+ (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") +
+ "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
}
@Override
@@ -230,7 +251,7 @@ public class TensorType {
@Override
public int hashCode() {
- return dimensions.hashCode();
+ return Objects.hash(dimensions, valueType);
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index d5f77be0dd0..ba23868381c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -55,21 +55,12 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
- public static TensorType.Value toValueType(String valueTypeString) {
- switch (valueTypeString) {
- case "double" : return TensorType.Value.DOUBLE;
- case "float" : return TensorType.Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
- }
- }
-
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
try {
- return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
}
catch (IllegalArgumentException e) {
throw formatException(fullSpecString, e.getMessage());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index d3bb702175a..a547f941d8e 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,6 +96,8 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
+ assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {