From 00e7d63e41842231528343a6e80ede595d997ff5 Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 1 Dec 2022 07:36:44 +0100 Subject: - Reduce usage of guava. - Ensure that tests relying on order are determinsitic. --- .../rankingexpression/importer/ImportedModel.java | 22 ++++++++++------------ .../importer/vespa/VespaImportTestCase.java | 12 ++++++++---- 2 files changed, 18 insertions(+), 16 deletions(-) (limited to 'model-integration/src') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java index 35c409a637c..8c55e6793c0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -1,20 +1,19 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer; -import com.google.common.collect.ImmutableMap; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import com.yahoo.config.application.api.ApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import java.io.File; import java.io.IOException; import java.util.ArrayList; -import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; @@ -73,7 +72,7 @@ public class ImportedModel implements ImportedMlModel { public String toString() { return "imported model '" + name + "' from " + source; } /** Returns an immutable map of the inputs of this */ - public Map inputs() { return Collections.unmodifiableMap(inputs); } + public Map inputs() { return Map.copyOf(inputs); } @Override public Optional inputTypeSpec(String input) { @@ -121,7 +120,7 @@ public class ImportedModel implements ImportedMlModel { * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants). * Note that only nodes recursively referenced by a placeholder/input are added. */ - public Map expressions() { return Collections.unmodifiableMap(expressions); } + public Map expressions() { return Map.copyOf(expressions); } /** * Returns an immutable map of the functions that are part of this model. @@ -130,7 +129,7 @@ public class ImportedModel implements ImportedMlModel { public Map functions() { return asExpressionStrings(functions); } /** Returns an immutable map of the signatures of this */ - public Map signatures() { return Collections.unmodifiableMap(signatures); } + public Map signatures() { return Map.copyOf(signatures); } /** Returns the given signature. If it does not already exist it is added to this. */ public Signature signature(String name) { @@ -270,30 +269,29 @@ public class ImportedModel implements ImportedMlModel { * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name * in this signature to input name in the owning model */ - public Map inputs() { return Collections.unmodifiableMap(inputs); } + public Map inputs() { return Map.copyOf(inputs); } /** Returns the name and type of all inputs in this signature as an immutable map */ Map inputMap() { - ImmutableMap.Builder inputs = new ImmutableMap.Builder<>(); // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* // in the model, as these are the names which must actually be bound, if we are to avoid creating an // "input mapping" to accommodate this complexity - for (Map.Entry inputEntry : inputs().entrySet()) - inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); - return inputs.build(); + return Map.copyOf(inputs.entrySet() + .stream() + .collect(CustomCollectors.toLinkedMap(Map.Entry::getValue, e -> owner().inputs.get(e.getValue())))); } /** Returns the type of the input this input references */ public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); } /** Returns an immutable list of the expression names of this */ - public Map outputs() { return Collections.unmodifiableMap(outputs); } + public Map outputs() { return Map.copyOf(outputs); } /** * Returns an immutable list of the outputs of this which could not be imported, * with a string detailing the reason for each */ - public Map skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + public Map skippedOutputs() { return Map.copyOf(skippedOutputs); } /** Returns the expression this output references as an imported function */ public ImportedMlFunction outputFunction(String outputName, String functionName) { diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java index d9c7e67c946..75e31d66e5b 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java @@ -10,7 +10,8 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import org.junit.Test; -import java.util.List; +import java.util.Map; +import java.util.stream.Collectors; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -51,9 +52,12 @@ public class VespaImportTestCase { assertEquals("reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * constant2", model.expressions().get("foo2").getRoot().toString()); - List functions = model.outputExpressions(); - assertEquals(2, functions.size()); - ImportedMlFunction foo1Function = functions.get(0); + Map byName = model.outputExpressions().stream() + .collect(Collectors.toUnmodifiableMap(ImportedMlFunction::name, f -> f)); + assertEquals(2, byName.size()); + assertTrue(byName.containsKey("foo1")); + assertTrue(byName.containsKey("foo2")); + ImportedMlFunction foo1Function = byName.get("foo1"); assertEquals("foo1", foo1Function.name()); assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", foo1Function.expression()); assertEquals("tensor():{202.5}", evaluate(foo1Function, "{{name:a, x:0}: 1, {name:a, x:1}: 2, {name:a, x:2}: 3}").toString()); -- cgit v1.2.3