summaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@gmail.com>2019-05-08 12:24:41 +0200
committerGitHub <noreply@github.com>2019-05-08 12:24:41 +0200
commit172698ac2c7af46e1446f7709ad0d67a444744c0 (patch)
tree7648de3a0ea53aa2aa3e41570d52c9df4ed7d904 /vespajlib/src
parent6c8283fc5264ae59f0f5eb90b073add1d3552ab3 (diff)
Revert "Bratseth/emit float tensors in config"
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java2
3 files changed, 14 insertions, 28 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 7f73ef41032..b1c7a2341c0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -29,17 +29,7 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE("double"), FLOAT("float");
-
- private final String id;
-
- Value(String id) { this.id = id; }
-
- public String id() { return id; }
-
- public boolean isEqualOrLargerThan(TensorType.Value other) {
- return this == other || largestOf(this, other) == this;
- }
+ DOUBLE, FLOAT;
public static Value largestOf(List<Value> values) {
if (values.isEmpty()) return Value.DOUBLE; // Default
@@ -58,15 +48,6 @@ public class TensorType {
return FLOAT;
}
- public static Value fromId(String valueTypeString) {
- switch (valueTypeString) {
- case "double" : return Value.DOUBLE;
- case "float" : return Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
- }
- }
-
};
/** The empty tensor type - which is the same as a double */
@@ -162,7 +143,7 @@ public class TensorType {
}
private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
- if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false;
+ if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
@@ -184,9 +165,7 @@ public class TensorType {
@Override
public String toString() {
- return "tensor" +
- (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") +
- "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
+ return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
}
@Override
@@ -251,7 +230,7 @@ public class TensorType {
@Override
public int hashCode() {
- return Objects.hash(dimensions, valueType);
+ return dimensions.hashCode();
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index ba23868381c..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -55,12 +55,21 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
try {
- return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
}
catch (IllegalArgumentException e) {
throw formatException(fullSpecString, e.getMessage());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index a547f941d8e..d3bb702175a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,8 +96,6 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
- assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
- assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {