aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@gmail.com>2019-05-08 12:24:41 +0200
committerGitHub <noreply@github.com>2019-05-08 12:24:41 +0200
commit172698ac2c7af46e1446f7709ad0d67a444744c0 (patch)
tree7648de3a0ea53aa2aa3e41570d52c9df4ed7d904 /vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
parent6c8283fc5264ae59f0f5eb90b073add1d3552ab3 (diff)
Revert "Bratseth/emit float tensors in config"
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java11
1 files changed, 10 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index ba23868381c..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -55,12 +55,21 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
try {
- return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
}
catch (IllegalArgumentException e) {
throw formatException(fullSpecString, e.getMessage());