diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2022-12-01 09:25:04 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2022-12-01 09:25:04 +0100 |
commit | f578da98634e6c148a360a9ac4ec2313ba1a3033 (patch) | |
tree | 2e7d52df9d5c87c5ff9c1fb77b85102486deb564 /model-evaluation | |
parent | 6826fbab9fc00e4a76d52f8aa6b489f55ef8a3ac (diff) |
Revert "- Reduce usage of guava."
Diffstat (limited to 'model-evaluation')
9 files changed, 71 insertions, 77 deletions
diff --git a/model-evaluation/pom.xml b/model-evaluation/pom.xml index c0600872666..caf28199c3d 100644 --- a/model-evaluation/pom.xml +++ b/model-evaluation/pom.xml @@ -74,6 +74,11 @@ <scope>provided</scope> </dependency> <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + <scope>provided</scope> + </dependency> + <dependency> <groupId>org.lz4</groupId> <artifactId>lz4-java</artifactId> </dependency> diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java index 1d3da73a509..6af33e29e62 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java @@ -101,11 +101,9 @@ public class FunctionEvaluator { } public Tensor evaluate() { - function.argumentTypes().keySet().stream().sorted() - .forEach(name -> { - var type = function.argumentTypes().get(name); - checkArgument(name, type); - }); + for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) { + checkArgument(argument.getKey(), argument.getValue()); + } evaluated = true; evaluateOnnxModels(); return function.getBody().evaluate(context).asTensor(); diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java index 81325740218..d030108a17a 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java @@ -1,7 +1,8 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.models.evaluation; -import com.yahoo.lang.MutableInteger; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.ImmutableSet; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -13,7 +14,6 @@ import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -24,7 +24,6 @@ import java.util.List; import java.util.Map; import java.util.Optional; import java.util.Set; -import java.util.stream.Collectors; /** * An array context supporting functions invocations implemented as lazy values. @@ -152,16 +151,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { private static class IndexedBindings { /** The mapping from variable name to index */ - private final Map<String, Integer> nameToIndex; + private final ImmutableMap<String, Integer> nameToIndex; /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */ - private final Set<String> arguments; + private final ImmutableSet<String> arguments; /** The current values set */ private final Value[] values; /** ONNX models indexed by rank feature that calls them */ - private final Map<String, OnnxModel> onnxModels; + private final ImmutableMap<String, OnnxModel> onnxModels; /** The object instance which encodes "no value is set". The actual value of this is never used. */ private static final Value missing = new DoubleValue(Double.NaN).freeze(); @@ -170,14 +169,14 @@ public final class LazyArrayContext extends Context implements ContextIndex { private Value missingValue = new DoubleValue(Double.NaN).freeze(); - private IndexedBindings(Map<String, Integer> nameToIndex, + private IndexedBindings(ImmutableMap<String, Integer> nameToIndex, Value[] values, - Set<String> arguments, - Map<String, OnnxModel> onnxModels) { - this.nameToIndex = Map.copyOf(nameToIndex); + ImmutableSet<String> arguments, + ImmutableMap<String, OnnxModel> onnxModels) { + this.nameToIndex = nameToIndex; this.values = values; this.arguments = arguments; - this.onnxModels = Map.copyOf(onnxModels); + this.onnxModels = onnxModels; } /** @@ -196,14 +195,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { Map<String, OnnxModel> onnxModelsInUse = new HashMap<>(); extractBindTargets(function.getBody().getRoot(), referencedFunctions, bindTargets, arguments, onnxModels, onnxModelsInUse); - this.onnxModels = Map.copyOf(onnxModelsInUse); - this.arguments = Set.copyOf(arguments); + this.onnxModels = ImmutableMap.copyOf(onnxModelsInUse); + this.arguments = ImmutableSet.copyOf(arguments); values = new Value[bindTargets.size()]; Arrays.fill(values, missing); - MutableInteger nextIndex = new MutableInteger(0); - nameToIndex = Map.copyOf(bindTargets.stream() - .collect(CustomCollectors.toLinkedMap(name -> name, name -> nextIndex.next()))); + int i = 0; + ImmutableMap.Builder<String, Integer> nameToIndexBuilder = new ImmutableMap.Builder<>(); + for (String variable : bindTargets) + nameToIndexBuilder.put(variable, i++); + nameToIndex = nameToIndexBuilder.build(); // 2. Bind the bind targets for (Constant constant : constants) { @@ -251,7 +252,8 @@ public final class LazyArrayContext extends Context implements ContextIndex { bindTargets.add(node.toString()); arguments.add(node.toString()); } - else if (node instanceof CompositeNode cNode) { + else if (node instanceof CompositeNode) { + CompositeNode cNode = (CompositeNode)node; for (ExpressionNode child : cNode.children()) extractBindTargets(child, functions, bindTargets, arguments, onnxModels, onnxModelsInUse); } @@ -289,14 +291,16 @@ public final class LazyArrayContext extends Context implements ContextIndex { } private Optional<String> getArgument(ExpressionNode node) { - if (node instanceof ReferenceNode reference) { + if (node instanceof ReferenceNode) { + ReferenceNode reference = (ReferenceNode) node; if (reference.getArguments().size() > 0) { if (reference.getArguments().expressions().get(0) instanceof ConstantNode) { ExpressionNode constantNode = reference.getArguments().expressions().get(0); return Optional.of(stripQuotes(constantNode.toString())); } - if (reference.getArguments().expressions().get(0) instanceof ReferenceNode refNode) { - return Optional.of(refNode.getName()); + if (reference.getArguments().expressions().get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) reference.getArguments().expressions().get(0); + return Optional.of(referenceNode.getName()); } } } @@ -312,17 +316,20 @@ public final class LazyArrayContext extends Context implements ContextIndex { } private boolean isFunctionReference(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode reference)) return false; + if ( ! (node instanceof ReferenceNode)) return false; + ReferenceNode reference = (ReferenceNode)node; return reference.getName().equals("rankingExpression") && reference.getArguments().size() == 1; } private boolean isOnnx(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode reference)) return false; + if ( ! (node instanceof ReferenceNode)) return false; + ReferenceNode reference = (ReferenceNode) node; return reference.getName().equals("onnx") || reference.getName().equals("onnxModel"); } private boolean isConstant(ExpressionNode node) { - if ( ! (node instanceof ReferenceNode reference)) return false; + if ( ! (node instanceof ReferenceNode)) return false; + ReferenceNode reference = (ReferenceNode)node; return reference.getName().equals("constant") && reference.getArguments().size() == 1; } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index ffcfb5e9379..1ecec4108a3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -2,16 +2,15 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; +import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; -import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.TensorType; import java.util.Arrays; import java.util.Collection; -import java.util.HashMap; -import java.util.LinkedHashMap; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.stream.Collectors; @@ -36,10 +35,10 @@ public class Model { private final List<ExpressionFunction> publicFunctions; /** Instances of each usage of the above function, where variables (if any) are replaced by their bindings */ - private final Map<FunctionReference, ExpressionFunction> referencedFunctions; + private final ImmutableMap<FunctionReference, ExpressionFunction> referencedFunctions; /** Context prototypes, indexed by function name (as all invocations of the same function share the same context prototype) */ - private final Map<String, LazyArrayContext> contextPrototypes; + private final ImmutableMap<String, LazyArrayContext> contextPrototypes; private final ExpressionOptimizer expressionOptimizer = new ExpressionOptimizer(); @@ -47,9 +46,9 @@ public class Model { public Model(String name, Collection<ExpressionFunction> functions) { this(name, functions.stream().collect(Collectors.toMap(f -> FunctionReference.fromName(f.getName()), f -> f)), - Map.of(), - List.of(), - List.of()); + Collections.emptyMap(), + Collections.emptyList(), + Collections.emptyList()); } Model(String name, @@ -60,7 +59,7 @@ public class Model { this.name = name; // Build context and add missing function arguments (missing because it is legal to omit scalar type arguments) - Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>(); + ImmutableMap.Builder<String, LazyArrayContext> contextBuilder = new ImmutableMap.Builder<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { LazyArrayContext context = new LazyArrayContext(function.getValue(), referencedFunctions, constants, onnxModels, this); @@ -93,14 +92,19 @@ public class Model { throw new IllegalArgumentException("Could not prepare an evaluation context for " + function, e); } } - this.contextPrototypes = Map.copyOf(contextBuilder); + this.contextPrototypes = contextBuilder.build(); this.functions = List.copyOf(functions.values()); this.publicFunctions = functions.values().stream() .filter(f -> !f.getName().startsWith(INTERMEDIATE_OPERATION_FUNCTION_PREFIX)).toList(); // Optimize functions - this.referencedFunctions = Map.copyOf(referencedFunctions.entrySet().stream() - .collect(CustomCollectors.toLinkedMap(f -> f.getKey(), f -> optimize(f.getValue(), contextPrototypes.get(f.getKey().functionName()))))); + ImmutableMap.Builder<FunctionReference, ExpressionFunction> functionsBuilder = new ImmutableMap.Builder<>(); + for (Map.Entry<FunctionReference, ExpressionFunction> function : referencedFunctions.entrySet()) { + ExpressionFunction optimizedFunction = optimize(function.getValue(), + contextPrototypes.get(function.getKey().functionName())); + functionsBuilder.put(function.getKey(), optimizedFunction); + } + this.referencedFunctions = functionsBuilder.build(); } /** Returns an optimized version of the given function */ @@ -138,7 +142,7 @@ public class Model { } /** Returns an immutable map of the referenced function instances of this */ - Map<FunctionReference, ExpressionFunction> referencedFunctions() { return Map.copyOf(referencedFunctions); } + Map<FunctionReference, ExpressionFunction> referencedFunctions() { return referencedFunctions; } /** Returns the given referred function, or throws a IllegalArgumentException if it does not exist */ ExpressionFunction requireReferencedFunction(FunctionReference reference) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java index 40a503e0212..88843fd99ab 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java @@ -2,6 +2,7 @@ package ai.vespa.models.evaluation; import com.yahoo.api.annotations.Beta; +import com.google.common.collect.ImmutableMap; import com.yahoo.component.annotation.Inject; import com.yahoo.component.AbstractComponent; import com.yahoo.filedistribution.fileacquirer.FileAcquirer; @@ -22,7 +23,7 @@ import java.util.Map; @Beta public class ModelsEvaluator extends AbstractComponent { - private final Map<String, Model> models; + private final ImmutableMap<String, Model> models; @Inject public ModelsEvaluator(RankProfilesConfig config, @@ -42,7 +43,7 @@ public class ModelsEvaluator extends AbstractComponent { } public ModelsEvaluator(Map<String, Model> models) { - this.models = Map.copyOf(models); + this.models = ImmutableMap.copyOf(models); } /** Returns the models of this as an immutable map */ diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java index 78addf0328a..2661b9c2eb2 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java +++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java @@ -102,8 +102,9 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler { private HttpResponse listAllModels(HttpRequest request) { Slime slime = new Slime(); Cursor root = slime.setObject(); - modelsEvaluator.models().keySet().stream().sorted() - .forEach(name -> root.setString(name, baseUrl(request) + name)); + for (String modelName: modelsEvaluator.models().keySet()) { + root.setString(modelName, baseUrl(request) + modelName); + } return new Response(200, com.yahoo.slime.JsonFormat.toJsonBytes(slime)); } diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java index f09bac63085..c4e859bec9f 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java @@ -75,7 +75,7 @@ public class ModelsEvaluatorTest { evaluator.evaluate(); } catch (IllegalStateException e) { - assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])", + assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", Exceptions.toMessageString(e)); } @@ -88,15 +88,6 @@ public class ModelsEvaluatorTest { assertEquals("Argument 'arg1' must be bound to a value of type tensor(d0[1])", Exceptions.toMessageString(e)); } - try { // Just the other binding - FunctionEvaluator evaluator = model.evaluatorOf("test"); - evaluator.bind("arg1", Tensor.from(TensorType.fromSpec("tensor(d0[1])"), "{{d0:0}:0.1}")); - evaluator.evaluate(); - } - catch (IllegalStateException e) { - assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})", - Exceptions.toMessageString(e)); - } try { // Wrong binding argument FunctionEvaluator evaluator = model.evaluatorOf("test"); diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java index 3b16be311a0..fc05a9936a9 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/HandlerTester.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.serialization.JsonFormat; import java.io.ByteArrayOutputStream; import java.io.IOException; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.Map; import java.util.concurrent.Executors; @@ -24,19 +25,19 @@ class HandlerTester { } void assertResponse(String url, int expectedCode) { - assertResponse(url, Map.of(), expectedCode, (String)null); + assertResponse(url, Collections.emptyMap(), expectedCode, (String)null); } void assertResponse(String url, int expectedCode, String expectedResult) { - assertResponse(url, Map.of(), expectedCode, expectedResult); + assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult); } void assertResponse(String url, int expectedCode, String expectedResult, Map<String, String> headers) { - assertResponse(url, Map.of(), expectedCode, expectedResult, headers); + assertResponse(url, Collections.emptyMap(), expectedCode, expectedResult, headers); } void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult) { - assertResponse(url, properties, expectedCode, expectedResult, Map.of()); + assertResponse(url, properties, expectedCode, expectedResult, Collections.emptyMap()); } void assertResponse(String url, Map<String, String> properties, int expectedCode, String expectedResult, Map<String, String> headers) { diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index d804e50c67d..0de8ce5f061 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -49,30 +49,16 @@ public class ModelsEvaluationHandlerTest { @Test public void testListModels() { String url = "http://localhost/model-evaluation/v1"; - String expected = "{" + - "\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"," + - "\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\"," + - "\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\"," + - "\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\"," + - "\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\"," + - "\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\"," + - "\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\"" + - "}"; + String expected = + "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}"; handler.assertResponse(url, 200, expected); } @Test public void testListModelsWithDifferentHost() { String url = "http://localhost/model-evaluation/v1"; - String expected = "{" + - "\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"," + - "\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\"," + - "\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\"," + - "\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\"," + - "\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\"," + - "\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\"," + - "\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\"" + - "}"; + String expected = + "{\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\",\"xgboost_non_standalone\":\"http://localhost:8088/model-evaluation/v1/xgboost_non_standalone\",\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"}"; handler.assertResponse(url, 200, expected, Map.of("Host", "localhost:8088")); } |