aboutsummaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 09:25:04 +0100
committerGitHub <noreply@github.com>2022-12-01 09:25:04 +0100
commitf578da98634e6c148a360a9ac4ec2313ba1a3033 (patch)
tree2e7d52df9d5c87c5ff9c1fb77b85102486deb564 /model-evaluation
parent6826fbab9fc00e4a76d52f8aa6b489f55ef8a3ac (diff)
Revert "- Reduce usage of guava."
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/pom.xml5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java53
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java30
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java5
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java5
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java11
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java9
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java22
9 files changed, 71 insertions, 77 deletions
diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml
index c0600872666..caf28199c3d 100644
--- a/model-evaluation/pom.xml
+++ b/model-evaluation/pom.xml
@@ -74,6 +74,11 @@
<scope>provided</scope>
</dependency>
<dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>org.lz4</groupId>
<artifactId>lz4-java</artifactId>
</dependency>
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index 1d3da73a509..6af33e29e62 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -101,11 +101,9 @@ public class FunctionEvaluator {
}
public Tensor evaluate() {
- function.argumentTypes().keySet().stream().sorted()
- .forEach(name -> {
- var type = function.argumentTypes().get(name);
- checkArgument(name, type);
- });
+ for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
+ checkArgument(argument.getKey(), argument.getValue());
+ }
evaluated = true;
evaluateOnnxModels();
return function.getBody().evaluate(context).asTensor();
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index 81325740218..d030108a17a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -1,7 +1,8 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.models.evaluation;
-import com.yahoo.lang.MutableInteger;
+import com.google.common.collect.ImmutableMap;
+import com.google.common.collect.ImmutableSet;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -13,7 +14,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -24,7 +24,6 @@ import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
-import java.util.stream.Collectors;
/**
* An array context supporting functions invocations implemented as lazy values.
@@ -152,16 +151,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private static class IndexedBindings {
/** The mapping from variable name to index */
- private final Map<String, Integer> nameToIndex;
+ private final ImmutableMap<String, Integer> nameToIndex;
/** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
- private final Set<String> arguments;
+ private final ImmutableSet<String> arguments;
/** The current values set */
private final Value[] values;
/** ONNX models indexed by rank feature that calls them */
- private final Map<String, OnnxModel> onnxModels;
+ private final ImmutableMap<String, OnnxModel> onnxModels;
/** The object instance which encodes "no value is set". The actual value of this is never used. */
private static final Value missing = new DoubleValue(Double.NaN).freeze();
@@ -170,14 +169,14 @@ public final class LazyArrayContext extends Context implements ContextIndex {
private Value missingValue = new DoubleValue(Double.NaN).freeze();
- private IndexedBindings(Map<String, Integer> nameToIndex,
+ private IndexedBindings(ImmutableMap<String, Integer> nameToIndex,
Value[] values,
- Set<String> arguments,
- Map<String, OnnxModel> onnxModels) {
- this.nameToIndex = Map.copyOf(nameToIndex);
+ ImmutableSet<String> arguments,
+ ImmutableMap<String, OnnxModel> onnxModels) {
+ this.nameToIndex = nameToIndex;
this.values = values;
this.arguments = arguments;
- this.onnxModels = Map.copyOf(onnxModels);
+ this.onnxModels = onnxModels;
}
/**
@@ -196,14 +195,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
Map<String, OnnxModel> onnxModelsInUse = new HashMap<>();
extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse);
- this.onnxModels = Map.copyOf(onnxModelsInUse);
- this.arguments = Set.copyOf(arguments);
+ this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse);
+ this.arguments = ImmutableSet.copyOf(arguments);
values = new Value[bindTargets.size()];
Arrays.fill(values, missing);
- MutableInteger nextIndex = new MutableInteger(0);
- nameToIndex = Map.copyOf(bindTargets.stream()
- .collect(CustomCollectors.toLinkedMap(name -> name, name -> nextIndex.next())));
+ int i = 0;
+ ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>();
+ for (String variable : bindTargets)
+ nameToIndexBuilder.put(variable, i++);
+ nameToIndex = nameToIndexBuilder.build();
// 2. Bind the bind targets
for (Constant constant : constants) {
@@ -251,7 +252,8 @@ public final class LazyArrayContext extends Context implements ContextIndex {
bindTargets.add(node.toString());
arguments.add(node.toString());
}
- else if (node instanceof CompositeNode cNode) {
+ else if (node instanceof CompositeNode) {
+ CompositeNode cNode = (CompositeNode)node;
for (ExpressionNode child : cNode.children())
extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse);
}
@@ -289,14 +291,16 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
private Optional<String> getArgument(ExpressionNode node) {
- if (node instanceof ReferenceNode reference) {
+ if (node instanceof ReferenceNode) {
+ ReferenceNode reference = (ReferenceNode) node;
if (reference.getArguments().size() > 0) {
if (reference.getArguments().expressions().get(0) instanceof ConstantNode) {
ExpressionNode constantNode = reference.getArguments().expressions().get(0);
return Optional.of(stripQuotes(constantNode.toString()));
}
- if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) {
- return Optional.of(refNode.getName());
+ if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0);
+ return Optional.of(referenceNode.getName());
}
}
}
@@ -312,17 +316,20 @@ public final class LazyArrayContext extends Context implements ContextIndex {
}
private boolean isFunctionReference(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode reference)) return false;
+ if ( ! (node instanceof ReferenceNode)) return false;
+ ReferenceNode reference = (ReferenceNode)node;
return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1;
}
private boolean isOnnx(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode reference)) return false;
+ if ( ! (node instanceof ReferenceNode)) return false;
+ ReferenceNode reference = (ReferenceNode) node;
return reference.getName().equals("onnx") || reference.getName().equals("onnxModel");
}
private boolean isConstant(ExpressionNode node) {
- if ( ! (node instanceof ReferenceNode reference)) return false;
+ if ( ! (node instanceof ReferenceNode)) return false;
+ ReferenceNode reference = (ReferenceNode)node;
return reference.getName().equals("constant") && reference.getArguments().size() == 1;
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index ffcfb5e9379..1ecec4108a3 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -2,16 +2,15 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
-import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
+import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -36,10 +35,10 @@ public class Model {
private final List<ExpressionFunction> publicFunctions;
/** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */
- private final Map<FunctionReference, ExpressionFunction> referencedFunctions;
+ private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions;
/** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */
- private final Map<String, LazyArrayContext> contextPrototypes;
+ private final ImmutableMap<String, LazyArrayContext> contextPrototypes;
private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer();
@@ -47,9 +46,9 @@ public class Model {
public Model(String name, Collection<ExpressionFunction> functions) {
this(name,
functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)),
- Map.of(),
- List.of(),
- List.of());
+ Collections.emptyMap(),
+ Collections.emptyList(),
+ Collections.emptyList());
}
Model(String name,
@@ -60,7 +59,7 @@ public class Model {
this.name = name;
// Build context and add missing function arguments (missing because it is legal to omit scalar type arguments)
- Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>();
+ ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this);
@@ -93,14 +92,19 @@ public class Model {
throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e);
}
}
- this.contextPrototypes = Map.copyOf(contextBuilder);
+ this.contextPrototypes = contextBuilder.build();
this.functions = List.copyOf(functions.values());
this.publicFunctions = functions.values().stream()
.filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList();
// Optimize functions
- this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream()
- .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName())))));
+ ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>();
+ for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) {
+ ExpressionFunction optimizedFunction = optimize(function.getValue(),
+ contextPrototypes.get(function.getKey().functionName()));
+ functionsBuilder.put(function.getKey(), optimizedFunction);
+ }
+ this.referencedFunctions = functionsBuilder.build();
}
/** Returns an optimized version of the given function */
@@ -138,7 +142,7 @@ public class Model {
}
/** Returns an immutable map of the referenced function instances of this */
- Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); }
+ Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; }
/** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */
ExpressionFunction requireReferencedFunction(FunctionReference reference) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index 40a503e0212..88843fd99ab 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -2,6 +2,7 @@
package ai.vespa.models.evaluation;
import com.yahoo.api.annotations.Beta;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.component.annotation.Inject;
import com.yahoo.component.AbstractComponent;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
@@ -22,7 +23,7 @@ import java.util.Map;
@Beta
public class ModelsEvaluator extends AbstractComponent {
- private final Map<String, Model> models;
+ private final ImmutableMap<String, Model> models;
@Inject
public ModelsEvaluator(RankProfilesConfig config,
@@ -42,7 +43,7 @@ public class ModelsEvaluator extends AbstractComponent {
}
public ModelsEvaluator(Map<String, Model> models) {
- this.models = Map.copyOf(models);
+ this.models = ImmutableMap.copyOf(models);
}
/** Returns the models of this as an immutable map */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index 78addf0328a..2661b9c2eb2 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -102,8 +102,9 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
private HttpResponse listAllModels(HttpRequest request) {
Slime slime = new Slime();
Cursor root = slime.setObject();
- modelsEvaluator.models().keySet().stream().sorted()
- .forEach(name -> root.setString(name, baseUrl(request) + name));
+ for (String modelName: modelsEvaluator.models().keySet()) {
+ root.setString(modelName, baseUrl(request) + modelName);
+ }
return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index f09bac63085..c4e859bec9f 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -75,7 +75,7 @@ public class ModelsEvaluatorTest {
evaluator.evaluate();
}
catch (IllegalStateException e) {
- assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])",
+ assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})",
Exceptions.toMessageString(e));
}
@@ -88,15 +88,6 @@ public class ModelsEvaluatorTest {
assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])",
Exceptions.toMessageString(e));
}
- try { // Just the other binding
- FunctionEvaluator evaluator = model.evaluatorOf("test");
- evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}"));
- evaluator.evaluate();
- }
- catch (IllegalStateException e) {
- assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})",
- Exceptions.toMessageString(e));
- }
try { // Wrong binding argument
FunctionEvaluator evaluator = model.evaluatorOf("test");
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
index 3b16be311a0..fc05a9936a9 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.serialization.JsonFormat;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.nio.charset.StandardCharsets;
+import java.util.Collections;
import java.util.Map;
import java.util.concurrent.Executors;
@@ -24,19 +25,19 @@ class HandlerTester {
}
void assertResponse(String url, int expectedCode) {
- assertResponse(url, Map.of(), expectedCode, (String)null);
+ assertResponse(url, Collections.emptyMap(), expectedCode, (String)null);
}
void assertResponse(String url, int expectedCode, String expectedResult) {
- assertResponse(url, Map.of(), expectedCode, expectedResult);
+ assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult);
}
void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) {
- assertResponse(url, Map.of(), expectedCode, expectedResult, headers);
+ assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult, headers);
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) {
- assertResponse(url, properties, expectedCode, expectedResult, Map.of());
+ assertResponse(url, properties, expectedCode, expectedResult, Collections.emptyMap());
}
void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index d804e50c67d..0de8ce5f061 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -49,30 +49,16 @@ public class ModelsEvaluationHandlerTest {
@Test
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
- String expected = "{" +
- "\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"," +
- "\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\"," +
- "\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\"," +
- "\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\"," +
- "\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\"," +
- "\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"," +
- "\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\"" +
- "}";
+ String expected =
+ "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
handler.assertResponse(url, 200, expected);
}
@Test
public void testListModelsWithDifferentHost() {
String url = "http://localhost/model-evaluation/v1";
- String expected = "{" +
- "\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"," +
- "\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\"," +
- "\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\"," +
- "\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\"," +
- "\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\"," +
- "\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\"," +
- "\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\"" +
- "}";
+ String expected =
+ "{\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"}";
handler.assertResponse(url, 200, expected, Map.of("Host", "localhost:8088"));
}