diff options
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 8 |
1 files changed, 6 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index c317cdc5922..84c8e2b1e38 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -49,6 +49,7 @@ public class Model implements AutoCloseable { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), Map.of(), + Map.of(), List.of(), List.of()); } @@ -56,6 +57,7 @@ public class Model implements AutoCloseable { Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, + Map<String, TensorType> declaredTypes, List<Constant> constants, List<OnnxModel> onnxModels) { this.name = name; @@ -85,8 +87,10 @@ public class Model implements AutoCloseable { } else { // External functions have type info (when not scalar) - add argument types - if (function.getValue().getArgumentType(argument) == null) - functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty)); + if (function.getValue().getArgumentType(argument) == null) { + TensorType type = declaredTypes.getOrDefault(argument, TensorType.empty); + functions.put(function.getKey(), function.getValue().withArgument(argument, type)); + } } } } |