summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java2
-rw-r--r--config-model/src/main/java/com/yahoo/schema/RankProfile.java4
-rw-r--r--config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java17
-rw-r--r--config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java14
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java2
7 files changed, 21 insertions, 22 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java
index c6c807f2dbb..cbf120e1ee0 100644
--- a/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/schema/MapEvaluationTypeContext.java
@@ -267,7 +267,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
String modelConfigName = OnnxModelTransformer.getModelConfigName(reference);
String modelOutput = OnnxModelTransformer.getModelOutput(reference, null);
- reference = new Reference("onnxModel", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
+ reference = new Reference("onnx", new Arguments(new ReferenceNode(modelConfigName)), modelOutput);
if ( ! featureTypes.containsKey(reference)) {
throw new IllegalArgumentException("Missing onnx-model config for '" + configOrFileName + "'");
}
diff --git a/config-model/src/main/java/com/yahoo/schema/RankProfile.java b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
index 5479ecf323f..56786c733ec 100644
--- a/config-model/src/main/java/com/yahoo/schema/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/schema/RankProfile.java
@@ -1089,11 +1089,11 @@ public class RankProfile implements Cloneable {
Map<String, TensorType> inputTypes = resolveOnnxInputTypes(model, context);
TensorType defaultOutputType = model.getTensorType(model.getDefaultOutput(), inputTypes);
- context.setType(new Reference("onnxModel", args, null), defaultOutputType);
+ context.setType(new Reference("onnx", args, null), defaultOutputType);
for (Map.Entry<String, String> mapping : model.getOutputMap().entrySet()) {
TensorType type = model.getTensorType(mapping.getKey(), inputTypes);
- context.setType(new Reference("onnxModel", args, mapping.getValue()), type);
+ context.setType(new Reference("onnx", args, mapping.getValue()), type);
}
}
return context;
diff --git a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java
index 4c38c257602..8797deefcb6 100644
--- a/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java
+++ b/config-model/src/main/java/com/yahoo/schema/expressiontransforms/OnnxModelTransformer.java
@@ -20,17 +20,16 @@ import java.util.List;
/**
* Transforms ONNX model features of the forms:
*
- * onnxModel(config_name)
- * onnxModel(config_name).output
- * onnxModel("path/to/model")
- * onnxModel("path/to/model").output
- * onnxModel("path/to/model", "path/to/output")
- * onnxModel("path/to/model", "unused", "path/to/output") // signature is unused
- * onnx(...) // same as with onnxModel, onnx is an alias of onnxModel
+ * onnx(config_name)
+ * onnx(config_name).output
+ * onnx("path/to/model")
+ * onnx("path/to/model").output
+ * onnx("path/to/model", "path/to/output")
+ * onnx("path/to/model", "unused", "path/to/output") // signature is unused
*
* To the format expected by the backend:
*
- * onnxModel(config_name).output
+ * onnx(config_name).output
*
* @author lesters
*/
@@ -84,7 +83,7 @@ public class OnnxModelTransformer extends ExpressionTransformer<RankProfileTrans
throw new IllegalArgumentException(featureName + " argument '" + output +
"' output not found in model '" + onnxModel.getFileName() + "'");
}
- return new ReferenceNode("onnxModel", List.of(new ReferenceNode(modelConfigName)), output);
+ return new ReferenceNode("onnx", List.of(new ReferenceNode(modelConfigName)), output);
}
public static String getModelConfigName(Reference reference) {
diff --git a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
index 713e11fd608..1280895bfc0 100644
--- a/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
+++ b/config-model/src/test/java/com/yahoo/schema/processing/RankingExpressionWithOnnxModelTestCase.java
@@ -139,7 +139,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("vespa.rank.firstphase", config.rankprofile(2).fef().property(2).name());
assertEquals("rankingExpression(firstphase)", config.rankprofile(2).fef().property(2).value());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(2).fef().property(3).name());
- assertEquals("onnxModel(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value());
+ assertEquals("onnx(my_model).out{d0:1}", config.rankprofile(2).fef().property(3).value());
assertEquals("test_generated_model_config", config.rankprofile(3).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(3).fef().property(0).name());
@@ -149,25 +149,25 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("vespa.rank.firstphase", config.rankprofile(3).fef().property(8).name());
assertEquals("rankingExpression(firstphase)", config.rankprofile(3).fef().property(8).value());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(3).fef().property(9).name());
- assertEquals("onnxModel(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value());
+ assertEquals("onnx(files_model_onnx).path_to_output_1{d0:1}", config.rankprofile(3).fef().property(9).value());
assertEquals("test_summary_features", config.rankprofile(4).name());
assertEquals("rankingExpression(another_function).rankingScript", config.rankprofile(4).fef().property(0).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(4).fef().property(3).name());
assertEquals("1", config.rankprofile(4).fef().property(3).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(4).name());
- assertEquals("onnxModel(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(4).value());
+ assertEquals("onnx(another_model).out", config.rankprofile(4).fef().property(4).value());
assertEquals("vespa.summary.feature", config.rankprofile(4).fef().property(5).name());
- assertEquals("onnxModel(another_model).out", config.rankprofile(4).fef().property(5).value());
+ assertEquals("onnx(files_summary_model_onnx).path_to_output_2", config.rankprofile(4).fef().property(5).value());
assertEquals("test_dynamic_model", config.rankprofile(5).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(5).fef().property(0).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(5).fef().property(3).name());
- assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value());
+ assertEquals("onnx(dynamic_model).my_output{d0:0, d1:1}", config.rankprofile(5).fef().property(3).value());
assertEquals("test_dynamic_model_2", config.rankprofile(6).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(6).fef().property(5).name());
- assertEquals("onnxModel(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
+ assertEquals("onnx(dynamic_model).my_output{d0:0, d1:2}", config.rankprofile(6).fef().property(5).value());
assertEquals("test_dynamic_model_with_transformer_tokens", config.rankprofile(7).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(7).fef().property(1).name());
@@ -176,7 +176,7 @@ public class RankingExpressionWithOnnxModelTestCase {
assertEquals("test_unbound_model", config.rankprofile(8).name());
assertEquals("rankingExpression(my_function).rankingScript", config.rankprofile(8).fef().property(0).name());
assertEquals("rankingExpression(firstphase).rankingScript", config.rankprofile(8).fef().property(3).name());
- assertEquals("onnxModel(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value());
+ assertEquals("onnx(unbound_model).my_output{d0:0, d1:1}", config.rankprofile(8).fef().property(3).value());
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
index c60817704cd..694b908478d 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/ml/ModelEvaluationTest.java
@@ -170,7 +170,7 @@ public class ModelEvaluationTest {
}
private final String profile =
- "rankingExpression(output).rankingScript: onnxModel(small_constants_and_functions).output\n" +
+ "rankingExpression(output).rankingScript: onnx(small_constants_and_functions).output\n" +
"rankingExpression(output).type: tensor<float>(d0[3])\n";
private RankProfilesConfig.Rankprofile.Fef findProfile(String name, RankProfilesConfig config) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 3314ecb23fc..d030108a17a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -260,7 +260,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
/**
- * Extract the feature used to evaluate the onnx model. e.g. onnxModel(name) and add
+ * Extract the feature used to evaluate the onnx model. e.g. onnx(name) and add
* that as a bind target and argument. During evaluation, this will be evaluated before
* the rest of the expression and the result is added to the context. Also extract the
* inputs to the model and add them as bind targets and arguments.
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
index 144e3e6c03b..e8fc75824be 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java
@@ -37,7 +37,7 @@ public class OnnxImporter extends ModelImporter {
for (int i = 0; i < model.getGraph().getOutputCount(); ++i) {
Onnx.ValueInfoProto output = model.getGraph().getOutput(i);
String outputName = asValidIdentifier(output.getName());
- importedModel.expression(outputName, "onnxModel(" + modelName + ")." + outputName);
+ importedModel.expression(outputName, "onnx(" + modelName + ")." + outputName);
}
return importedModel;