aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-05-08 12:27:09 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-08-13 13:22:31 +0200
commitb32202458cce6a00686fab7bac777b6cb9ee34de (patch)
treed4fb80f544a13442f56e5d7d30ec09e26c6f8fe3 /vespajlib
parente15d87688f4da812e93500598fa653164b47b9bd (diff)
Merge
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java33
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java2
4 files changed, 33 insertions, 21 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 6f37b9edea4..9c425570a7e 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1364,9 +1364,12 @@
"methods": [
"public static com.yahoo.tensor.TensorType$Value[] values()",
"public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)",
+ "public java.lang.String id()",
+ "public boolean isEqualOrLargerThan(com.yahoo.tensor.TensorType$Value)",
"public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)",
"public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)",
- "public java.lang.String toString()"
+ "public java.lang.String toString()",
+ "public static com.yahoo.tensor.TensorType$Value fromId(java.lang.String)"
],
"fields": [
"public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
@@ -1409,8 +1412,7 @@
],
"methods": [
"public void <init>()",
- "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
- "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 9869f1e908c..319947607d2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -29,7 +29,17 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE, FLOAT;
+ DOUBLE("double"), FLOAT("float");
+
+ private final String id;
+
+ Value(String id) { this.id = id; }
+
+ public String id() { return id; }
+
+ public boolean isEqualOrLargerThan(TensorType.Value other) {
+ return this == other || largestOf(this, other) == this;
+ }
public static Value largestOf(List<Value> values) {
if (values.isEmpty()) return Value.DOUBLE; // Default
@@ -51,6 +61,15 @@ public class TensorType {
@Override
public String toString() { return name().toLowerCase(); }
+ public static Value fromId(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return Value.DOUBLE;
+ case "float" : return Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
};
/** The empty tensor type - which is the same as a double */
@@ -146,7 +165,7 @@ public class TensorType {
}
private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
- if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed
+ if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false;
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
@@ -168,11 +187,9 @@ public class TensorType {
@Override
public String toString() {
- if ((rank() == 0) || (valueType == Value.DOUBLE)) {
- return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
- } else {
- return "tensor<" + valueType + ">(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
- }
+ return "tensor" +
+ (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") +
+ "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
}
@Override
@@ -238,7 +255,7 @@ public class TensorType {
@Override
public int hashCode() {
- return dimensions.hashCode();
+ return Objects.hash(dimensions, valueType);
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index 1f426942c5f..def3ab6b4ec 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -56,21 +56,12 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
- public static TensorType.Value toValueType(String valueTypeString) {
- switch (valueTypeString) {
- case "double" : return TensorType.Value.DOUBLE;
- case "float" : return TensorType.Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
- }
- }
-
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
try {
- return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
}
catch (IllegalArgumentException e) {
throw formatException(fullSpecString, e.getMessage());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index d3bb702175a..a547f941d8e 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,6 +96,8 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
+ assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {