summaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java18
1 files changed, 15 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 03bbb436026..40a84a701ec 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -26,7 +26,7 @@ import java.util.stream.Collectors;
@Beta
public class Model {
- /** The prefix generated by mode-integration/../IntermediateOperation */
+ /** The prefix generated by model-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
private final String name;
@@ -50,25 +50,37 @@ public class Model {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
Collections.emptyMap(),
+ Collections.emptyList(),
Collections.emptyList());
}
Model(String name,
Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
- List<Constant> constants) {
+ List<Constant> constants,
+ List<OnnxModel> onnxModels) {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
- LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this);
+ LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
contextBuilder.put(function.getValue().getName(), context);
if ( ! function.getValue().returnType().isPresent()) {
functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty));
}
+ for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel onnxModel = entry.getValue();
+ for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) {
+ functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue()));
+ }
+ TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName());
+ functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType));
+ }
+
for (String argument : context.arguments()) {
if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) {
// Internal (generated) functions do not have type info - add arguments