diff options
author | Lester Solbakken <lesters@oath.com> | 2022-01-10 14:34:04 +0100 |
---|---|---|
committer | gjoranv <gv@verizonmedia.com> | 2022-06-08 11:45:28 +0200 |
commit | 19f9c783c2f1ca136a6ed874656e0a2c93b4adca (patch) | |
tree | fe0a0ccaff7e0ba5293d9950b1725d7e03c5237d | |
parent | 45a16605fe7caa2ebcbc1068fa2b48cbfa3b28c1 (diff) |
onnxModel to onnx in summary/matchfeatures
7 files changed, 21 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java index c6c807f2dbb..cbf120e1ee0 100644 --- a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java @@ -267,7 +267,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement String modelConfigName = OnnxModelTransformer.getModelConfigName(reference); String modelOutput = OnnxModelTransformer.getModelOutput(reference, null); - reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); + reference = new Reference("onnx", new Arguments(new ReferenceNode(modelConfigName)), modelOutput); if ( ! featureTypes.containsKey(reference)) { throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'"); } diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java index 5479ecf323f..56786c733ec 100644 --- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java @@ -1089,11 +1089,11 @@ public class RankProfile implements Cloneable { Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context); TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes); - context.setType(new Reference("onnxModel", args, null), defaultOutputType); + context.setType(new Reference("onnx", args, null), defaultOutputType); for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) { TensorType type = model.getTensorType(mapping.getKey(), inputTypes); - context.setType(new Reference("onnxModel", args, mapping.getValue()), type); + context.setType(new Reference("onnx", args, mapping.getValue()), type); } } return context; diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java index 4c38c257602..8797deefcb6 100644 --- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java +++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java @@ -20,17 +20,16 @@ import java.util.List; /** * Transforms ONNX model features of the forms: * - * onnxModel(config_name) - * onnxModel(config_name).output - * onnxModel("path/to/model") - * onnxModel("path/to/model").output - * onnxModel("path/to/model", "path/to/output") - * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused - * onnx(...) // same as with onnxModel, onnx is an alias of onnxModel + * onnx(config_name) + * onnx(config_name).output + * onnx("path/to/model") + * onnx("path/to/model").output + * onnx("path/to/model", "path/to/output") + * onnx("path/to/model", "unused", "path/to/output") // signature is unused * * To the format expected by the backend: * - * onnxModel(config_name).output + * onnx(config_name).output * * @author lesters */ @@ -84,7 +83,7 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans throw new IllegalArgumentException(featureName + " argument '" + output + "' output not found in model '" + onnxModel.getFileName() + "'"); } - return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output); + return new ReferenceNode("onnx", List.of(new ReferenceNode(modelConfigName)), output); } public static String getModelConfigName(Reference reference) { diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java index 713e11fd608..1280895bfc0 100644 --- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java @@ -139,7 +139,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name()); assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name()); - assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); + assertEquals("onnx(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value()); assertEquals("test_generated_model_config", config.rankprofile(3).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name()); @@ -149,25 +149,25 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name()); assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name()); - assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); + assertEquals("onnx(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value()); assertEquals("test_summary_features", config.rankprofile(4).name()); assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name()); assertEquals("1", config.rankprofile(4).fef().property(3).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name()); - assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value()); + assertEquals("onnx(another_model).out", config.rankprofile(4).fef().property(4).value()); assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name()); - assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value()); + assertEquals("onnx(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value()); assertEquals("test_dynamic_model", config.rankprofile(5).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name()); - assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); + assertEquals("onnx(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value()); assertEquals("test_dynamic_model_2", config.rankprofile(6).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name()); - assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); + assertEquals("onnx(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value()); assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name()); @@ -176,7 +176,7 @@ public class RankingExpressionWithOnnxModelTestCase { assertEquals("test_unbound_model", config.rankprofile(8).name()); assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name()); assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name()); - assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); + assertEquals("onnx(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value()); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java index c60817704cd..694b908478d 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java @@ -170,7 +170,7 @@ public class ModelEvaluationTest { } private final String profile = - "rankingExpression(output).rankingScript: onnxModel(small_constants_and_functions).output\n" + + "rankingExpression(output).rankingScript: onnx(small_constants_and_functions).output\n" + "rankingExpression(output).type: tensor<float>(d0[3])\n"; private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 3314ecb23fc..d030108a17a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -260,7 +260,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { } /** - * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add + * Extract the feature used to evaluate the onnx model. e.g. onnx(name) and add * that as a bind target and argument. During evaluation, this will be evaluated before * the rest of the expression and the result is added to the context. Also extract the * inputs to the model and add them as bind targets and arguments. diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java index 144e3e6c03b..e8fc75824be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java @@ -37,7 +37,7 @@ public class OnnxImporter extends ModelImporter { for (int i = 0; i < model.getGraph().getOutputCount(); ++i) { Onnx.ValueInfoProto output = model.getGraph().getOutput(i); String outputName = asValidIdentifier(output.getName()); - importedModel.expression(outputName, "onnxModel(" + modelName + ")." + outputName); + importedModel.expression(outputName, "onnx(" + modelName + ")." + outputName); } return importedModel; |