aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArnstein Ressem <aressem@gmail.com>2019-05-08 12:24:41 +0200
committerGitHub <noreply@github.com>2019-05-08 12:24:41 +0200
commit172698ac2c7af46e1446f7709ad0d67a444744c0 (patch)
tree7648de3a0ea53aa2aa3e41570d52c9df4ed7d904
parent6c8283fc5264ae59f0f5eb90b073add1d3552ab3 (diff)
Revert "Bratseth/emit float tensors in config"
-rw-r--r--config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java8
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java1
-rw-r--r--config-model/src/main/javacc/SDParser.jj4
-rw-r--r--config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd2
-rw-r--r--config-model/src/test/derived/tensor/attributes.cfg2
-rw-r--r--config-model/src/test/derived/tensor/documenttypes.cfg2
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg10
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java30
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java6
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java6
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj2
-rw-r--r--vespajlib/abi-spec.json8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java2
16 files changed, 52 insertions, 73 deletions
diff --git a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
index cef0a5d4d8b..7a981cd6a53 100644
--- a/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
+++ b/config-model/src/main/java/com/yahoo/config/model/deploy/DeployState.java
@@ -376,17 +376,15 @@ public class DeployState implements ConfigDefinitionStore {
String searchName = builder.importReader(reader, readerName, logger);
String sdName = stripSuffix(readerName, ApplicationPackage.SD_NAME_SUFFIX);
names.put(searchName, sdName);
- if ( ! sdName.equals(searchName)) {
+ if (!sdName.equals(searchName)) {
throw new IllegalArgumentException("Search definition file name ('" + sdName + "') and name of " +
"search element ('" + searchName +
"') are not equal for file '" + readerName + "'");
}
} catch (ParseException e) {
- throw new IllegalArgumentException("Could not parse search definition file '" +
- getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e);
+ throw new IllegalArgumentException("Could not parse search definition file '" + getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e);
} catch (IOException e) {
- throw new IllegalArgumentException("Could not read search definition file '" +
- getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e);
+ throw new IllegalArgumentException("Could not read search definition file '" + getSearchDefinitionRelativePath(reader.getName()) + "': " + e.getMessage(), e);
} finally {
closeIgnoreException(reader.getReader());
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
index 69cd6f09046..bd4daa58253 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/SearchBuilder.java
@@ -171,7 +171,6 @@ public class SearchBuilder {
} catch (TokenMgrException e) {
throw new ParseException("Unknown symbol: " + e.getMessage());
} catch (ParseException pe) {
- pe.printStackTrace();
throw new ParseException(stream.formatException(Exceptions.toMessageString(pe)));
}
return importRawSearch(search);
diff --git a/config-model/src/main/javacc/SDParser.jj b/config-model/src/main/javacc/SDParser.jj
index b0af54812a4..571ad452b01 100644
--- a/config-model/src/main/javacc/SDParser.jj
+++ b/config-model/src/main/javacc/SDParser.jj
@@ -298,7 +298,6 @@ TOKEN :
| < FASTSEARCH: "fast-search" >
| < HUGE: "huge" >
| < TENSOR_TYPE: "tensor(" (~["(",")"])+ ")" >
-| < TENSOR_TYPE_WITH_VALUE_TYPE: "tensor<" (["a"-"z","A"-"Z"])+ ">(" (~["(",")"])+ ")" >
| < TENSOR_VALUE_SL: "value" (" ")* ":" (" ")* ("{"<BRACE_SL_LEVEL_1>) ("\n")? >
| < TENSOR_VALUE_ML: "value" (<SEARCHLIB_SKIP>)? "{" (["\n"," "])* ("{"<BRACE_ML_LEVEL_1>) (["\n"," "])* "}" ("\n")? >
| < COMPRESSION: "compression" >
@@ -842,7 +841,6 @@ DataType dataType() :
| LOOKAHEAD(<MAP> <LESSTHAN>) ( mapType = mapDataType() { return mapType; } )
| LOOKAHEAD(<ANNOTATIONREFERENCE> <LESSTHAN>) ( mapType = annotationRefDataType() { return mapType; } )
| LOOKAHEAD(<TENSOR_TYPE>) ( tensorType = tensorType("Field type") { return DataType.getTensor(tensorType); } )
- | LOOKAHEAD(<TENSOR_TYPE_WITH_VALUE_TYPE>) ( tensorType = tensorType("Field type") { return DataType.getTensor(tensorType); } )
| LOOKAHEAD(<REFERENCE>) ( <REFERENCE> <LESSTHAN> referenceType = referenceType() <GREATERTHAN> { return ReferenceDataType.createWithInferredId(referenceType); } )
| ( typeName = identifier() ["[]" { isArrayOldStyle = true; }] )
)
@@ -2407,7 +2405,7 @@ TensorType tensorType(String errorMessage) :
String tensorTypeString;
}
{
- ( <TENSOR_TYPE> | <TENSOR_TYPE_WITH_VALUE_TYPE> ) { tensorTypeString = token.image; }
+ <TENSOR_TYPE> { tensorTypeString = token.image; }
{
TensorType tensorType;
try {
diff --git a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
index ab5e42f983d..d7b3cc76339 100644
--- a/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
+++ b/config-model/src/test/cfg/application/ml_models/searchdefinitions/test.sd
@@ -2,7 +2,7 @@
search test {
document test {
- field argument type tensor<float>(d0[],d1[784]) {
+ field argument type tensor(d0[],d1[784]) {
indexing: attribute
}
}
diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg
index a72772276a8..4b54e67f8b8 100644
--- a/config-model/src/test/derived/tensor/attributes.cfg
+++ b/config-model/src/test/derived/tensor/attributes.cfg
@@ -17,7 +17,7 @@ attribute[].arity 8
attribute[].lowerbound -9223372036854775808
attribute[].upperbound 9223372036854775807
attribute[].densepostinglistthreshold 0.4
-attribute[].tensortype "tensor<float>(x[2],y[1])"
+attribute[].tensortype "tensor(x[2],y[])"
attribute[].imported false
attribute[].name "f3"
attribute[].datatype TENSOR
diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg
index fba11509ad8..56818298eb8 100644
--- a/config-model/src/test/derived/tensor/documenttypes.cfg
+++ b/config-model/src/test/derived/tensor/documenttypes.cfg
@@ -29,7 +29,7 @@ documenttype[].datatype[].sstruct.field[].name "f2"
documenttype[].datatype[].sstruct.field[].id 2080644671
documenttype[].datatype[].sstruct.field[].id_v6 1424572148
documenttype[].datatype[].sstruct.field[].datatype 21
-documenttype[].datatype[].sstruct.field[].detailedtype "tensor<float>(x[2],y[1])"
+documenttype[].datatype[].sstruct.field[].detailedtype "tensor(x[2],y[])"
documenttype[].datatype[].sstruct.field[].name "f3"
documenttype[].datatype[].sstruct.field[].id 1295091863
documenttype[].datatype[].sstruct.field[].id_v6 1444109654
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 45233f3e1ca..471343da63c 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -1,6 +1,6 @@
rankprofile[].name "default"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
-rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].value "tensor(x[2],y[])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
@@ -15,7 +15,7 @@ rankprofile[].fef.property[].value "0"
rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures"
rankprofile[].fef.property[].value "true"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
-rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].value "tensor(x[2],y[])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
@@ -26,7 +26,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
rankprofile[].fef.property[].value "reduce(map(attribute(f4), f(x)(x * x)) + reduce(tensor(x[2],y[3])(random), count) * rename(attribute(f4), (x, y), (y, x)), sum)"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
-rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].value "tensor(x[2],y[])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
@@ -37,7 +37,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[2],y[2],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
-rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].value "tensor(x[2],y[])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
@@ -52,7 +52,7 @@ rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
rankprofile[].fef.property[].value "reduce(rankingExpression(joinedtensors), sum)"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
-rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].value "tensor(x[2],y[])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
rankprofile[].fef.property[].value "tensor(x{})"
rankprofile[].fef.property[].name "vespa.type.attribute.f4"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index 3dee10b5f8a..e3e8cf43347 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -5,7 +5,7 @@ search tensor {
field f1 type tensor(x[]) {
indexing: summary
}
- field f2 type tensor<float>(x[2],y[1]) {
+ field f2 type tensor(x[2],y[]) {
indexing: attribute
}
field f3 type tensor(x{}) {
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 094068fc723..8944409e1e9 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -63,7 +63,7 @@ public class RankingExpressionWithOnnxTestCase {
public void testOnnxReferenceWithConstantFeature() {
RankProfileSearchFixture search = fixtureWith("constant(mytensor)",
"onnx('mnist_softmax.onnx')",
- "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
null);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
}
@@ -73,7 +73,7 @@ public class RankingExpressionWithOnnxTestCase {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType =
"<query-profile-type id='root'>" +
- " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[3],d1[784])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
@@ -93,7 +93,7 @@ public class RankingExpressionWithOnnxTestCase {
RankProfileSearchFixture search = fixtureWith("attribute(mytensor)",
"onnx('mnist_softmax.onnx')",
null,
- "field mytensor type tensor<float>(d0[],d1[784]) { indexing: attribute }",
+ "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -105,15 +105,15 @@ public class RankingExpressionWithOnnxTestCase {
String queryProfile = "<query-profile id='default' type='root'/>";
String queryProfileType =
"<query-profile-type id='root'>" +
- " <field name='query(mytensor)' type='tensor&lt;float&gt;(d0[3],d1[784],d2[10])'/>" +
+ " <field name='query(mytensor)' type='tensor(d0[3],d1[784],d2[10])'/>" +
"</query-profile-type>";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir,
queryProfile,
queryProfileType);
RankProfileSearchFixture search = fixtureWith("sum(query(mytensor) * attribute(mytensor) * constant(mytensor),d2)",
"onnx('mnist_softmax.onnx')",
- "constant mytensor { file: ignored\ntype: tensor<float>(d0[7],d1[784]) }",
- "field mytensor type tensor<float>(d0[],d1[784]) { indexing: attribute }",
+ "constant mytensor { file: ignored\ntype: tensor(d0[7],d1[784]) }",
+ "field mytensor type tensor(d0[],d1[784]) { indexing: attribute }",
"Placeholder",
application);
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -122,7 +122,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testNestedOnnxReference() {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"5 + sum(onnx('mnist_softmax.onnx'))");
search.assertFirstPhaseExpression("5 + reduce(" + vespaExpression + ", sum)", "my_profile");
}
@@ -145,7 +145,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " +
+ "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -154,7 +154,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testOnnxReferenceWithWrongFunctionType() {
try {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d5[10])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d5[10])(0.0)",
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
fail("Expecting exception");
@@ -162,8 +162,8 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " +
- "but this function returns tensor<float>(d0[2],d5[10])",
+ "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "but this function returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
}
@@ -186,7 +186,7 @@ public class RankingExpressionWithOnnxTestCase {
@Test
public void testImportingFromStoredExpressions() throws IOException {
- RankProfileSearchFixture search = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"onnx('mnist_softmax.onnx')");
search.assertFirstPhaseExpression(vespaExpression, "my_profile");
@@ -197,7 +197,7 @@ public class RankingExpressionWithOnnxTestCase {
IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory);
- RankProfileSearchFixture searchFromStored = fixtureWith("tensor<float>(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture searchFromStored = fixtureWith("tensor(d0[2],d1[784])(0.0)",
"onnx('mnist_softmax.onnx')",
null,
null,
@@ -217,10 +217,10 @@ public class RankingExpressionWithOnnxTestCase {
String rankProfile =
" rank-profile my_profile {\n" +
" function Placeholder() {\n" +
- " expression: tensor<float>(d0[2],d1[784])(0.0)\n" +
+ " expression: tensor(d0[2],d1[784])(0.0)\n" +
" }\n" +
" function " + name + "_Variable() {\n" +
- " expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
+ " expression: tensor(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
" expression: onnx('mnist_softmax.onnx')" +
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
index d66f376ed6a..c19f1a769de 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/MlModelsTest.java
@@ -63,9 +63,9 @@ public class MlModelsTest {
private final String testProfile =
"rankingExpression(input).rankingScript: attribute(argument)\n" +
- "rankingExpression(input).type: tensor<float>(d0[],d1[784])\n" +
+ "rankingExpression(input).type: tensor(d0[],d1[784])\n" +
"rankingExpression(Placeholder).rankingScript: attribute(argument)\n" +
- "rankingExpression(Placeholder).type: tensor<float>(d0[],d1[784])\n" +
+ "rankingExpression(Placeholder).type: tensor(d0[],d1[784])\n" +
"rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add).rankingScript: join(reduce(join(rename(rankingExpression(input), (d0, d1), (d0, d4)), constant(mnist_saved_dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(mnist_saved_dnn_hidden1_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_tensorflow).rankingScript: join(reduce(join(map(join(reduce(join(join(join(rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), 0.009999999776482582, f(a,b)(a * b)), rankingExpression(imported_ml_function_mnist_saved_dnn_hidden1_add), f(a,b)(max(a,b))), constant(mnist_saved_dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(mnist_saved_dnn_hidden2_bias_read), f(a,b)(a + b)), f(a)(1.0507009873554805 * if (a >= 0, a, 1.6732632423543772 * (exp(a) - 1)))), constant(mnist_saved_dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(mnist_saved_dnn_outputs_bias_read), f(a,b)(a + b))\n" +
"rankingExpression(mnist_softmax_tensorflow).rankingScript: join(reduce(join(rename(rankingExpression(Placeholder), (d0, d1), (d0, d2)), constant(mnist_softmax_saved_layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(mnist_softmax_saved_layer_Variable_1_read), f(a,b)(a + b))\n" +
@@ -73,6 +73,6 @@ public class MlModelsTest {
"rankingExpression(my_xgboost).rankingScript: if (f29 < -0.1234567, if (f56 < -0.242398, 1.71218, -1.70044), if (f109 < 0.8723473, -1.94071, 1.85965)) + if (f60 < -0.482947, if (f29 < -4.2387498, 0.784718, -0.96853), -6.23624)\n" +
"vespa.rank.firstphase: rankingExpression(firstphase)\n" +
"rankingExpression(firstphase).rankingScript: rankingExpression(mnist_tensorflow) + rankingExpression(mnist_softmax_tensorflow) + rankingExpression(mnist_softmax_onnx) + rankingExpression(my_xgboost)\n" +
- "vespa.type.attribute.argument: tensor<float>(d0[],d1[784])\n";
+ "vespa.type.attribute.argument: tensor(d0[],d1[784])\n";
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 35c853bd746..07814687dc6 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -31,13 +31,13 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable"));
assertNotNull(constant0);
- assertEquals(new TensorType.Builder(TensorType.Value.FLOAT).indexed("d2", 784).indexed("d1", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1"));
assertNotNull(constant1);
- assertEquals(new TensorType.Builder(TensorType.Value.FLOAT).indexed("d1", 10).build(), constant1.type());
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
// Check inputs
@@ -52,7 +52,7 @@ public class OnnxMnistSoftmaxImportTestCase {
output.expression());
assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
- assertEquals("{Placeholder=tensor<float>(d0[],d1[784])}", output.argumentTypes().toString());
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
@Test
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index ea65a508047..c83de4ced0a 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -617,7 +617,7 @@ TensorType.Value optionalTensorValueTypeParameter() :
}
{
( <LT> valueType = identifier() <GT> )?
- { return TensorType.Value.fromId(valueType); }
+ { return TensorTypeParser.toValueType(valueType); }
}
// NOTE: Only indexed bound dimensions are parsed currently, as that is what we need
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 65fbf49d334..4f81f3baea8 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1346,11 +1346,8 @@
"methods": [
"public static com.yahoo.tensor.TensorType$Value[] values()",
"public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)",
- "public java.lang.String id()",
- "public boolean isEqualOrLargerThan(com.yahoo.tensor.TensorType$Value)",
"public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)",
- "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)",
- "public static com.yahoo.tensor.TensorType$Value fromId(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)"
],
"fields": [
"public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
@@ -1393,7 +1390,8 @@
],
"methods": [
"public void <init>()",
- "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)"
+ "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
+ "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 7f73ef41032..b1c7a2341c0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -29,17 +29,7 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE("double"), FLOAT("float");
-
- private final String id;
-
- Value(String id) { this.id = id; }
-
- public String id() { return id; }
-
- public boolean isEqualOrLargerThan(TensorType.Value other) {
- return this == other || largestOf(this, other) == this;
- }
+ DOUBLE, FLOAT;
public static Value largestOf(List<Value> values) {
if (values.isEmpty()) return Value.DOUBLE; // Default
@@ -58,15 +48,6 @@ public class TensorType {
return FLOAT;
}
- public static Value fromId(String valueTypeString) {
- switch (valueTypeString) {
- case "double" : return Value.DOUBLE;
- case "float" : return Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
- }
- }
-
};
/** The empty tensor type - which is the same as a double */
@@ -162,7 +143,7 @@ public class TensorType {
}
private boolean isConvertibleOrAssignableTo(TensorType generalization, boolean convertible, boolean considerName) {
- if ( ! generalization.valueType().isEqualOrLargerThan(this.valueType) ) return false;
+ if ( this.valueType() != generalization.valueType()) return false; // TODO: This can be relaxed
if (generalization.dimensions().size() != this.dimensions().size()) return false;
for (int i = 0; i < generalization.dimensions().size(); i++) {
Dimension thisDimension = this.dimensions().get(i);
@@ -184,9 +165,7 @@ public class TensorType {
@Override
public String toString() {
- return "tensor" +
- (valueType == Value.DOUBLE ? "" : "<" + valueType.id() + ">") +
- "(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
+ return "tensor(" + dimensions.stream().map(Dimension::toString).collect(Collectors.joining(",")) + ")";
}
@Override
@@ -251,7 +230,7 @@ public class TensorType {
@Override
public int hashCode() {
- return Objects.hash(dimensions, valueType);
+ return dimensions.hashCode();
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index ba23868381c..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -55,12 +55,21 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
try {
- return TensorType.Value.fromId(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
}
catch (IllegalArgumentException e) {
throw formatException(fullSpecString, e.getMessage());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index a547f941d8e..d3bb702175a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,8 +96,6 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
- assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
- assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {