aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation/src/main/java/ai/vespa/models/evaluation
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 09:32:05 +0100
committerGitHub <noreply@github.com>2022-12-01 09:32:05 +0100
commit1eb22cc4a24973f52b344c3033cff394c724cbe4 (patch)
tree98fdcb5bed45fc1199400988d45cf6bb47e413f2 /model-evaluation/src/main/java/ai/vespa/models/evaluation
parent2925f225b34ad7fa3eb515bbddcc8c774e514131 (diff)
Revert "Revert "- Reduce usage of guava.""
Diffstat (limited to 'model-evaluation/src/main/java/ai/vespa/models/evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java53
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java5
4 files changed, 43 insertions, 53 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 6af33e29e62..1d3da73a509 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -101,9 +101,11 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- checkArgument(argument.getKey(), argument.getValue());
- }
+ function.argumentTypes().keySet().stream().sorted()
+ .forEach(name -> {
+ var type = function.argumentTypes().get(name);
+ checkArgument(name, type);
+ });
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index d030108a17a..81325740218 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -1,8 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
+import com.yahoo.lang.MutableInteger;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -14,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -24,6 +24,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* An array context supporting functions invocations implemented as lazy values.
@@ -151,16 +152,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private static class IndexedBindings {
/** The mapping from variable name to index */
- private final ImmutableMap<String, Integer> nameToIndex;
+ private final Map<String, Integer> nameToIndex;
/** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
- private final ImmutableSet<String> arguments;
+ private final Set<String> arguments;
/** The current values set */
private final Value[] values;
/** ONNX models indexed by rank feature that calls them */
- private final ImmutableMap<String, OnnxModel> onnxModels;
+ private final Map<String, OnnxModel> onnxModels;
/** The object instance which encodes "no value is set". The actual value of this is never used. */
private static final Value missing = new DoubleValue(Double.NaN).freeze();
@@ -169,14 +170,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private Value missingValue = new DoubleValue(Double.NaN).freeze();
- private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
+ private IndexedBindings(Map<String, Integer> nameToIndex,
Value[] values,
- ImmutableSet<String> arguments,
- ImmutableMap<String, OnnxModel> onnxModels) {
- this.nameToIndex = nameToIndex;
+ Set<String> arguments,
+ Map<String, OnnxModel> onnxModels) {
+ this.nameToIndex = Map.copyOf(nameToIndex);
this.values = values;
this.arguments = arguments;
- this.onnxModels = onnxModels;
+ this.onnxModels = Map.copyOf(onnxModels);
}
/**
@@ -195,16 +196,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
- this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
- this.arguments = ImmutableSet.copyOf(arguments);
+ this.onnxModels = Map.copyOf(onnxModelsInUse);
+ this.arguments = Set.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, missing);
- int i = 0;
- ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
- for (String variable : bindTargets)
- nameToIndexBuilder.put(variable, i++);
- nameToIndex = nameToIndexBuilder.build();
+ MutableInteger nextIndex = new MutableInteger(0);
+ nameToIndex = Map.copyOf(bindTargets.stream()
+ .collect(CustomCollectors.toLinkedMap(name -> name, name -> nextIndex.next())));
// 2. Bind the bind targets
for (Constant constant : constants) {
@@ -252,8 +251,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
bindTargets.add(node.toString());
arguments.add(node.toString());
}
- else if (node instanceof CompositeNode) {
- CompositeNode cNode = (CompositeNode)node;
+ else if (node instanceof CompositeNode cNode) {
for (ExpressionNode child : cNode.children())
extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
}
@@ -291,16 +289,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
private Optional<String> getArgument(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- ReferenceNode reference = (ReferenceNode) node;
+ if (node instanceof ReferenceNode reference) {
if (reference.getArguments().size() > 0) {
if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
ExpressionNode constantNode = reference.getArguments().expressions().get(0);
return Optional.of(stripQuotes(constantNode.toString()));
}
- if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
- return Optional.of(referenceNode.getName());
+ if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) {
+ return Optional.of(refNode.getName());
}
}
}
@@ -316,20 +312,17 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
private boolean isFunctionReference(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode)) return false;
- ReferenceNode reference = (ReferenceNode)node;
+ if ( ! (node instanceof ReferenceNode reference)) return false;
return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
}
private boolean isOnnx(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode)) return false;
- ReferenceNode reference = (ReferenceNode) node;
+ if ( ! (node instanceof ReferenceNode reference)) return false;
return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
}
private boolean isConstant(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode)) return false;
- ReferenceNode reference = (ReferenceNode)node;
+ if ( ! (node instanceof ReferenceNode reference)) return false;
return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 1ecec4108a3..ffcfb5e9379 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -2,15 +2,16 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -35,10 +36,10 @@ public class Model {
private final List<ExpressionFunction> publicFunctions;
/** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
- private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
+ private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
/** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */
- private final ImmutableMap<String, LazyArrayContext> contextPrototypes;
+ private final Map<String, LazyArrayContext> contextPrototypes;
private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
@@ -46,9 +47,9 @@ public class Model {
public Model(String name, Collection<ExpressionFunction> functions) {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
- Collections.emptyMap(),
- Collections.emptyList(),
- Collections.emptyList());
+ Map.of(),
+ List.of(),
+ List.of());
}
Model(String name,
@@ -59,7 +60,7 @@ public class Model {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
- ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
+ Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
@@ -92,19 +93,14 @@ public class Model {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
- this.contextPrototypes = contextBuilder.build();
+ this.contextPrototypes = Map.copyOf(contextBuilder);
this.functions = List.copyOf(functions.values());
this.publicFunctions = functions.values().stream()
.filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList();
// Optimize functions
- ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
- for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
- ExpressionFunction optimizedFunction = optimize(function.getValue(),
- contextPrototypes.get(function.getKey().functionName()));
- functionsBuilder.put(function.getKey(), optimizedFunction);
- }
- this.referencedFunctions = functionsBuilder.build();
+ this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream()
+ .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName())))));
}
/** Returns an optimized version of the given function */
@@ -142,7 +138,7 @@ public class Model {
}
/** Returns an immutable map of the referenced function instances of this */
- Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; }
+ Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); }
/** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */
ExpressionFunction requireReferencedFunction(FunctionReference reference) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 88843fd99ab..40a503e0212 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -2,7 +2,6 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -23,7 +22,7 @@ import java.util.Map;
@Beta
public class ModelsEvaluator extends AbstractComponent {
- private final ImmutableMap<String, Model> models;
+ private final Map<String, Model> models;
@Inject
public ModelsEvaluator(RankProfilesConfig config,
@@ -43,7 +42,7 @@ public class ModelsEvaluator extends AbstractComponent {
}
public ModelsEvaluator(Map<String, Model> models) {
- this.models = ImmutableMap.copyOf(models);
+ this.models = Map.copyOf(models);
}
/** Returns the models of this as an immutable map */