aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 09:25:04 +0100
committerGitHub <noreply@github.com>2022-12-01 09:25:04 +0100
commitf578da98634e6c148a360a9ac4ec2313ba1a3033 (patch)
tree2e7d52df9d5c87c5ff9c1fb77b85102486deb564 /model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
parent6826fbab9fc00e4a76d52f8aa6b489f55ef8a3ac (diff)
Revert "- Reduce usage of guava."
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
1 files changed, 17 insertions, 13 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ffcfb5e9379..1ecec4108a3 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -2,16 +2,15 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
-import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -36,10 +35,10 @@ public class Model {
private final List<ExpressionFunction> publicFunctions;
/** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
- private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
+ private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
/** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */
- private final Map<String, LazyArrayContext> contextPrototypes;
+ private final ImmutableMap<String, LazyArrayContext> contextPrototypes;
private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
@@ -47,9 +46,9 @@ public class Model {
public Model(String name, Collection<ExpressionFunction> functions) {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
- Map.of(),
- List.of(),
- List.of());
+ Collections.emptyMap(),
+ Collections.emptyList(),
+ Collections.emptyList());
}
Model(String name,
@@ -60,7 +59,7 @@ public class Model {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
- Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>();
+ ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
@@ -93,14 +92,19 @@ public class Model {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
- this.contextPrototypes = Map.copyOf(contextBuilder);
+ this.contextPrototypes = contextBuilder.build();
this.functions = List.copyOf(functions.values());
this.publicFunctions = functions.values().stream()
.filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList();
// Optimize functions
- this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream()
- .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName())))));
+ ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
+ ExpressionFunction optimizedFunction = optimize(function.getValue(),
+ contextPrototypes.get(function.getKey().functionName()));
+ functionsBuilder.put(function.getKey(), optimizedFunction);
+ }
+ this.referencedFunctions = functionsBuilder.build();
}
/** Returns an optimized version of the given function */
@@ -138,7 +142,7 @@ public class Model {
}
/** Returns an immutable map of the referenced function instances of this */
- Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); }
+ Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; }
/** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */
ExpressionFunction requireReferencedFunction(FunctionReference reference) {