diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-09-02 09:20:54 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-09-02 09:20:54 +0200 |
commit | 96e2cf880899cb204000e0693bb1bc51e2f52520 (patch) | |
tree | a75f048347f4c806d9332f81f0ffafdb7c549f30 /config-model | |
parent | b6fd9b3e3381b733923263b667cd9a7d52ed8715 (diff) |
Propagate float value types from Onnx and TF
Diffstat (limited to 'config-model')
-rw-r--r-- | config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java | 2 | ||||
-rw-r--r-- | config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java | 8 |
2 files changed, 5 insertions, 5 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 3e70ac4705e..877b1ac72a9 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -284,7 +284,7 @@ public class ConvertedModel { RankProfile.RankingExpressionFunction rankingExpressionFunctionOverridingConstant = profile.getFunctions().get(constantName); if (rankingExpressionFunctionOverridingConstant != null) { TensorType functionType = rankingExpressionFunctionOverridingConstant.function().getBody().type(profile.typeContext(queryProfiles)); - if ( ! functionType.equals(constantValue.type())) + if ( ! constantValue.type().isAssignableTo(functionType)) throw new IllegalArgumentException("Function '" + constantName + "' replaces the constant with this name. " + typeMismatchExplanation(constantValue.type(), functionType)); constantsReplacedByFunctions.add(constantName); // will replace constant(constantName) by constantName later diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java index 2a4292f70fc..754f161c70b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithOnnxTestCase.java @@ -145,7 +145,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder' of type tensor(d0[],d1[784]) but this function is " + + "Model refers input 'Placeholder' of type tensor<float>(d0[],d1[784]) but this function is " + "not present in rank profile 'my_profile'", Exceptions.toMessageString(expected)); } @@ -162,7 +162,7 @@ public class RankingExpressionWithOnnxTestCase { catch (IllegalArgumentException expected) { assertEquals("Rank profile 'my_profile' is invalid: Could not use Onnx model from " + "onnx('mnist_softmax.onnx'): " + - "Model refers input 'Placeholder'. The required type of this is tensor(d0[],d1[784]), " + + "Model refers input 'Placeholder'. The required type of this is tensor<float>(d0[],d1[784]), " + "but this function returns tensor(d0[2],d5[10])", Exceptions.toMessageString(expected)); } @@ -217,10 +217,10 @@ public class RankingExpressionWithOnnxTestCase { String rankProfile = " rank-profile my_profile {\n" + " function Placeholder() {\n" + - " expression: tensor(d0[2],d1[784])(0.0)\n" + + " expression: tensor<float>(d0[2],d1[784])(0.0)\n" + " }\n" + " function " + name + "_Variable() {\n" + - " expression: tensor(d1[10],d2[784])(0.0)\n" + + " expression: tensor<float>(d1[10],d2[784])(0.0)\n" + " }\n" + " first-phase {\n" + " expression: onnx('mnist_softmax.onnx')" + |