diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 09:25:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-01 09:25:04 +0100 |
commit | f578da98634e6c148a360a9ac4ec2313ba1a3033 (patch) | |
tree | 2e7d52df9d5c87c5ff9c1fb77b85102486deb564 /model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | |
parent | 6826fbab9fc00e4a76d52f8aa6b489f55ef8a3ac (diff) |
Revert "- Reduce usage of guava."
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java')
-rw-r--r-- | model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java | 30 |
1 files changed, 17 insertions, 13 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ffcfb5e9379..1ecec4108a3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -2,16 +2,15 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; +import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; -import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; -import java.util.LinkedHashMap; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -36,10 +35,10 @@ public class Model { private final List<ExpressionFunction> publicFunctions; /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */ - private final Map<FunctionReference, ExpressionFunction> referencedFunctions; + private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions; /** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */ - private final Map<String, LazyArrayContext> contextPrototypes; + private final ImmutableMap<String, LazyArrayContext> contextPrototypes; private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); @@ -47,9 +46,9 @@ public class Model { public Model(String name, Collection<ExpressionFunction> functions) { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), - Map.of(), - List.of(), - List.of()); + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList()); } Model(String name, @@ -60,7 +59,7 @@ public class Model { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) - Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>(); + ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); @@ -93,14 +92,19 @@ public class Model { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } - this.contextPrototypes = Map.copyOf(contextBuilder); + this.contextPrototypes = contextBuilder.build(); this.functions = List.copyOf(functions.values()); this.publicFunctions = functions.values().stream() .filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList(); // Optimize functions - this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() - .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); + ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); + for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { + ExpressionFunction optimizedFunction = optimize(function.getValue(), + contextPrototypes.get(function.getKey().functionName())); + functionsBuilder.put(function.getKey(), optimizedFunction); + } + this.referencedFunctions = functionsBuilder.build(); } /** Returns an optimized version of the given function */ @@ -138,7 +142,7 @@ public class Model { } /** Returns an immutable map of the referenced function instances of this */ - Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); } + Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; } /** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */ ExpressionFunction requireReferencedFunction(FunctionReference reference) { |