diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 07:36:44 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 07:36:44 +0100 |
commit | 00e7d63e41842231528343a6e80ede595d997ff5 (patch) | |
tree | d611749f67d8ac3201b1a39b516339755715f236 /model-evaluation | |
parent | c42b104ac2a231cb120719dd904d5ad2ac31fbeb (diff) |
- Reduce usage of guava.
- Ensure that tests relying on order are determinsitic.
Diffstat (limited to 'model-evaluation')
9 files changed, 77 insertions, 71 deletions
diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index caf28199c3d..c0600872666 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -74,11 +74,6 @@ <scope>provided</scope> </dependency> <dependency> - <groupId>com.google.guava</groupId> - <artifactId>guava</artifactId> - <scope>provided</scope> - </dependency> - <dependency> <groupId>org.lz4</groupId> <artifactId>lz4-java</artifactId> </dependency> diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 6af33e29e62..1d3da73a509 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -101,9 +101,11 @@ public class FunctionEvaluator { } public Tensor evaluate() { - for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { - checkArgument(argument.getKey(), argument.getValue()); - } + function.argumentTypes().keySet().stream().sorted() + .forEach(name -> { + var type = function.argumentTypes().get(name); + checkArgument(name, type); + }); evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index d030108a17a..81325740218 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -1,8 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import com.google.common.collect.ImmutableMap; -import com.google.common.collect.ImmutableSet; +import com.yahoo.lang.MutableInteger; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -14,6 +13,7 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -24,6 +24,7 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * An array context supporting functions invocations implemented as lazy values. @@ -151,16 +152,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { private static class IndexedBindings { /** The mapping from variable name to index */ - private final ImmutableMap<String, Integer> nameToIndex; + private final Map<String, Integer> nameToIndex; /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ - private final ImmutableSet<String> arguments; + private final Set<String> arguments; /** The current values set */ private final Value[] values; /** ONNX models indexed by rank feature that calls them */ - private final ImmutableMap<String, OnnxModel> onnxModels; + private final Map<String, OnnxModel> onnxModels; /** The object instance which encodes "no value is set". The actual value of this is never used. */ private static final Value missing = new DoubleValue(Double.NaN).freeze(); @@ -169,14 +170,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { private Value missingValue = new DoubleValue(Double.NaN).freeze(); - private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, + private IndexedBindings(Map<String, Integer> nameToIndex, Value[] values, - ImmutableSet<String> arguments, - ImmutableMap<String, OnnxModel> onnxModels) { - this.nameToIndex = nameToIndex; + Set<String> arguments, + Map<String, OnnxModel> onnxModels) { + this.nameToIndex = Map.copyOf(nameToIndex); this.values = values; this.arguments = arguments; - this.onnxModels = onnxModels; + this.onnxModels = Map.copyOf(onnxModels); } /** @@ -195,16 +196,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { Map<String, OnnxModel> onnxModelsInUse = new HashMap<>(); extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse); - this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse); - this.arguments = ImmutableSet.copyOf(arguments); + this.onnxModels = Map.copyOf(onnxModelsInUse); + this.arguments = Set.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, missing); - int i = 0; - ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); - for (String variable : bindTargets) - nameToIndexBuilder.put(variable, i++); - nameToIndex = nameToIndexBuilder.build(); + MutableInteger nextIndex = new MutableInteger(0); + nameToIndex = Map.copyOf(bindTargets.stream() + .collect(CustomCollectors.toLinkedMap(name -> name, name -> nextIndex.next()))); // 2. Bind the bind targets for (Constant constant : constants) { @@ -252,8 +251,7 @@ public final class LazyArrayContext extends Context implements ContextIndex { bindTargets.add(node.toString()); arguments.add(node.toString()); } - else if (node instanceof CompositeNode) { - CompositeNode cNode = (CompositeNode)node; + else if (node instanceof CompositeNode cNode) { for (ExpressionNode child : cNode.children()) extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); } @@ -291,16 +289,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { } private Optional<String> getArgument(ExpressionNode node) { - if (node instanceof ReferenceNode) { - ReferenceNode reference = (ReferenceNode) node; + if (node instanceof ReferenceNode reference) { if (reference.getArguments().size() > 0) { if (reference.getArguments().expressions().get(0) instanceof ConstantNode) { ExpressionNode constantNode = reference.getArguments().expressions().get(0); return Optional.of(stripQuotes(constantNode.toString())); } - if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0); - return Optional.of(referenceNode.getName()); + if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) { + return Optional.of(refNode.getName()); } } } @@ -316,20 +312,17 @@ public final class LazyArrayContext extends Context implements ContextIndex { } private boolean isFunctionReference(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode)node; + if ( ! (node instanceof ReferenceNode reference)) return false; return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; } private boolean isOnnx(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode) node; + if ( ! (node instanceof ReferenceNode reference)) return false; return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); } private boolean isConstant(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode)) return false; - ReferenceNode reference = (ReferenceNode)node; + if ( ! (node instanceof ReferenceNode reference)) return false; return reference.getName().equals("constant") && reference.getArguments().size() == 1; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index 1ecec4108a3..ffcfb5e9379 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -2,15 +2,16 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; -import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -35,10 +36,10 @@ public class Model { private final List<ExpressionFunction> publicFunctions; /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */ - private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions; + private final Map<FunctionReference, ExpressionFunction> referencedFunctions; /** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */ - private final ImmutableMap<String, LazyArrayContext> contextPrototypes; + private final Map<String, LazyArrayContext> contextPrototypes; private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); @@ -46,9 +47,9 @@ public class Model { public Model(String name, Collection<ExpressionFunction> functions) { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), - Collections.emptyMap(), - Collections.emptyList(), - Collections.emptyList()); + Map.of(), + List.of(), + List.of()); } Model(String name, @@ -59,7 +60,7 @@ public class Model { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) - ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); + Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); @@ -92,19 +93,14 @@ public class Model { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } - this.contextPrototypes = contextBuilder.build(); + this.contextPrototypes = Map.copyOf(contextBuilder); this.functions = List.copyOf(functions.values()); this.publicFunctions = functions.values().stream() .filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList(); // Optimize functions - ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); - for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { - ExpressionFunction optimizedFunction = optimize(function.getValue(), - contextPrototypes.get(function.getKey().functionName())); - functionsBuilder.put(function.getKey(), optimizedFunction); - } - this.referencedFunctions = functionsBuilder.build(); + this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() + .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); } /** Returns an optimized version of the given function */ @@ -142,7 +138,7 @@ public class Model { } /** Returns an immutable map of the referenced function instances of this */ - Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; } + Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); } /** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */ ExpressionFunction requireReferencedFunction(FunctionReference reference) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 88843fd99ab..40a503e0212 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -2,7 +2,6 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; -import com.google.common.collect.ImmutableMap; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -23,7 +22,7 @@ import java.util.Map; @Beta public class ModelsEvaluator extends AbstractComponent { - private final ImmutableMap<String, Model> models; + private final Map<String, Model> models; @Inject public ModelsEvaluator(RankProfilesConfig config, @@ -43,7 +42,7 @@ public class ModelsEvaluator extends AbstractComponent { } public ModelsEvaluator(Map<String, Model> models) { - this.models = ImmutableMap.copyOf(models); + this.models = Map.copyOf(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 2661b9c2eb2..78addf0328a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -102,9 +102,8 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { private HttpResponse listAllModels(HttpRequest request) { Slime slime = new Slime(); Cursor root = slime.setObject(); - for (String modelName: modelsEvaluator.models().keySet()) { - root.setString(modelName, baseUrl(request) + modelName); - } + modelsEvaluator.models().keySet().stream().sorted() + .forEach(name -> root.setString(name, baseUrl(request) + name)); return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index c4e859bec9f..f09bac63085 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -75,7 +75,7 @@ public class ModelsEvaluatorTest { evaluator.evaluate(); } catch (IllegalStateException e) { - assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", + assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])", Exceptions.toMessageString(e)); } @@ -88,6 +88,15 @@ public class ModelsEvaluatorTest { assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])", Exceptions.toMessageString(e)); } + try { // Just the other binding + FunctionEvaluator evaluator = model.evaluatorOf("test"); + evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}")); + evaluator.evaluate(); + } + catch (IllegalStateException e) { + assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", + Exceptions.toMessageString(e)); + } try { // Wrong binding argument FunctionEvaluator evaluator = model.evaluatorOf("test"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index fc05a9936a9..3b16be311a0 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -10,7 +10,6 @@ import com.yahoo.tensor.serialization.JsonFormat; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; -import java.util.Collections; import java.util.Map; import java.util.concurrent.Executors; @@ -25,19 +24,19 @@ class HandlerTester { } void assertResponse(String url, int expectedCode) { - assertResponse(url, Collections.emptyMap(), expectedCode, (String)null); + assertResponse(url, Map.of(), expectedCode, (String)null); } void assertResponse(String url, int expectedCode, String expectedResult) { - assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult); + assertResponse(url, Map.of(), expectedCode, expectedResult); } void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) { - assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult, headers); + assertResponse(url, Map.of(), expectedCode, expectedResult, headers); } void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) { - assertResponse(url, properties, expectedCode, expectedResult, Collections.emptyMap()); + assertResponse(url, properties, expectedCode, expectedResult, Map.of()); } void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 0de8ce5f061..d804e50c67d 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -49,16 +49,30 @@ public class ModelsEvaluationHandlerTest { @Test public void testListModels() { String url = "http://localhost/model-evaluation/v1"; - String expected = - "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}"; + String expected = "{" + + "\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"," + + "\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\"," + + "\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\"," + + "\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\"," + + "\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\"," + + "\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"," + + "\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\"" + + "}"; handler.assertResponse(url, 200, expected); } @Test public void testListModelsWithDifferentHost() { String url = "http://localhost/model-evaluation/v1"; - String expected = - "{\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"}"; + String expected = "{" + + "\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"," + + "\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\"," + + "\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\"," + + "\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\"," + + "\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\"," + + "\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\"," + + "\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\"" + + "}"; handler.assertResponse(url, 200, expected, Map.of("Host", "localhost:8088")); } |