summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 07:36:44 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2022-12-01 07:36:44 +0100
commit00e7d63e41842231528343a6e80ede595d997ff5 (patch)
treed611749f67d8ac3201b1a39b516339755715f236 /model-integration
parentc42b104ac2a231cb120719dd904d5ad2ac31fbeb (diff)
- Reduce usage of guava.
- Ensure that tests relying on order are determinsitic.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/pom.xml23
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java12
3 files changed, 29 insertions, 28 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml
index 7d3ab3f7a5f..63232b61106 100644
--- a/model-integration/pom.xml
+++ b/model-integration/pom.xml
@@ -16,12 +16,6 @@
<packaging>container-plugin</packaging>
<dependencies>
<dependency>
- <groupId>junit</groupId>
- <artifactId>junit</artifactId>
- <scope>test</scope>
- </dependency>
-
- <dependency>
<groupId>com.yahoo.vespa</groupId>
<artifactId>annotations</artifactId>
<version>${project.version}</version>
@@ -59,12 +53,6 @@
</dependency>
<dependency>
- <groupId>com.google.guava</groupId>
- <artifactId>guava</artifactId>
- <scope>provided</scope>
- </dependency>
-
- <dependency>
<groupId>com.microsoft.onnxruntime</groupId>
<artifactId>onnxruntime</artifactId>
</dependency>
@@ -72,6 +60,17 @@
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
</dependency>
+
+ <dependency>
+ <groupId>junit</groupId>
+ <artifactId>junit</artifactId>
+ <scope>test</scope>
+ </dependency>
+ <dependency>
+ <groupId>com.google.guava</groupId>
+ <artifactId>guava</artifactId>
+ <scope>test</scope>
+ </dependency>
</dependencies>
<build>
<plugins>
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
index 35c409a637c..8c55e6793c0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java
@@ -1,20 +1,19 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
-import com.google.common.collect.ImmutableMap;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import com.yahoo.config.application.api.ApplicationPackage;
import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.io.File;
import java.io.IOException;
import java.util.ArrayList;
-import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
@@ -73,7 +72,7 @@ public class ImportedModel implements ImportedMlModel {
public String toString() { return "imported model '" + name + "' from " + source; }
/** Returns an immutable map of the inputs of this */
- public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
+ public Map<String, TensorType> inputs() { return Map.copyOf(inputs); }
@Override
public Optional<String> inputTypeSpec(String input) {
@@ -121,7 +120,7 @@ public class ImportedModel implements ImportedMlModel {
* which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants).
* Note that only nodes recursively referenced by a placeholder/input are added.
*/
- public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
+ public Map<String, RankingExpression> expressions() { return Map.copyOf(expressions); }
/**
* Returns an immutable map of the functions that are part of this model.
@@ -130,7 +129,7 @@ public class ImportedModel implements ImportedMlModel {
public Map<String, String> functions() { return asExpressionStrings(functions); }
/** Returns an immutable map of the signatures of this */
- public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
+ public Map<String, Signature> signatures() { return Map.copyOf(signatures); }
/** Returns the given signature. If it does not already exist it is added to this. */
public Signature signature(String name) {
@@ -270,30 +269,29 @@ public class ImportedModel implements ImportedMlModel {
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
* in this signature to input name in the owning model
*/
- public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+ public Map<String, String> inputs() { return Map.copyOf(inputs); }
/** Returns the name and type of all inputs in this signature as an immutable map */
Map<String, TensorType> inputMap() {
- ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
// Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
// in the model, as these are the names which must actually be bound, if we are to avoid creating an
// "input mapping" to accommodate this complexity
- for (Map.Entry<String, String> inputEntry : inputs().entrySet())
- inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
- return inputs.build();
+ return Map.copyOf(inputs.entrySet()
+ .stream()
+ .collect(CustomCollectors.toLinkedMap(Map.Entry::getValue, e -> owner().inputs.get(e.getValue()))));
}
/** Returns the type of the input this input references */
public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
- public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
+ public Map<String, String> outputs() { return Map.copyOf(outputs); }
/**
* Returns an immutable list of the outputs of this which could not be imported,
* with a string detailing the reason for each
*/
- public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+ public Map<String, String> skippedOutputs() { return Map.copyOf(skippedOutputs); }
/** Returns the expression this output references as an imported function */
public ImportedMlFunction outputFunction(String outputName, String functionName) {
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
index d9c7e67c946..75e31d66e5b 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/vespa/VespaImportTestCase.java
@@ -10,7 +10,8 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import org.junit.Test;
-import java.util.List;
+import java.util.Map;
+import java.util.stream.Collectors;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -51,9 +52,12 @@ public class VespaImportTestCase {
assertEquals("reduce(reduce(input1 * input2, sum, name) * constant(constant1asLarge), max, x) * constant2",
model.expressions().get("foo2").getRoot().toString());
- List<ImportedMlFunction> functions = model.outputExpressions();
- assertEquals(2, functions.size());
- ImportedMlFunction foo1Function = functions.get(0);
+ Map<String, ImportedMlFunction> byName = model.outputExpressions().stream()
+ .collect(Collectors.toUnmodifiableMap(ImportedMlFunction::name, f -> f));
+ assertEquals(2, byName.size());
+ assertTrue(byName.containsKey("foo1"));
+ assertTrue(byName.containsKey("foo2"));
+ ImportedMlFunction foo1Function = byName.get("foo1");
assertEquals("foo1", foo1Function.name());
assertEquals("reduce(reduce(input1 * input2, sum, name) * constant1, max, x) * constant2", foo1Function.expression());
assertEquals("tensor():{202.5}", evaluate(foo1Function, "{{name:a, x:0}: 1, {name:a, x:1}: 2, {name:a, x:2}: 3}").toString());