diff options
16 files changed, 65 insertions, 50 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java index 977946cbf71..86c20bf96af 100644 --- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java +++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java @@ -400,15 +400,17 @@ public class DeployState implements ConfigDefinitionStore { String searchName = builder.importReader(reader, readerName, logger); String sdName = stripSuffix(readerName, ApplicationPackage.SD_NAME_SUFFIX); names.put(searchName, sdName); - if (!sdName.equals(searchName)) { + if ( ! sdName.equals(searchName)) { throw new IllegalArgumentException("Search definition file name ('" + sdName + "') and name of " + "search element ('" + searchName + "') are not equal for file '" + readerName + "'"); } } catch (ParseException e) { - throw new IllegalArgumentException("Could not parse search definition file '" + getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e); + throw new IllegalArgumentException("Could not parse search definition file '" + + getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e); } catch (IOException e) { - throw new IllegalArgumentException("Could not read search definition file '" + getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e); + throw new IllegalArgumentException("Could not read search definition file '" + + getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e); } finally { closeIgnoreException(reader.getReader()); } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java index bf1939e2c3d..176784a792e 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java @@ -172,6 +172,7 @@ public class SearchBuilder { } catch (TokenMgrException e) { throw new ParseException("Unknown symbol: " + e.getMessage()); } catch (ParseException pe) { + pe.printStackTrace(); throw new ParseException(stream.formatException(Exceptions.toMessageString(pe))); } return importRawSearch(search); diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj index 3141f7f7164..db22e73268c 100644 --- a/config-model/src/main/javacc/SDParser.jj +++ b/config-model/src/main/javacc/SDParser.jj @@ -2405,7 +2405,7 @@ TensorType tensorType(String errorMessage) : String tensorTypeString; } { - <TENSOR_TYPE> { tensorTypeString = token.image; } + ( <TENSOR_TYPE> ) { tensorTypeString = token.image; } { TensorType tensorType; try { diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd index d7b3cc76339..ab5e42f983d 100644 --- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd +++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd @@ -2,7 +2,7 @@ search test { document test { - field argument type tensor(d0[],d1[784]) { + field argument type tensor<float>(d0[],d1[784]) { indexing: attribute } } diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg index f1c95da7084..0c556aad868 100644 --- a/config-model/src/test/derived/tensor/attributes.cfg +++ b/config-model/src/test/derived/tensor/attributes.cfg @@ -17,7 +17,7 @@ attribute[].arity 8 attribute[].lowerbound -9223372036854775808 attribute[].upperbound 9223372036854775807 attribute[].densepostinglistthreshold 0.4 -attribute[].tensortype "tensor(x[2],y[])" +attribute[].tensortype "tensor<float>(x[2],y[1])" attribute[].imported false attribute[].name "f3" attribute[].datatype TENSOR diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg index 0b75644f4ae..4e04ad4c309 100644 --- a/config-model/src/test/derived/tensor/documenttypes.cfg +++ b/config-model/src/test/derived/tensor/documenttypes.cfg @@ -29,7 +29,7 @@ documenttype[].datatype[].sstruct.field[].name "f2" documenttype[].datatype[].sstruct.field[].id 2080644671 documenttype[].datatype[].sstruct.field[].id_v6 1424572148 documenttype[].datatype[].sstruct.field[].datatype 21 -documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x[2],y[])" +documenttype[].datatype[].sstruct.field[].detailedtype "tensor<float>(x[2],y[1])" documenttype[].datatype[].sstruct.field[].name "f3" documenttype[].datatype[].sstruct.field[].id 1295091863 documenttype[].datatype[].sstruct.field[].id_v6 1444109654 diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 87c0b6fab42..1ce9227d323 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -1,6 +1,6 @@ rankprofile[].name "default" rankprofile[].fef.property[].name "vespa.type.attribute.f2" -rankprofile[].fef.property[].value "tensor(x[2],y[])" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" @@ -17,7 +17,7 @@ rankprofile[].fef.property[].value "0" rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" rankprofile[].fef.property[].value "true" rankprofile[].fef.property[].name "vespa.type.attribute.f2" -rankprofile[].fef.property[].value "tensor(x[2],y[])" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" @@ -30,7 +30,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(map(attribute(f4), f(x)(x * x)) + reduce(tensor(x[2],y[3])(random), count) * rename(attribute(f4), (x, y), (y, x)), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" -rankprofile[].fef.property[].value "tensor(x[2],y[])" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" @@ -43,7 +43,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" -rankprofile[].fef.property[].value "tensor(x[2],y[])" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" @@ -60,7 +60,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(rankingExpression(joinedtensors), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" -rankprofile[].fef.property[].value "tensor(x[2],y[])" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" @@ -73,7 +73,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(attribute(f5), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" -rankprofile[].fef.property[].value "tensor(x[2],y[])" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" rankprofile[].fef.property[].value "tensor(x{})" rankprofile[].fef.property[].name "vespa.type.attribute.f4" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index 622be033229..b31352a2105 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -5,7 +5,7 @@ search tensor { field f1 type tensor(x[]) { indexing: summary } - field f2 type tensor(x[2],y[]) { + field f2 type tensor<float>(x[2],y[1]) { indexing: attribute } field f3 type tensor<double>(x{}) { diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 8944409e1e9..2a4292f70fc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -63,7 +63,7 @@ public class RankingExpressionWithOnnxTestCase { public void testOnnxReferenceWithConstantFeature() { RankProfileSearchFixture search = fixtureWith("constant(mytensor)", "onnx('mnist_softmax.onnx')", - "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", + "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }", null); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); } @@ -73,7 +73,7 @@ public class RankingExpressionWithOnnxTestCase { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" + + " <field name='query(mytensor)' type='tensor<float>(d0[3],d1[784])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, @@ -93,7 +93,7 @@ public class RankingExpressionWithOnnxTestCase { RankProfileSearchFixture search = fixtureWith("attribute(mytensor)", "onnx('mnist_softmax.onnx')", null, - "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "field mytensor type tensor<float>(d0[],d1[784]) { indexing: attribute }", "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -105,15 +105,15 @@ public class RankingExpressionWithOnnxTestCase { String queryProfile = "<query-profile id='default' type='root'/>"; String queryProfileType = "<query-profile-type id='root'>" + - " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" + + " <field name='query(mytensor)' type='tensor<float>(d0[3],d1[784],d2[10])'/>" + "</query-profile-type>"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir, queryProfile, queryProfileType); RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)", "onnx('mnist_softmax.onnx')", - "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }", - "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }", + "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }", + "field mytensor type tensor<float>(d0[],d1[784]) { indexing: attribute }", "Placeholder", application); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -122,7 +122,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testNestedOnnxReference() { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "5 + sum(onnx('mnist_softmax.onnx'))"); search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile"); } @@ -186,7 +186,7 @@ public class RankingExpressionWithOnnxTestCase { @Test public void testImportingFromStoredExpressions() throws IOException { - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "onnx('mnist_softmax.onnx')"); search.assertFirstPhaseExpression(vespaExpression, "my_profile"); @@ -197,7 +197,7 @@ public class RankingExpressionWithOnnxTestCase { IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); - RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)", + RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)", "onnx('mnist_softmax.onnx')", null, null, diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java index c19f1a769de..d66f376ed6a 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java @@ -63,9 +63,9 @@ public class MlModelsTest { private final String testProfile = "rankingExpression(input).rankingScript: attribute(argument)\n" + - "rankingExpression(input).type: tensor(d0[],d1[784])\n" + + "rankingExpression(input).type: tensor<float>(d0[],d1[784])\n" + "rankingExpression(Placeholder).rankingScript: attribute(argument)\n" + - "rankingExpression(Placeholder).type: tensor(d0[],d1[784])\n" + + "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" + "rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" + "rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" + @@ -73,6 +73,6 @@ public class MlModelsTest { "rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" + "vespa.rank.firstphase: rankingExpression(firstphase)\n" + "rankingExpression(firstphase).rankingScript: rankingExpression(mnist_tensorflow) + rankingExpression(mnist_softmax_tensorflow) + rankingExpression(mnist_softmax_onnx) + rankingExpression(my_xgboost)\n" + - "vespa.type.attribute.argument: tensor(d0[],d1[784])\n"; + "vespa.type.attribute.argument: tensor<float>(d0[],d1[784])\n"; } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 424e4d6c57c..68df59bf93f 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -31,13 +31,13 @@ public class OnnxMnistSoftmaxImportTestCase { Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable")); assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), + assertEquals(new TensorType.Builder(TensorType.Value.DOUBLE).indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1")); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); + assertEquals(new TensorType.Builder(TensorType.Value.DOUBLE).indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); // Check inputs diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index c83de4ced0a..ea65a508047 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -617,7 +617,7 @@ TensorType.Value optionalTensorValueTypeParameter() : } { ( <LT> valueType = identifier() <GT> )? - { return TensorTypeParser.toValueType(valueType); } + { return TensorType.Value.fromId(valueType); } } // NOTE: Only indexed bound dimensions are parsed currently, as that is what we need diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 6f37b9edea4..9c425570a7e 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1364,9 +1364,12 @@ "methods": [ "public static com.yahoo.tensor.TensorType$Value[] values()", "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", + "public java.lang.String id()", + "public boolean isEqualOrLargerThan(com.yahoo.tensor.TensorType$Value)", "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)", - "public java.lang.String toString()" + "public java.lang.String toString()", + "public static com.yahoo.tensor.TensorType$Value fromId(java.lang.String)" ], "fields": [ "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", @@ -1409,8 +1412,7 @@ ], "methods": [ "public void <init>()", - "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", - "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)" + "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 9869f1e908c..319947607d2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -29,7 +29,17 @@ public class TensorType { public enum Value { // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below - DOUBLE, FLOAT; + DOUBLE("double"), FLOAT("float"); + + private final String id; + + Value(String id) { this.id = id; } + + public String id() { return id; } + + public boolean isEqualOrLargerThan(TensorType.Value other) { + return this == other || largestOf(this, other) == this; + } public static Value largestOf(List<Value> values) { if (values.isEmpty()) return Value.DOUBLE; // Default @@ -51,6 +61,15 @@ public class TensorType { @Override public String toString() { return name().toLowerCase(); } + public static Value fromId(String valueTypeString) { + switch (valueTypeString) { + case "double" : return Value.DOUBLE; + case "float" : return Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + }; /** The empty tensor type - which is the same as a double */ @@ -146,7 +165,7 @@ public class TensorType { } private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) { - if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed + if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false; if (generalization.dimensions().size() != this.dimensions().size()) return false; for (int i = 0; i < generalization.dimensions().size(); i++) { Dimension thisDimension = this.dimensions().get(i); @@ -168,11 +187,9 @@ public class TensorType { @Override public String toString() { - if ((rank() == 0) || (valueType == Value.DOUBLE)) { - return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; - } else { - return "tensor<" + valueType + ">(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; - } + return "tensor" + + (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") + + "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")"; } @Override @@ -238,7 +255,7 @@ public class TensorType { @Override public int hashCode() { - return dimensions.hashCode(); + return Objects.hash(dimensions, valueType); } /** diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index 1f426942c5f..def3ab6b4ec 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -56,21 +56,12 @@ public class TensorTypeParser { return new TensorType.Builder(valueType, dimensions).build(); } - public static TensorType.Value toValueType(String valueTypeString) { - switch (valueTypeString) { - case "double" : return TensorType.Value.DOUBLE; - case "float" : return TensorType.Value.FLOAT; - default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + - " but was '" + valueTypeString + "'"); - } - } - private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); try { - return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); } catch (IllegalArgumentException e) { throw formatException(fullSpecString, e.getMessage()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java index d3bb702175a..a547f941d8e 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java @@ -96,6 +96,8 @@ public class TensorTypeTestCase { assertValueType(TensorType.Value.DOUBLE, "tensor(x[])"); assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])"); assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])"); + assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString()); + assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString()); } private static void assertTensorType(String typeSpec) { |