summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-09-02 09:20:54 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-09-02 09:20:54 +0200
commit96e2cf880899cb204000e0693bb1bc51e2f52520 (patch)
treea75f048347f4c806d9332f81f0ffafdb7c549f30 /config-model
parentb6fd9b3e3381b733923263b667cd9a7d52ed8715 (diff)
Propagate float value types from Onnx and TF
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java2
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java8
2 files changed, 5 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 3e70ac4705e..877b1ac72a9 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -284,7 +284,7 @@ public class ConvertedModel {
RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName);
if (rankingExpressionFunctionOverridingConstant != null) {
TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles));
- if ( ! functionType.equals(constantValue.type()))
+ if ( ! constantValue.type().isAssignableTo(functionType))
throw new IllegalArgumentException("Function '" + constantName + "' replaces the constant with this name. " +
typeMismatchExplanation(constantValue.type(), functionType));
constantsReplacedByFunctions.add(constantName); // will replace constant(constantName) by constantName later
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
index 2a4292f70fc..754f161c70b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java
@@ -145,7 +145,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this function is " +
+ "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " +
"not present in rank profile 'my_profile'",
Exceptions.toMessageString(expected));
}
@@ -162,7 +162,7 @@ public class RankingExpressionWithOnnxTestCase {
catch (IllegalArgumentException expected) {
assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " +
"onnx('mnist_softmax.onnx'): " +
- "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " +
+ "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " +
"but this function returns tensor(d0[2],d5[10])",
Exceptions.toMessageString(expected));
}
@@ -217,10 +217,10 @@ public class RankingExpressionWithOnnxTestCase {
String rankProfile =
" rank-profile my_profile {\n" +
" function Placeholder() {\n" +
- " expression: tensor(d0[2],d1[784])(0.0)\n" +
+ " expression: tensor<float>(d0[2],d1[784])(0.0)\n" +
" }\n" +
" function " + name + "_Variable() {\n" +
- " expression: tensor(d1[10],d2[784])(0.0)\n" +
+ " expression: tensor<float>(d1[10],d2[784])(0.0)\n" +
" }\n" +
" first-phase {\n" +
" expression: onnx('mnist_softmax.onnx')" +