diff options
author | Lester Solbakken <lesters@oath.com> | 2022-01-10 14:34:04 +0100 |
---|---|---|
committer | gjoranv <gv@verizonmedia.com> | 2022-06-08 11:45:28 +0200 |
commit | 19f9c783c2f1ca136a6ed874656e0a2c93b4adca (patch) | |
tree | fe0a0ccaff7e0ba5293d9950b1725d7e03c5237d /config-model | |
parent | 45a16605fe7caa2ebcbc1068fa2b48cbfa3b28c1 (diff) |
onnxModel to onnx in summary/matchfeatures
Diffstat (limited to 'config-model')
5 files changed, 19 insertions, 20 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java index c6c807f2dbb..cbf120e1ee0 100644 --- a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java @@ -267,7 +267,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); - reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + reference = new Reference("onnx", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); if ( ! featureTypes.containsKey(reference)) { throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); } diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 5479ecf323f..56786c733ec 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1089,11 +1089,11 @@ public class RankProfile implements Cloneable { Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); - context.setType(new Reference("onnxModel", args, null), defaultOutputType); + context.setType(new Reference("onnx", args, null), defaultOutputType); for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { TensorType type = model.getTensorType(mapping.getKey(), inputTypes); - context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + context.setType(new Reference("onnx", args, mapping.getValue()), type); } } return context; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java index 4c38c257602..8797deefcb6 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java @@ -20,17 +20,16 @@ import java.util.List; /** * Transforms ONNX model features of the forms: * - * onnxModel(config_name) - * onnxModel(config_name).output - * onnxModel("path/to/model") - * onnxModel("path/to/model").output - * onnxModel("path/to/model", "path/to/output") - * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused - * onnx(...) // same as with onnxModel, onnx is an alias of onnxModel + * onnx(config_name) + * onnx(config_name).output + * onnx("path/to/model") + * onnx("path/to/model").output + * onnx("path/to/model", "path/to/output") + * onnx("path/to/model", "unused", "path/to/output") // signature is unused * * To the format expected by the backend: * - * onnxModel(config_name).output + * onnx(config_name).output * * @author lesters */ @@ -84,7 +83,7 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans throw new IllegalArgumentException(featureName + " argument '" + output + "' output not found in model '" + onnxModel.getFileName() + "'"); } - return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); + return new ReferenceNode("onnx", List.of(new ReferenceNode(modelConfigName)), output); } public static String getModelConfigName(Reference reference) { diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java index 713e11fd608..1280895bfc0 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -139,7 +139,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); - assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); + assertEquals("onnx(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); assertEquals("test_generated_model_config", config.rankprofile(3).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); @@ -149,25 +149,25 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); - assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); + assertEquals("onnx(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); assertEquals("test_summary_features", config.rankprofile(4).name()); assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); assertEquals("1", config.rankprofile(4).fef().property(3).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); - assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); + assertEquals("onnx(another_model).out", config.rankprofile(4).fef().property(4).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); - assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); + assertEquals("onnx(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); assertEquals("test_dynamic_model", config.rankprofile(5).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); - assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); + assertEquals("onnx(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); - assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + assertEquals("onnx(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); @@ -176,7 +176,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name()); - assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); + assertEquals("onnx(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index c60817704cd..694b908478d 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -170,7 +170,7 @@ public class ModelEvaluationTest { } private final String profile = - "rankingExpression(output).rankingScript: onnxModel(small_constants_and_functions).output\n" + + "rankingExpression(output).rankingScript: onnx(small_constants_and_functions).output\n" + "rankingExpression(output).type: tensor<float>(d0[3])\n"; private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) { |