summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 07:36:44 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 07:36:44 +0100
commit00e7d63e41842231528343a6e80ede595d997ff5 (patch)
treed611749f67d8ac3201b1a39b516339755715f236 /model-evaluation
parentc42b104ac2a231cb120719dd904d5ad2ac31fbeb (diff)
- Reduce usage of guava.
- Ensure that tests relying on order are determinsitic.
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/pom.xml5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java53
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java11
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java9
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java22
9 files changed, 77 insertions, 71 deletions
diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml
index caf28199c3d..c0600872666 100644
--- a/model-evaluation/pom.xml
+++ b/model-evaluation/pom.xml
@@ -74,11 +74,6 @@
<scope>provided</scope>
</dependency>
<dependency>
- <groupId>com.google.guava</groupId>
- <artifactId>guava</artifactId>
- <scope>provided</scope>
- </dependency>
- <dependency>
<groupId>org.lz4</groupId>
<artifactId>lz4-java</artifactId>
</dependency>
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 6af33e29e62..1d3da73a509 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -101,9 +101,11 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- checkArgument(argument.getKey(), argument.getValue());
- }
+ function.argumentTypes().keySet().stream().sorted()
+ .forEach(name -> {
+ var type = function.argumentTypes().get(name);
+ checkArgument(name, type);
+ });
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index d030108a17a..81325740218 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -1,8 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import com.google.common.collect.ImmutableMap;
-import com.google.common.collect.ImmutableSet;
+import com.yahoo.lang.MutableInteger;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -14,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -24,6 +24,7 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
+import java.util.stream.Collectors;
/**
* An array context supporting functions invocations implemented as lazy values.
@@ -151,16 +152,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private static class IndexedBindings {
/** The mapping from variable name to index */
- private final ImmutableMap<String, Integer> nameToIndex;
+ private final Map<String, Integer> nameToIndex;
/** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
- private final ImmutableSet<String> arguments;
+ private final Set<String> arguments;
/** The current values set */
private final Value[] values;
/** ONNX models indexed by rank feature that calls them */
- private final ImmutableMap<String, OnnxModel> onnxModels;
+ private final Map<String, OnnxModel> onnxModels;
/** The object instance which encodes "no value is set". The actual value of this is never used. */
private static final Value missing = new DoubleValue(Double.NaN).freeze();
@@ -169,14 +170,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private Value missingValue = new DoubleValue(Double.NaN).freeze();
- private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
+ private IndexedBindings(Map<String, Integer> nameToIndex,
Value[] values,
- ImmutableSet<String> arguments,
- ImmutableMap<String, OnnxModel> onnxModels) {
- this.nameToIndex = nameToIndex;
+ Set<String> arguments,
+ Map<String, OnnxModel> onnxModels) {
+ this.nameToIndex = Map.copyOf(nameToIndex);
this.values = values;
this.arguments = arguments;
- this.onnxModels = onnxModels;
+ this.onnxModels = Map.copyOf(onnxModels);
}
/**
@@ -195,16 +196,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
- this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
- this.arguments = ImmutableSet.copyOf(arguments);
+ this.onnxModels = Map.copyOf(onnxModelsInUse);
+ this.arguments = Set.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, missing);
- int i = 0;
- ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
- for (String variable : bindTargets)
- nameToIndexBuilder.put(variable, i++);
- nameToIndex = nameToIndexBuilder.build();
+ MutableInteger nextIndex = new MutableInteger(0);
+ nameToIndex = Map.copyOf(bindTargets.stream()
+ .collect(CustomCollectors.toLinkedMap(name -> name, name -> nextIndex.next())));
// 2. Bind the bind targets
for (Constant constant : constants) {
@@ -252,8 +251,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
bindTargets.add(node.toString());
arguments.add(node.toString());
}
- else if (node instanceof CompositeNode) {
- CompositeNode cNode = (CompositeNode)node;
+ else if (node instanceof CompositeNode cNode) {
for (ExpressionNode child : cNode.children())
extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
}
@@ -291,16 +289,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
private Optional<String> getArgument(ExpressionNode node) {
- if (node instanceof ReferenceNode) {
- ReferenceNode reference = (ReferenceNode) node;
+ if (node instanceof ReferenceNode reference) {
if (reference.getArguments().size() > 0) {
if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
ExpressionNode constantNode = reference.getArguments().expressions().get(0);
return Optional.of(stripQuotes(constantNode.toString()));
}
- if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
- return Optional.of(referenceNode.getName());
+ if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) {
+ return Optional.of(refNode.getName());
}
}
}
@@ -316,20 +312,17 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
private boolean isFunctionReference(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode)) return false;
- ReferenceNode reference = (ReferenceNode)node;
+ if ( ! (node instanceof ReferenceNode reference)) return false;
return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
}
private boolean isOnnx(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode)) return false;
- ReferenceNode reference = (ReferenceNode) node;
+ if ( ! (node instanceof ReferenceNode reference)) return false;
return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
}
private boolean isConstant(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode)) return false;
- ReferenceNode reference = (ReferenceNode)node;
+ if ( ! (node instanceof ReferenceNode reference)) return false;
return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 1ecec4108a3..ffcfb5e9379 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -2,15 +2,16 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
-import java.util.Collections;
+import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -35,10 +36,10 @@ public class Model {
private final List<ExpressionFunction> publicFunctions;
/** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
- private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
+ private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
/** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */
- private final ImmutableMap<String, LazyArrayContext> contextPrototypes;
+ private final Map<String, LazyArrayContext> contextPrototypes;
private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
@@ -46,9 +47,9 @@ public class Model {
public Model(String name, Collection<ExpressionFunction> functions) {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
- Collections.emptyMap(),
- Collections.emptyList(),
- Collections.emptyList());
+ Map.of(),
+ List.of(),
+ List.of());
}
Model(String name,
@@ -59,7 +60,7 @@ public class Model {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
- ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
+ Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
@@ -92,19 +93,14 @@ public class Model {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
- this.contextPrototypes = contextBuilder.build();
+ this.contextPrototypes = Map.copyOf(contextBuilder);
this.functions = List.copyOf(functions.values());
this.publicFunctions = functions.values().stream()
.filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList();
// Optimize functions
- ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
- for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
- ExpressionFunction optimizedFunction = optimize(function.getValue(),
- contextPrototypes.get(function.getKey().functionName()));
- functionsBuilder.put(function.getKey(), optimizedFunction);
- }
- this.referencedFunctions = functionsBuilder.build();
+ this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream()
+ .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName())))));
}
/** Returns an optimized version of the given function */
@@ -142,7 +138,7 @@ public class Model {
}
/** Returns an immutable map of the referenced function instances of this */
- Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; }
+ Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); }
/** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */
ExpressionFunction requireReferencedFunction(FunctionReference reference) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 88843fd99ab..40a503e0212 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -2,7 +2,6 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
-import com.google.common.collect.ImmutableMap;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -23,7 +22,7 @@ import java.util.Map;
@Beta
public class ModelsEvaluator extends AbstractComponent {
- private final ImmutableMap<String, Model> models;
+ private final Map<String, Model> models;
@Inject
public ModelsEvaluator(RankProfilesConfig config,
@@ -43,7 +42,7 @@ public class ModelsEvaluator extends AbstractComponent {
}
public ModelsEvaluator(Map<String, Model> models) {
- this.models = ImmutableMap.copyOf(models);
+ this.models = Map.copyOf(models);
}
/** Returns the models of this as an immutable map */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 2661b9c2eb2..78addf0328a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -102,9 +102,8 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
private HttpResponse listAllModels(HttpRequest request) {
Slime slime = new Slime();
Cursor root = slime.setObject();
- for (String modelName: modelsEvaluator.models().keySet()) {
- root.setString(modelName, baseUrl(request) + modelName);
- }
+ modelsEvaluator.models().keySet().stream().sorted()
+ .forEach(name -> root.setString(name, baseUrl(request) + name));
return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index c4e859bec9f..f09bac63085 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -75,7 +75,7 @@ public class ModelsEvaluatorTest {
evaluator.evaluate();
}
catch (IllegalStateException e) {
- assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})",
+ assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])",
Exceptions.toMessageString(e));
}
@@ -88,6 +88,15 @@ public class ModelsEvaluatorTest {
assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])",
Exceptions.toMessageString(e));
}
+ try { // Just the other binding
+ FunctionEvaluator evaluator = model.evaluatorOf("test");
+ evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}"));
+ evaluator.evaluate();
+ }
+ catch (IllegalStateException e) {
+ assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})",
+ Exceptions.toMessageString(e));
+ }
try { // Wrong binding argument
FunctionEvaluator evaluator = model.evaluatorOf("test");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index fc05a9936a9..3b16be311a0 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -10,7 +10,6 @@ import com.yahoo.tensor.serialization.JsonFormat;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
-import java.util.Collections;
import java.util.Map;
import java.util.concurrent.Executors;
@@ -25,19 +24,19 @@ class HandlerTester {
}
void assertResponse(String url, int expectedCode) {
- assertResponse(url, Collections.emptyMap(), expectedCode, (String)null);
+ assertResponse(url, Map.of(), expectedCode, (String)null);
}
void assertResponse(String url, int expectedCode, String expectedResult) {
- assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
+ assertResponse(url, Map.of(), expectedCode, expectedResult);
}
void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) {
- assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult, headers);
+ assertResponse(url, Map.of(), expectedCode, expectedResult, headers);
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
- assertResponse(url, properties, expectedCode, expectedResult, Collections.emptyMap());
+ assertResponse(url, properties, expectedCode, expectedResult, Map.of());
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 0de8ce5f061..d804e50c67d 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -49,16 +49,30 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
- String expected =
- "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
+ String expected = "{" +
+ "\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"," +
+ "\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\"," +
+ "\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\"," +
+ "\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\"," +
+ "\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\"," +
+ "\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"," +
+ "\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\"" +
+ "}";
handler.assertResponse(url, 200, expected);
}
@Test
public void testListModelsWithDifferentHost() {
String url = "http://localhost/model-evaluation/v1";
- String expected =
- "{\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"}";
+ String expected = "{" +
+ "\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"," +
+ "\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\"," +
+ "\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\"," +
+ "\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\"," +
+ "\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\"," +
+ "\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\"," +
+ "\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\"" +
+ "}";
handler.assertResponse(url, 200, expected, Map.of("Host", "localhost:8088"));
}