diff options
author | Lester Solbakken <lesters@oath.com> | 2021-05-27 10:08:50 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-05-27 10:08:50 +0200 |
commit | 8bb0b79db9d224e632a2f49c77418415a5b6b0d4 (patch) | |
tree | 735f073c38f8a28f4887814207933021b005fa30 | |
parent | 4feb552579339c370db15c1c433f9269336506b8 (diff) |
Add compatibility for old ONNX default signature import
4 files changed, 16 insertions, 9 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java index 3dab19699a3..6a497460c5f 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/OnnxFeatureConverter.java @@ -44,7 +44,7 @@ public class OnnxFeatureConverter extends ExpressionTransformer<RankProfileTrans FeatureArguments arguments = asFeatureArguments(feature.getArguments()); ConvertedModel convertedModel = convertedOnnxModels.computeIfAbsent(arguments.path(), - path -> ConvertedModel.fromSourceOrStore(path, true, context, true)); + path -> ConvertedModel.fromSourceOrStore(path, true, context)); return convertedModel.expression(arguments, context); } catch (IllegalArgumentException | UncheckedIOException e) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index da26ea9daf2..62f911c9f1a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -73,10 +73,6 @@ public class ConvertedModel { this.sourceModel = sourceModel; } - public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, RankProfileTransformContext context) { - return fromSourceOrStore(modelPath, pathIsFile, context, false); - } - /** * Create and store a converted model for a rank profile given from either an imported model, * or (if unavailable) from stored application package data. @@ -84,12 +80,10 @@ public class ConvertedModel { * @param modelPath the path to the model * @param pathIsFile true if that path (this kind of model) is stored in a file, false if it is in a directory * @param context the transform context - * @param convertToNative force conversion to native Vespa expressions (if applicable) */ public static ConvertedModel fromSourceOrStore(Path modelPath, boolean pathIsFile, - RankProfileTransformContext context, - boolean convertToNative) { + RankProfileTransformContext context) { ImportedMlModel sourceModel = // TODO: Convert to name here, make sure its done just one way context.importedModels().get(sourceModelFile(context.rankProfile().applicationPackage(), modelPath)); ModelName modelName = new ModelName(context.rankProfile().getName(), modelPath, pathIsFile); @@ -99,7 +93,7 @@ public class ConvertedModel { context.importedModels().all().stream().map(ImportedMlModel::source).collect(Collectors.joining(", "))); if (sourceModel != null) { - if (convertToNative && ! sourceModel.isNative()) { + if ( ! sourceModel.isNative()) { sourceModel = sourceModel.asNative(); } return fromSource(modelName, diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index bf35a002e3a..bbf59fc66e4 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -138,8 +138,16 @@ public class ModelEvaluationTest { assertEquals(2, add_mul.functions().size()); assertNotNull(add_mul.evaluatorOf("output1")); assertNotNull(add_mul.evaluatorOf("output2")); + assertNotNull(add_mul.evaluatorOf("default.output1")); + assertNotNull(add_mul.evaluatorOf("default.output2")); + assertNotNull(add_mul.evaluatorOf("default", "output1")); + assertNotNull(add_mul.evaluatorOf("default", "output2")); assertNotNull(evaluator.evaluatorOf("add_mul", "output1")); assertNotNull(evaluator.evaluatorOf("add_mul", "output2")); + assertNotNull(evaluator.evaluatorOf("add_mul", "default.output1")); + assertNotNull(evaluator.evaluatorOf("add_mul", "default.output2")); + assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output1")); + assertNotNull(evaluator.evaluatorOf("add_mul", "default", "output2")); assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input1")); assertEquals(TensorType.fromSpec("tensor<float>(d0[1])"), add_mul.functions().get(0).argumentTypes().get("input2")); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 40a84a701ec..5590b9f0242 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -216,6 +216,11 @@ public class Model { return evaluatorOf("default" + name.substring("serving_default".length())); } + // To handle backward compatibility with ONNX conversion to native Vespa rank expressions + if (name.startsWith("default.")) { + return evaluatorOf(name.substring("default.".length())); + } + throwUndeterminedFunction("No function '" + name + "' in " + this); } else if (names.length == 2) { |