summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-01 14:43:28 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-01 14:44:39 +0000
commit890df240c421e516396d1327cccde4296fd1366a (patch)
treedb3618b71472a29df6487e728603535ca6e97e7b /model-evaluation
parent0d63497a0ce084191ba08625aa8413175844b6cb (diff)
only optimize functions that have a contextPrototype
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java12
1 files changed, 8 insertions, 4 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 84c8e2b1e38..1da8121ba8e 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -99,13 +99,17 @@ public class Model implements AutoCloseable {
}
}
this.contextPrototypes = Map.copyOf(contextBuilder);
- this.functions = List.copyOf(functions.values());
+ // Optimize free functions
+ this.functions = List.copyOf(functions.entrySet()
+ .stream()
+ .map(f -> optimize(f.getValue(),
+ contextPrototypes.get(f.getKey().functionName())))
+ .collect(Collectors.toList()));
+
this.publicFunctions = functions.values().stream()
.filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList();
- // Optimize functions
- this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream()
- .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName())))));
+ this.referencedFunctions = Map.copyOf(referencedFunctions);
this.closeActions = onnxModels.stream().map(o -> (Runnable)o::close).toList();
}