diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 07:36:44 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 07:36:44 +0100 |
commit | 00e7d63e41842231528343a6e80ede595d997ff5 (patch) | |
tree | d611749f67d8ac3201b1a39b516339755715f236 /model-evaluation/src/main/java/ai | |
parent | c42b104ac2a231cb120719dd904d5ad2ac31fbeb (diff) |
- Reduce usage of guava.
- Ensure that tests relying on order are determinsitic.
Diffstat (limited to 'model-evaluation/src/main/java/ai')
5 files changed, 45 insertions, 56 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 6af33e29e62..1d3da73a509 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -101,9 +101,11 @@ public class FunctionEvaluator { } public Tensor evaluate() { - for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - checkArgument(argument.getKey(), argument.getValue()); - } + function.argumentTypes().keySet().stream().sorted() + .forEach(name -> { + var type = function.argumentTypes().get(name); + checkArgument(name, type); + }); evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d030108a17a..81325740218 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -1,8 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; +import com.yahoo.lang.MutableInteger; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -14,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * An array context supporting functions invocations implemented as lazy values. @@ -151,16 +152,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { private static class IndexedBindings { /** The mapping from variable name to index */ - private final ImmutableMap<String, Integer> nameToIndex; + private final Map<String, Integer> nameToIndex; /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ - private final ImmutableSet<String> arguments; + private final Set<String> arguments; /** The current values set */ private final Value[] values; /** ONNX models indexed by rank feature that calls them */ - private final ImmutableMap<String, OnnxModel> onnxModels; + private final Map<String, OnnxModel> onnxModels; /** The object instance which encodes "no value is set". The actual value of this is never used. */ private static final Value missing = new DoubleValue(Double.NaN).freeze(); @@ -169,14 +170,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { private Value missingValue = new DoubleValue(Double.NaN).freeze(); - private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, + private IndexedBindings(Map<String, Integer> nameToIndex, Value[] values, - ImmutableSet<String> arguments, - ImmutableMap<String, OnnxModel> onnxModels) { - this.nameToIndex = nameToIndex; + Set<String> arguments, + Map<String, OnnxModel> onnxModels) { + this.nameToIndex = Map.copyOf(nameToIndex); this.values = values; this.arguments = arguments; - this.onnxModels = onnxModels; + this.onnxModels = Map.copyOf(onnxModels); } /** @@ -195,16 +196,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { Map<String, OnnxModel> onnxModelsInUse = new HashMap<>(); extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse); - this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse); - this.arguments = ImmutableSet.copyOf(arguments); + this.onnxModels = Map.copyOf(onnxModelsInUse); + this.arguments = Set.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, missing); - int i = 0; - ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); - for (String variable : bindTargets) - nameToIndexBuilder.put(variable, i++); - nameToIndex = nameToIndexBuilder.build(); + MutableInteger nextIndex = new MutableInteger(0); + nameToIndex = Map.copyOf(bindTargets.stream() + .collect(CustomCollectors.toLinkedMap(name -> name, name -> nextIndex.next()))); // 2. Bind the bind targets for (Constant constant : constants) { @@ -252,8 +251,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { bindTargets.add(node.toString()); arguments.add(node.toString()); } - else if (node instanceof CompositeNode) { - CompositeNode cNode = (CompositeNode)node; + else if (node instanceof CompositeNode cNode) { for (ExpressionNode child : cNode.children()) extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); } @@ -291,16 +289,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { } private Optional<String> getArgument(ExpressionNode node) { - if (node instanceof ReferenceNode) { - ReferenceNode reference = (ReferenceNode) node; + if (node instanceof ReferenceNode reference) { if (reference.getArguments().size() > 0) { if (reference.getArguments().expressions().get(0) instanceof ConstantNode) { ExpressionNode constantNode = reference.getArguments().expressions().get(0); return Optional.of(stripQuotes(constantNode.toString())); } - if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0); - return Optional.of(referenceNode.getName()); + if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) { + return Optional.of(refNode.getName()); } } } @@ -316,20 +312,17 @@ public final class LazyArrayContext extends Context implements ContextIndex { } private boolean isFunctionReference(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode)node; + if ( ! (node instanceof ReferenceNode reference)) return false; return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; } private boolean isOnnx(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode) node; + if ( ! (node instanceof ReferenceNode reference)) return false; return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); } private boolean isConstant(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode)node; + if ( ! (node instanceof ReferenceNode reference)) return false; return reference.getName().equals("constant") && reference.getArguments().size() == 1; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 1ecec4108a3..ffcfb5e9379 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -2,15 +2,16 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; -import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -35,10 +36,10 @@ public class Model { private final List<ExpressionFunction> publicFunctions; /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */ - private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions; + private final Map<FunctionReference, ExpressionFunction> referencedFunctions; /** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */ - private final ImmutableMap<String, LazyArrayContext> contextPrototypes; + private final Map<String, LazyArrayContext> contextPrototypes; private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); @@ -46,9 +47,9 @@ public class Model { public Model(String name, Collection<ExpressionFunction> functions) { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), - Collections.emptyMap(), - Collections.emptyList(), - Collections.emptyList()); + Map.of(), + List.of(), + List.of()); } Model(String name, @@ -59,7 +60,7 @@ public class Model { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) - ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); + Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); @@ -92,19 +93,14 @@ public class Model { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } - this.contextPrototypes = contextBuilder.build(); + this.contextPrototypes = Map.copyOf(contextBuilder); this.functions = List.copyOf(functions.values()); this.publicFunctions = functions.values().stream() .filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList(); // Optimize functions - ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); - for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { - ExpressionFunction optimizedFunction = optimize(function.getValue(), - contextPrototypes.get(function.getKey().functionName())); - functionsBuilder.put(function.getKey(), optimizedFunction); - } - this.referencedFunctions = functionsBuilder.build(); + this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() + .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); } /** Returns an optimized version of the given function */ @@ -142,7 +138,7 @@ public class Model { } /** Returns an immutable map of the referenced function instances of this */ - Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; } + Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); } /** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */ ExpressionFunction requireReferencedFunction(FunctionReference reference) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 88843fd99ab..40a503e0212 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -2,7 +2,6 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; -import com.google.common.collect.ImmutableMap; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -23,7 +22,7 @@ import java.util.Map; @Beta public class ModelsEvaluator extends AbstractComponent { - private final ImmutableMap<String, Model> models; + private final Map<String, Model> models; @Inject public ModelsEvaluator(RankProfilesConfig config, @@ -43,7 +42,7 @@ public class ModelsEvaluator extends AbstractComponent { } public ModelsEvaluator(Map<String, Model> models) { - this.models = ImmutableMap.copyOf(models); + this.models = Map.copyOf(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 2661b9c2eb2..78addf0328a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -102,9 +102,8 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { private HttpResponse listAllModels(HttpRequest request) { Slime slime = new Slime(); Cursor root = slime.setObject(); - for (String modelName: modelsEvaluator.models().keySet()) { - root.setString(modelName, baseUrl(request) + modelName); - } + modelsEvaluator.models().keySet().stream().sorted() + .forEach(name -> root.setString(name, baseUrl(request) + name)); return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); } |