diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 18 |
1 files changed, 15 insertions, 3 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 03bbb436026..40a84a701ec 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -26,7 +26,7 @@ import java.util.stream.Collectors; @Beta public class Model { - /** The prefix generated by mode-integration/../IntermediateOperation */ + /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; private final String name; @@ -50,25 +50,37 @@ public class Model { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Collections.emptyMap(), + Collections.emptyList(), Collections.emptyList()); } Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, - List<Constant> constants) { + List<Constant> constants, + List<OnnxModel> onnxModels) { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { - LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, this); + LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); contextBuilder.put(function.getValue().getName(), context); if ( ! function.getValue().returnType().isPresent()) { functions.put(function.getKey(), function.getValue().withReturnType(TensorType.empty)); } + for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) { + String onnxFeature = entry.getKey(); + OnnxModel onnxModel = entry.getValue(); + for(Map.Entry<String, TensorType> input : onnxModel.inputs().entrySet()) { + functions.put(function.getKey(), function.getValue().withArgument(input.getKey(), input.getValue())); + } + TensorType onnxOutputType = onnxModel.outputs().get(function.getKey().functionName()); + functions.put(function.getKey(), function.getValue().withArgument(onnxFeature, onnxOutputType)); + } + for (String argument : context.arguments()) { if (function.getValue().getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)) { // Internal (generated) functions do not have type info - add arguments |