aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java8
1 files changed, 6 insertions, 2 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index c317cdc5922..84c8e2b1e38 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -49,6 +49,7 @@ public class Model implements AutoCloseable {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
Map.of(),
+ Map.of(),
List.of(),
List.of());
}
@@ -56,6 +57,7 @@ public class Model implements AutoCloseable {
Model(String name,
Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
+ Map<String, TensorType> declaredTypes,
List<Constant> constants,
List<OnnxModel> onnxModels) {
this.name = name;
@@ -85,8 +87,10 @@ public class Model implements AutoCloseable {
}
else {
// External functions have type info (when not scalar) - add argument types
- if (function.getValue().getArgumentType(argument) == null)
- functions.put(function.getKey(), function.getValue().withArgument(argument, TensorType.empty));
+ if (function.getValue().getArgumentType(argument) == null) {
+ TensorType type = declaredTypes.getOrDefault(argument, TensorType.empty);
+ functions.put(function.getKey(), function.getValue().withArgument(argument, type));
+ }
}
}
}