summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-05-27 10:08:50 +0200
committerLester Solbakken <lesters@oath.com>2021-05-27 10:08:50 +0200
commit8bb0b79db9d224e632a2f49c77418415a5b6b0d4 (patch)
tree735f073c38f8a28f4887814207933021b005fa30
parent4feb552579339c370db15c1c433f9269336506b8 (diff)
Add compatibility for old ONNX default signature import
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java2
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java10
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java5
4 files changed, 16 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
index 3dab19699a3..6a497460c5f 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java
@@ -44,7 +44,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans
FeatureArguments arguments = asFeatureArguments(feature.getArguments());
ConvertedModel convertedModel =
convertedOnnxModels.computeIfAbsent(arguments.path(),
- path -> ConvertedModel.fromSourceOrStore(path, true, context, true));
+ path -> ConvertedModel.fromSourceOrStore(path, true, context));
return convertedModel.expression(arguments, context);
}
catch (IllegalArgumentException | UncheckedIOException e) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index da26ea9daf2..62f911c9f1a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -73,10 +73,6 @@ public class ConvertedModel {
this.sourceModel = sourceModel;
}
- public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) {
- return fromSourceOrStore(modelPath, pathIsFile, context, false);
- }
-
/**
* Create and store a converted model for a rank profile given from either an imported model,
* or (if unavailable) from stored application package data.
@@ -84,12 +80,10 @@ public class ConvertedModel {
* @param modelPath the path to the model
* @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory
* @param context the transform context
- * @param convertToNative force conversion to native Vespa expressions (if applicable)
*/
public static ConvertedModel fromSourceOrStore(Path modelPath,
boolean pathIsFile,
- RankProfileTransformContext context,
- boolean convertToNative) {
+ RankProfileTransformContext context) {
ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way
context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath));
ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile);
@@ -99,7 +93,7 @@ public class ConvertedModel {
context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", ")));
if (sourceModel != null) {
- if (convertToNative && ! sourceModel.isNative()) {
+ if ( ! sourceModel.isNative()) {
sourceModel = sourceModel.asNative();
}
return fromSource(modelName,
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index bf35a002e3a..bbf59fc66e4 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -138,8 +138,16 @@ public class ModelEvaluationTest {
assertEquals(2, add_mul.functions().size());
assertNotNull(add_mul.evaluatorOf("output1"));
assertNotNull(add_mul.evaluatorOf("output2"));
+ assertNotNull(add_mul.evaluatorOf("default.output1"));
+ assertNotNull(add_mul.evaluatorOf("default.output2"));
+ assertNotNull(add_mul.evaluatorOf("default", "output1"));
+ assertNotNull(add_mul.evaluatorOf("default", "output2"));
assertNotNull(evaluator.evaluatorOf("add_mul", "output1"));
assertNotNull(evaluator.evaluatorOf("add_mul", "output2"));
+ assertNotNull(evaluator.evaluatorOf("add_mul", "default.output1"));
+ assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2"));
+ assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1"));
+ assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2"));
assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1"));
assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2"));
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 40a84a701ec..5590b9f0242 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -216,6 +216,11 @@ public class Model {
return evaluatorOf("default" + name.substring("serving_default".length()));
}
+ // To handle backward compatibility with ONNX conversion to native Vespa rank expressions
+ if (name.startsWith("default.")) {
+ return evaluatorOf(name.substring("default.".length()));
+ }
+
throwUndeterminedFunction("No function '" + name + "' in " + this);
}
else if (names.length == 2) {