summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-11-22 14:27:58 +0100
committerJon Bratseth <bratseth@oath.com>2018-11-22 14:27:58 +0100
commitb288e61f7af7331656a1850fbdc58cc95fd1bbad (patch)
tree9d41fa770d2890585a902f41a89c41040ed764be /searchlib
parent3c4020645b13be560c14e60969e50e3ad41e3d3c (diff)
Move all importing to model-integration
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/pom.xml6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java226
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java109
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java243
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java235
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java57
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java108
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java89
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java61
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java106
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java35
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java191
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java120
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java37
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java72
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java113
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java35
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java26
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java48
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java131
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java88
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java54
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java85
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java48
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java21
29 files changed, 6 insertions, 2713 deletions
diff --git a/searchlib/pom.xml b/searchlib/pom.xml
index e0ce822e593..87058f8dfa1 100644
--- a/searchlib/pom.xml
+++ b/searchlib/pom.xml
@@ -35,6 +35,12 @@
<version>${project.version}</version>
</dependency>
<dependency>
+ <groupId>com.yahoo.vespa</groupId>
+ <artifactId>config-model-api</artifactId>
+ <version>${project.version}</version>
+ <scope>provided</scope>
+ </dependency>
+ <dependency>
<groupId>com.fasterxml.jackson.core</groupId>
<artifactId>jackson-core</artifactId>
<scope>provided</scope>
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java
deleted file mode 100644
index 86c4c287f05..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java
+++ /dev/null
@@ -1,210 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation;
-
-import java.util.ArrayDeque;
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.Deque;
-import java.util.HashMap;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Map;
-import java.util.Objects;
-import java.util.Optional;
-
-/**
- * A constraint satisfier to find suitable dimension names to reduce the
- * amount of necessary renaming during evaluation of an imported model.
- *
- * @author lesters
- */
-public class DimensionRenamer {
-
- private final String dimensionPrefix;
- private final Map<String, List<Integer>> variables = new HashMap<>();
- private final Map<Arc, Constraint> constraints = new HashMap<>();
- private final Map<String, Integer> renames = new HashMap<>();
-
- private int iterations = 0;
-
- public DimensionRenamer() {
- this("d");
- }
-
- public DimensionRenamer(String dimensionPrefix) {
- this.dimensionPrefix = dimensionPrefix;
- }
-
- /**
- * Add a dimension name variable.
- */
- public void addDimension(String name) {
- variables.computeIfAbsent(name, d -> new ArrayList<>());
- }
-
- /**
- * Add a constraint between dimension names.
- */
- public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) {
- Arc arc = new Arc(from, to, operation);
- Arc opposite = arc.opposite();
- constraints.put(arc, pred);
- constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
- }
-
- /**
- * Retrieve resulting name of dimension after solving for constraints.
- */
- public Optional<String> dimensionNameOf(String name) {
- if (!renames.containsKey(name)) {
- return Optional.empty();
- }
- return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
- }
-
- /**
- * Perform iterative arc consistency until we have found a solution. After
- * an initial iteration, the variables (dimensions) will have multiple
- * valid values. Find a single valid assignment by iteratively locking one
- * dimension after another, and running the arc consistency algorithm
- * multiple times.
- *
- * This requires having constraints that result in an absolute ordering:
- * equals, lesserThan and greaterThan do that, but adding notEquals does
- * not typically result in a guaranteed ordering. If that is needed, the
- * algorithm below needs to be adapted with a backtracking (tree) search
- * to find solutions.
- */
- public void solve(int maxIterations) {
- initialize();
-
- // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
-
- for (String dimension : variables.keySet()) {
- List<Integer> values = variables.get(dimension);
- if (values.size() > 1) {
- if (!ac3()) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
- }
- values.sort(Integer::compare);
- variables.put(dimension, Collections.singletonList(values.get(0)));
- }
- renames.put(dimension, variables.get(dimension).get(0));
- if (iterations > maxIterations) {
- throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
- maxIterations + " iterations");
- }
- }
-
- // Todo: handle failure more gracefully:
- // If a solution can't be found, look at the operation node in the arc
- // with the most remaining constraints, and inject a rename operation.
- // Then run this algorithm again.
- }
-
- public void solve() {
- solve(100000);
- }
-
- private void initialize() {
- for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
- List<Integer> values = variable.getValue();
- for (int i = 0; i < variables.size(); ++i) {
- values.add(i); // invariant: values are in increasing order
- }
- }
- }
-
- private boolean ac3() {
- Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
- while (!workList.isEmpty()) {
- Arc arc = workList.pop();
- iterations += 1;
- if (revise(arc)) {
- if (variables.get(arc.from).size() == 0) {
- return false; // no solution found
- }
- for (Arc constraint : constraints.keySet()) {
- if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
- workList.add(constraint);
- }
- }
- }
- }
- return true;
- }
-
- private boolean revise(Arc arc) {
- boolean revised = false;
- for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
- Integer from = fromIterator.next();
- boolean satisfied = false;
- for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
- Integer to = toIterator.next();
- if (constraints.get(arc).test(from, to)) {
- satisfied = true;
- }
- }
- if (!satisfied) {
- fromIterator.remove();
- revised = true;
- }
- }
- return revised;
- }
-
- public interface Constraint {
- boolean test(Integer x, Integer y);
- }
-
- public static boolean equals(Integer x, Integer y) {
- return Objects.equals(x, y);
- }
-
- public static boolean lesserThan(Integer x, Integer y) {
- return x < y;
- }
-
- public static boolean greaterThan(Integer x, Integer y) {
- return x > y;
- }
-
- private static class Arc {
-
- private final String from;
- private final String to;
- private final IntermediateOperation operation;
-
- Arc(String from, String to, IntermediateOperation operation) {
- this.from = from;
- this.to = to;
- this.operation = operation;
- }
-
- Arc opposite() {
- return new Arc(to, from, operation);
- }
-
- @Override
- public int hashCode() {
- return Objects.hash(from, to);
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof Arc)) {
- return false;
- }
- Arc other = (Arc) obj;
- return Objects.equals(from, other.from) && Objects.equals(to, other.to);
- }
-
- @Override
- public String toString() {
- return String.format("%s -> %s", from, to);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
deleted file mode 100644
index e66d0ab6f35..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ /dev/null
@@ -1,226 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.collections.Pair;
-import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.HashMap;
-import java.util.LinkedHashMap;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.regex.Pattern;
-
-/**
- * The result of importing an ML model into Vespa.
- *
- * @author bratseth
- */
-public class ImportedModel {
-
- private static final String defaultSignatureName = "default";
-
- private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*");
- private final String name;
- private final String source;
-
- private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> inputs = new HashMap<>();
- private final Map<String, Tensor> smallConstants = new HashMap<>();
- private final Map<String, Tensor> largeConstants = new HashMap<>();
- private final Map<String, RankingExpression> expressions = new HashMap<>();
- private final Map<String, RankingExpression> functions = new HashMap<>();
-
- /**
- * Creates a new imported model.
- *
- * @param name the name of this mode, containing only characters in [A-Za-z0-9_]
- * @param source the source path (directory or file) of this model
- */
- public ImportedModel(String name, String source) {
- if ( ! nameRegexp.matcher(name).matches())
- throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'");
- this.name = name;
- this.source = source;
- }
-
- /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
- public String name() { return name; }
-
- /** Returns the source path (directory or file) of this model */
- public String source() { return source; }
-
- /** Returns an immutable map of the inputs of this */
- public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
-
- /**
- * Returns an immutable map of the small constants of this.
- * These should have sizes up to a few kb at most, and correspond to constant values given in the source model.
- */
- public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); }
-
- /**
- * Returns an immutable map of the large constants of this.
- * These can have sizes in gigabytes and must be distributed to nodes separately from configuration.
- * For TensorFlow this corresponds to Variable files stored separately.
- */
- public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); }
-
- /**
- * Returns an immutable map of the expressions of this - corresponding to graph nodes
- * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants).
- * Note that only nodes recursively referenced by a placeholder/input are added.
- */
- public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
-
- /**
- * Returns an immutable map of the functions that are part of this model.
- * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification.
- */
- public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); }
-
- /** Returns an immutable map of the signatures of this */
- public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
-
- /** Returns the given signature. If it does not already exist it is added to this. */
- public Signature signature(String name) {
- return signatures.computeIfAbsent(name, Signature::new);
- }
-
- /** Convenience method for returning a default signature */
- public Signature defaultSignature() { return signature(defaultSignatureName); }
-
- public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
- public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
- public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
- public void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
- public void function(String name, RankingExpression expression) { functions.put(name, expression); }
-
- /**
- * Returns all the output expressions of this indexed by name. The names consist of one or two parts
- * separated by dot, where the first part is the signature name
- * if signatures are used, or the expression name if signatures are not used and there are multiple
- * expressions, and the second is the output name if signature names are used.
- */
- public List<Pair<String, ExpressionFunction>> outputExpressions() {
- List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
- for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
- for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
- expressions.add(new Pair<>(signatureEntry.getKey() + "" + outputEntry.getKey(),
- signatureEntry.getValue().outputExpression(outputEntry.getKey())
- .withName(signatureEntry.getKey() + "" + outputEntry.getKey())));
- if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
- expressions.add(new Pair<>(signatureEntry.getKey(),
- new ExpressionFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().values()),
- expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap(),
- Optional.empty())));
- }
- if (signatures().isEmpty()) { // fallback for models without signatures
- if (expressions().size() == 1) {
- Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- expressions.add(new Pair<>(singleEntry.getKey(),
- new ExpressionFunction(singleEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- singleEntry.getValue(),
- inputs,
- Optional.empty())));
- }
- else {
- for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- expressions.add(new Pair<>(expressionEntry.getKey(),
- new ExpressionFunction(expressionEntry.getKey(),
- new ArrayList<>(inputs.keySet()),
- expressionEntry.getValue(),
- inputs,
- Optional.empty())));
- }
- }
- }
- return expressions;
- }
-
- /**
- * A signature is a set of named inputs and outputs, where the inputs maps to input
- * ("placeholder") names+types, and outputs maps to expressions nodes.
- * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
- * concept of signatures. For now, we handle ONNX models as having a single signature.
- */
- public class Signature {
-
- private final String name;
- private final Map<String, String> inputs = new LinkedHashMap<>();
- private final Map<String, String> outputs = new LinkedHashMap<>();
- private final Map<String, String> skippedOutputs = new HashMap<>();
- private final List<String> importWarnings = new ArrayList<>();
-
- public Signature(String name) {
- this.name = name;
- }
-
- public String name() { return name; }
-
- /** Returns the result this is part of */
- public ImportedModel owner() { return ImportedModel.this; }
-
- /**
- * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
- * in this signature to input name in the owning model
- */
- public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
-
- /** Returns the name and type of all inputs in this signature as an immutable map */
- public Map<String, TensorType> inputMap() {
- ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
- // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
- // in the model, as these are the names which must actually be bound, if we are to avoid creating an
- // "input mapping" to accomodate this complexity in
- for (Map.Entry<String, String> inputEntry : inputs().entrySet())
- inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
- return inputs.build();
- }
-
- /** Returns the type of the input this input references */
- public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); }
-
- /** Returns an immutable list of the expression names of this */
- public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
-
- /**
- * Returns an immutable list of the outputs of this which could not be imported,
- * with a string detailing the reason for each
- */
- public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
-
- /**
- * Returns an immutable list of possibly non-fatal warnings encountered during import.
- */
- public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Returns the expression this output references */
- public ExpressionFunction outputExpression(String outputName) {
- return new ExpressionFunction(outputName,
- new ArrayList<>(inputs.values()),
- owner().expressions().get(outputs.get(outputName)),
- inputMap(),
- Optional.empty());
- }
-
- @Override
- public String toString() { return "signature '" + name + "'"; }
-
- void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
- void output(String name, String expressionName) { outputs.put(name, expressionName); }
- void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
- void importWarning(String warning) { importWarnings.add(warning); }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
deleted file mode 100644
index 896cd2e8d21..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java
+++ /dev/null
@@ -1,109 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.google.common.collect.ImmutableMap;
-import com.yahoo.path.Path;
-
-import java.io.File;
-import java.util.Arrays;
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Optional;
-
-/**
- * All models imported from the models/ directory in the application package.
- * If this is empty it may be due to either not having any models in the application package,
- * or this being created for a ZooKeeper application package, which does not have imported models.
- *
- * @author bratseth
- */
-public class ImportedModels {
-
- /** All imported models, indexed by their names */
- private final ImmutableMap<String, ImportedModel> importedModels;
-
- /** Create a null imported models */
- public ImportedModels() {
- importedModels = ImmutableMap.of();
- }
-
- public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) {
- Map<String, ImportedModel> models = new HashMap<>();
-
- // Find all subdirectories recursively which contains a model we can read
- importRecursively(modelsDirectory, models, importers);
- importedModels = ImmutableMap.copyOf(models);
- }
-
- private static void importRecursively(File dir,
- Map<String, ImportedModel> models,
- Collection<ModelImporter> importers) {
- if ( ! dir.isDirectory()) return;
-
- Arrays.stream(dir.listFiles()).sorted().forEach(child -> {
- Optional<ModelImporter> importer = findImporterOf(child, importers);
- if (importer.isPresent()) {
- String name = toName(child);
- ImportedModel existing = models.get(name);
- if (existing != null)
- throw new IllegalArgumentException("The models in " + child + " and " + existing.source() +
- " both resolve to the model name '" + name + "'");
- models.put(name, importer.get().importModel(name, child));
- }
- else {
- importRecursively(child, models, importers);
- }
- });
- }
-
- private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) {
- return importers.stream().filter(item -> item.canImport(path.toString())).findFirst();
- }
-
- /**
- * Returns the model at the given location in the application package.
- *
- * @param modelPath the path to this model (file or directory, depending on model type)
- * under the application package, both from the root or relative to the
- * models directory works
- * @return the model at this path or null if none
- */
- public ImportedModel get(File modelPath) {
- return importedModels.get(toName(modelPath));
- }
-
- public ImportedModel get(String modelName) {
- return importedModels.get(modelName);
- }
-
- /** Returns an immutable collection of all the imported models */
- public Collection<ImportedModel> all() {
- return importedModels.values();
- }
-
- private static String toName(File modelFile) {
- Path modelPath = Path.fromString(modelFile.toString());
- if (modelFile.isFile())
- modelPath = stripFileEnding(modelPath);
- String localPath = concatenateAfterModelsDirectory(modelPath);
- return localPath.replace('.', '_');
- }
-
- private static Path stripFileEnding(Path path) {
- int dotIndex = path.last().lastIndexOf("");
- if (dotIndex <= 0) return path;
- return path.withLast(path.last().substring(0, dotIndex));
- }
-
- private static String concatenateAfterModelsDirectory(Path path) {
- boolean afterModels = false;
- StringBuilder result = new StringBuilder();
- for (String element : path.elements()) {
- if (afterModels) result.append(element).append("_");
- if (element.equals("models")) afterModels = true;
- }
- return result.substring(0, result.length()-1);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java
deleted file mode 100644
index 81c176707e5..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java
+++ /dev/null
@@ -1,107 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation;
-
-import java.util.Collection;
-import java.util.HashMap;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * Holds an intermediate representation of an imported model graph.
- * After this intermediate representation is constructed, it is used to
- * simplify and optimize the computational graph and then converted into the
- * final ImportedModel that holds the Vespa ranking expressions for the model.
- *
- * @author lesters
- */
-public class IntermediateGraph {
-
- private final String modelName;
- private final Map<String, IntermediateOperation> index = new HashMap<>();
- private final Map<String, GraphSignature> signatures = new HashMap<>();
-
- private static class GraphSignature {
- final Map<String, String> inputs = new HashMap<>();
- final Map<String, String> outputs = new HashMap<>();
- }
-
- public IntermediateGraph(String modelName) {
- this.modelName = modelName;
- }
-
- public String name() {
- return modelName;
- }
-
- public IntermediateOperation put(String key, IntermediateOperation operation) {
- return index.put(key, operation);
- }
-
- public IntermediateOperation get(String key) {
- return index.get(key);
- }
-
- public Set<String> signatures() {
- return signatures.keySet();
- }
-
- public Map<String, String> inputs(String signature) {
- return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
- }
-
- public Map<String, String> outputs(String signature) {
- return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
- }
-
- public String defaultSignature() {
- return "default";
- }
-
- public boolean alreadyImported(String key) {
- return index.containsKey(key);
- }
-
- public Collection<IntermediateOperation> operations() {
- return index.values();
- }
-
- public void optimize() {
- renameDimensions();
- }
-
- /**
- * Find dimension names to avoid excessive renaming while evaluating the model.
- */
- private void renameDimensions() {
- DimensionRenamer renamer = new DimensionRenamer();
- for (String signature : signatures()) {
- for (String output : outputs(signature).values()) {
- addDimensionNameConstraints(index.get(output), renamer);
- }
- }
- renamer.solve();
- for (String signature : signatures()) {
- for (String output : outputs(signature).values()) {
- renameDimensions(index.get(output), renamer);
- }
- }
- }
-
- private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
- operation.addDimensionNameConstraints(renamer);
- }
- }
-
- private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
- if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
- operation.renameDimensions(renamer);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
deleted file mode 100644
index 47bf6a2a240..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ /dev/null
@@ -1,243 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Constant;
-import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation;
-import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-import com.yahoo.yolean.Exceptions;
-
-import java.io.File;
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.logging.Logger;
-
-/**
- * Base class for importing ML models (ONNX/TensorFlow) as native Vespa
- * ranking expressions. The general mechanism for import is for the
- * specific ML platform import implementations to create an
- * IntermediateGraph. This class offers common code to convert the
- * IntermediateGraph to Vespa ranking expressions and functions.
- *
- * @author lesters
- */
-public abstract class ModelImporter {
-
- private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
-
- /** Returns whether the file or directory at the given path is of the type which can be imported by this */
- public abstract boolean canImport(String modelPath);
-
- /** Imports the given model */
- public abstract ImportedModel importModel(String modelName, String modelPath);
-
- public final ImportedModel importModel(String modelName, File modelPath) {
- return importModel(modelName, modelPath.toString());
- }
-
- /**
- * Takes an IntermediateGraph and converts it to a ImportedModel containing
- * the actual Vespa ranking expressions.
- */
- public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
- ImportedModel model = new ImportedModel(graph.name(), modelSource);
-
- graph.optimize();
-
- importSignatures(graph, model);
- importExpressions(graph, model);
- reportWarnings(graph, model);
- logVariableTypes(graph);
-
- return model;
- }
-
- private static void importSignatures(IntermediateGraph graph, ImportedModel model) {
- for (String signatureName : graph.signatures()) {
- ImportedModel.Signature signature = model.signature(signatureName);
- for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) {
- signature.input(input.getKey(), input.getValue());
- }
- for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) {
- signature.output(output.getKey(), output.getValue());
- }
- }
- }
-
- private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String inputName : signature.inputs().values()) {
- if (inputName.equals(operation.name())) {
- return true;
- }
- }
- }
- return false;
- }
-
- private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- if (outputName.equals(operation.name())) {
- return true;
- }
- }
- }
- return false;
- }
-
- /**
- * Convert intermediate representation to Vespa ranking expressions.
- */
- static void importExpressions(IntermediateGraph graph, ImportedModel model) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- try {
- Optional<TensorFunction> function = importExpression(graph.get(outputName), model);
- if (!function.isPresent()) {
- signature.skippedOutput(outputName, "No valid output function could be found.");
- }
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
- }
- }
- }
-
- private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) {
- if (!operation.type().isPresent()) {
- return Optional.empty();
- }
- if (operation.isConstant()) {
- return importConstant(operation, model);
- }
- importExpressionInputs(operation, model);
- importRankingExpression(operation, model);
- importArgumentExpression(operation, model);
- importFunctionExpression(operation, model);
-
- return operation.function();
- }
-
- private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
- }
-
- private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
- String name = operation.vespaName();
- if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
- return operation.function();
- }
-
- Value value = operation.getConstantValue().orElseThrow(() ->
- new IllegalArgumentException("Operation '" + operation.vespaName() + "' " +
- "is constant but does not have a value."));
- if ( ! (value instanceof TensorValue)) {
- return operation.function(); // scalar values are inserted directly into the expression
- }
-
- Tensor tensor = value.asTensor();
- if (tensor.type().rank() == 0) {
- model.smallConstant(name, tensor);
- } else {
- model.largeConstant(name, tensor);
- }
- return operation.function();
- }
-
- private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.function().isPresent()) {
- String name = operation.name();
- if ( ! model.expressions().containsKey(name)) {
- TensorFunction function = operation.function().get();
-
- if (isSignatureOutput(model, operation)) {
- OrderedTensorType operationType = operation.type().get();
- OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType);
- if ( ! operationType.equals(standardNamingType)) {
- List<String> renameFrom = operationType.dimensionNames();
- List<String> renameTo = standardNamingType.dimensionNames();
- function = new Rename(function, renameFrom, renameTo);
- }
- }
-
- try {
- // We add all intermediate nodes imported as separate expressions. Only
- // those referenced from the output will be used. We parse the
- // TensorFunction here to convert it to a RankingExpression tree.
- model.expression(name, new RankingExpression(name, function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Imported function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
- }
-
- private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.isInput()) {
- // All inputs must have dimensions with standard naming convention: d0, d1, ...
- OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.input(operation.vespaName(), standardNamingConvention.type());
- }
- }
-
- private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) {
- if (operation.rankingExpressionFunction().isPresent()) {
- TensorFunction function = operation.rankingExpressionFunction().get();
- try {
- model.function(operation.rankingExpressionFunctionName(),
- new RankingExpression(operation.rankingExpressionFunctionName(),
- function.toString()));
- }
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
- }
- }
- }
-
- /**
- * Add any import warnings to the signature in the ImportedModel.
- */
- private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
- for (ImportedModel.Signature signature : model.signatures().values()) {
- for (String outputName : signature.outputs().values()) {
- reportWarnings(graph.get(outputName), model);
- }
- }
- }
-
- private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
- for (String warning : operation.warnings()) {
- model.defaultSignature().importWarning(warning);
- }
- for (IntermediateOperation input : operation.inputs()) {
- reportWarnings(input, model);
- }
- }
-
- /**
- * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
- * This allows users to learn the exact types (including dimension order after renaming) of the Variables
- * such that these can be converted and fed to a parent document independently of the rest of the model
- * for fast model weight updates.
- */
- private static void logVariableTypes(IntermediateGraph graph) {
- for (IntermediateOperation operation : graph.operations()) {
- if ( ! (operation instanceof Constant)) continue;
- if ( ! operation.type().isPresent()) continue; // will not happen
- log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
- " of type " + operation.type().get());
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java
deleted file mode 100644
index dfd073f82b4..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java
+++ /dev/null
@@ -1,235 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.TensorTypeParser;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * A Vespa tensor type is ordered by the lexicographical ordering of dimension
- * names. Imported tensors have an explicit ordering of their dimensions.
- * During import, we need to track the Vespa dimension that matches the
- * corresponding imported dimension as the ordering can change after
- * dimension renaming. That is the purpose of this class.
- *
- * @author lesters
- */
-public class OrderedTensorType {
-
- private final TensorType type;
- private final List<TensorType.Dimension> dimensions;
-
- private final long[] innerSizesOriginal;
- private final long[] innerSizesVespa;
- private final int[] dimensionMap;
-
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
- this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
- this.innerSizesOriginal = new long[dimensions.size()];
- this.innerSizesVespa = new long[dimensions.size()];
- this.dimensionMap = createDimensionMap();
- }
-
- public TensorType type() { return this.type; }
-
- public int rank() { return dimensions.size(); }
-
- public List<TensorType.Dimension> dimensions() {
- return dimensions;
- }
-
- public List<String> dimensionNames() {
- return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
- }
-
- private int[] createDimensionMap() {
- int numDimensions = dimensions.size();
- if (numDimensions == 0) {
- return null;
- }
- innerSizesOriginal[numDimensions - 1] = 1;
- innerSizesVespa[numDimensions - 1] = 1;
- for (int i = numDimensions - 1; --i >= 0; ) {
- innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1];
- innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
- }
- int[] mapping = new int[numDimensions];
- for (int i = 0; i < numDimensions; ++i) {
- TensorType.Dimension dim1 = dimensions().get(i);
- for (int j = 0; j < numDimensions; ++j) {
- TensorType.Dimension dim2 = type.dimensions().get(j);
- if (dim1.equals(dim2)) {
- mapping[i] = j;
- break;
- }
- }
- }
- return mapping;
- }
-
- public int dimensionMap(int originalIndex) {
- return dimensionMap[originalIndex];
- }
-
- /**
- * When dimension ordering between Vespa and imported differs, i.e.
- * after dimension renaming, use the dimension map to read in values
- * so that they are correctly laid out in memory for Vespa.
- * Used when importing tensors.
- */
- public int toDirectIndex(int index) {
- if (dimensions.size() == 0) {
- return 0;
- }
- if (dimensionMap == null) {
- throw new IllegalArgumentException("Dimension map is not available");
- }
- int directIndex = 0;
- long rest = index;
- for (int i = 0; i < dimensions.size(); ++i) {
- long address = rest / innerSizesOriginal[i];
- directIndex += innerSizesVespa[dimensionMap[i]] * address;
- rest %= innerSizesOriginal[i];
- }
- return directIndex;
- }
-
- @Override
- public boolean equals(Object obj) {
- if (obj == null || !(obj instanceof OrderedTensorType)) {
- return false;
- }
- OrderedTensorType other = (OrderedTensorType) obj;
- if (dimensions.size() != dimensions.size()) {
- return false;
- }
- List<TensorType.Dimension> thisDimensions = this.dimensions();
- List<TensorType.Dimension> otherDimensions = other.dimensions();
- for (int i = 0; i < thisDimensions.size(); ++i) {
- if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
- return false;
- }
- }
- return true;
- }
-
- public OrderedTensorType rename(DimensionRenamer renamer) {
- List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
- for (TensorType.Dimension dimension : dimensions) {
- String oldName = dimension.name();
- Optional<String> newName = renamer.dimensionNameOf(oldName);
- if (!newName.isPresent())
- return this; // presumably, already renamed
- TensorType.Dimension.Type dimensionType = dimension.type();
- if (dimensionType == TensorType.Dimension.Type.indexedBound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
- } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
- renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
- } else if (dimensionType == TensorType.Dimension.Type.mapped) {
- renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
- }
- }
- return new OrderedTensorType(renamedDimensions);
- }
-
- public OrderedTensorType rename(String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dimensions.size(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Optional<Long> dimSize = dimensions.get(i).size();
- if (dimSize.isPresent() && dimSize.get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- public static OrderedTensorType standardType(OrderedTensorType type) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < type.dimensions().size(); ++ i) {
- TensorType.Dimension dim = type.dimensions().get(i);
- String dimensionName = "d" + i;
- if (dim.size().isPresent() && dim.size().get() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- public static Long tensorSize(TensorType type) {
- Long size = 1L;
- for (TensorType.Dimension dimension : type.dimensions()) {
- size *= dimensionSize(dimension);
- }
- return size;
- }
-
- public static Long dimensionSize(TensorType.Dimension dim) {
- return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
- }
-
- /**
- * Returns a string representation of this: A standard tensor type string where dimensions
- * are listed in the order of this rather than in the natural order of their names.
- */
- @Override
- public String toString() {
- return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
- }
-
- /**
- * Creates an instance from the string representation of this: A standard tensor type string
- * where dimensions are listed in the order of this rather than the natural order of their names.
- */
- public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
- }
-
- public static OrderedTensorType fromDimensionList(List<Long> dims) {
- return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
- }
-
- public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dims.size(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
- if (dimSize >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- public static class Builder {
-
- private final List<TensorType.Dimension> dimensions;
-
- public Builder() {
- this.dimensions = new ArrayList<>();
- }
-
- public Builder add(TensorType.Dimension vespaDimension) {
- this.dimensions.add(vespaDimension);
- return this;
- }
-
- public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
- }
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java
deleted file mode 100644
index 57e4ec13e45..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java
+++ /dev/null
@@ -1,57 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.Collections;
-import java.util.List;
-
-public class Argument extends IntermediateOperation {
-
- private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
-
- public Argument(String modelName, String nodeName, OrderedTensorType type) {
- super(modelName, nodeName, Collections.emptyList());
- this.type = type.rename(vespaName() + "_");
- standardNamingType = OrderedTensorType.standardType(type);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return type;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
- if (!standardNamingType.equals(type)) {
- List<String> renameFrom = standardNamingType.dimensionNames();
- List<String> renameTo = type.dimensionNames();
- output = new Rename(output, renameFrom, renameTo);
- }
- return output;
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isInput() {
- return true;
- }
-
- @Override
- public boolean isConstant() {
- return false;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java
deleted file mode 100644
index 413b856e43d..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java
+++ /dev/null
@@ -1,108 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.Optional;
-
-public class ConcatV2 extends IntermediateOperation {
-
- private String concatDimensionName;
-
- public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
- return null;
- }
-
- IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
- if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a constant.");
- }
- Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
- if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a scalar.");
- }
-
- OrderedTensorType aType = inputs.get(0).type().get();
-
- int concatDim = (int)concatDimTensor.asDouble();
- long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L);
-
- for (int i = 1; i < inputs.size() - 1; ++i) {
- OrderedTensorType bType = inputs.get(i).type().get();
- if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "inputs must have save rank.");
- }
- for (int j = 0; j < aType.rank(); ++j) {
- long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
- long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
- if (j == concatDim) {
- concatDimSize += dimSizeB;
- } else if (dimSizeA != dimSizeB) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "input dimension " + j + " differs in input tensors.");
- }
- }
- }
-
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
- int dimensionIndex = 0;
- for (TensorType.Dimension dimension : aType.dimensions()) {
- if (dimensionIndex == concatDim) {
- concatDimensionName = dimension.name();
- typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize));
- } else {
- typeBuilder.add(dimension);
- }
- dimensionIndex++;
- }
- return typeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) {
- return null;
- }
- TensorFunction result = inputs.get(0).function().get();
- for (int i = 1; i < inputs.size() - 1; ++i) {
- TensorFunction b = inputs.get(i).function().get();
- result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
- }
- return result;
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
- return;
- }
- OrderedTensorType a = inputs.get(0).type().get();
- for (int i = 1; i < inputs.size() - 1; ++i) {
- OrderedTensorType b = inputs.get(i).type().get();
- String bDim = b.dimensions().get(i).name();
- String aDim = a.dimensions().get(i).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
- }
- }
-
- @Override
- public void renameDimensions(DimensionRenamer renamer) {
- super.renameDimensions(renamer);
- concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java
deleted file mode 100644
index 49c0bb712c5..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java
+++ /dev/null
@@ -1,89 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.Optional;
-
-public class Const extends IntermediateOperation {
-
- private final AttributeMap attributeMap;
-
- public Const(String modelName,
- String nodeName,
- List<IntermediateOperation> inputs,
- AttributeMap attributeMap,
- OrderedTensorType type) {
- super(modelName, nodeName, inputs);
- this.attributeMap = attributeMap;
- this.type = type.rename(vespaName() + "_");
- setConstantValue(value());
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return type;
- }
-
- @Override
- public Optional<TensorFunction> function() {
- if (function == null) {
- function = lazyGetFunction();
- }
- return Optional.ofNullable(function);
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- ExpressionNode expressionNode;
- if (type.type().rank() == 0 && getConstantValue().isPresent()) {
- expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
- } else {
- expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
- }
- return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
- }
-
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public void renameDimensions(DimensionRenamer renamer) {
- super.renameDimensions(renamer);
- setConstantValue(value());
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
- private Value value() {
- Optional<Value> value = attributeMap.get("value", type);
- if ( ! value.isPresent()) {
- throw new IllegalArgumentException("Node '" + name + "' of type " +
- "const has missing or non-recognized 'value' attribute");
- }
- return value.get();
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java
deleted file mode 100644
index 670274376a3..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java
+++ /dev/null
@@ -1,61 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.Collections;
-import java.util.Optional;
-
-public class Constant extends IntermediateOperation {
-
- private final String modelName;
-
- public Constant(String modelName, String nodeName, OrderedTensorType type) {
- super(modelName, nodeName, Collections.emptyList());
- this.modelName = modelName;
- this.type = type.rename(vespaName() + "_");
- }
-
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + vespaName(name);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return type;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null; // will be added by function() since this is constant.
- }
-
- /**
- * Constant values are sent in via the constantValueFunction, as the
- * dimension names and thus the data layout depends on the dimension
- * renaming which happens after the conversion to intermediate graph.
- */
- @Override
- public Optional<Value> getConstantValue() {
- return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type));
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java
deleted file mode 100644
index e037a9a2497..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java
+++ /dev/null
@@ -1,106 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-
-public class ExpandDims extends IntermediateOperation {
-
- private List<String> expandDimensions;
-
- public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
-
- IntermediateOperation axisOperation = inputs().get(1);
- if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis must be a constant.");
- }
- Tensor axis = axisOperation.getConstantValue().get().asTensor();
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis argument must be a scalar.");
- }
-
- OrderedTensorType inputType = inputs.get(0).type().get();
- int dimensionToInsert = (int)axis.asDouble();
- if (dimensionToInsert < 0) {
- dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
- }
-
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
- expandDimensions = new ArrayList<>();
- int dimensionIndex = 0;
- for (TensorType.Dimension dimension : inputType.dimensions()) {
- if (dimensionIndex == dimensionToInsert) {
- String name = String.format("%s_%d", vespaName(), dimensionIndex);
- expandDimensions.add(name);
- typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
- }
- typeBuilder.add(dimension);
- dimensionIndex++;
- }
-
- return typeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
-
- // multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (String name : expandDimensions) {
- typeBuilder.indexed(name, 1);
- }
- TensorType generatedType = typeBuilder.build();
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
- new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- @Override
- public void renameDimensions(DimensionRenamer renamer) {
- super.renameDimensions(renamer);
- List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
- for (String name : expandDimensions) {
- Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
- return; // presumably, already renamed
- }
- renamedDimensions.add(newName.get());
- }
- expandDimensions = renamedDimensions;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java
deleted file mode 100644
index c52b3357848..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-
-public class Identity extends IntermediateOperation {
-
- public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1))
- return null;
- return inputs.get(0).type().orElse(null);
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1))
- return null;
- return inputs.get(0).function().orElse(null);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java
deleted file mode 100644
index 1b62ef67d71..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java
+++ /dev/null
@@ -1,191 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.Reference;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.ArrayList;
-import java.util.Collections;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.Function;
-
-/**
- * Wraps an imported operation node and produces the respective Vespa tensor
- * operation. During import, a graph of these operations are constructed. Then,
- * the types are used to deduce sensible dimension names using the
- * DimensionRenamer. After the types have been renamed, the proper Vespa
- * expressions can be extracted.
- *
- * @author lesters
- */
-public abstract class IntermediateOperation {
-
- public final static String FUNCTION_PREFIX = "imported_ml_function_";
-
- protected final String name;
- protected final String modelName;
- protected final List<IntermediateOperation> inputs;
- protected final List<IntermediateOperation> outputs = new ArrayList<>();
-
- protected OrderedTensorType type;
- protected TensorFunction function;
- protected TensorFunction rankingExpressionFunction = null;
-
- private final List<String> importWarnings = new ArrayList<>();
- private Value constantValue = null;
- private List<IntermediateOperation> controlInputs = Collections.emptyList();
-
- protected Function<OrderedTensorType, Value> constantValueFunction = null;
-
- IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
- this.name = name;
- this.modelName = modelName;
- this.inputs = Collections.unmodifiableList(inputs);
- this.inputs.forEach(i -> i.outputs.add(this));
- }
-
- protected abstract OrderedTensorType lazyGetType();
- protected abstract TensorFunction lazyGetFunction();
-
- /** Returns the Vespa tensor type of this operation if it exists */
- public Optional<OrderedTensorType> type() {
- if (type == null) {
- type = lazyGetType();
- }
- return Optional.ofNullable(type);
- }
-
- /** Returns the Vespa tensor function implementing all operations from this node with inputs */
- public Optional<TensorFunction> function() {
- if (function == null) {
- if (isConstant()) {
- ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
- function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
- } else if (outputs.size() > 1) {
- rankingExpressionFunction = lazyGetFunction();
- function = new VariableTensor(rankingExpressionFunctionName(), type.type());
- } else {
- function = lazyGetFunction();
- }
- }
- return Optional.ofNullable(function);
- }
-
- /** Returns original name of this operation node */
- public String name() { return name; }
-
- /** Return unmodifiable list of inputs */
- public List<IntermediateOperation> inputs() { return inputs; }
-
- /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */
- public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); }
-
- /** Returns a function that should be added as a ranking expression function */
- public Optional<TensorFunction> rankingExpressionFunction() {
- return Optional.ofNullable(rankingExpressionFunction);
- }
-
- /** Add dimension name constraints for this operation */
- public void addDimensionNameConstraints(DimensionRenamer renamer) { }
-
- /** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
-
- /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
- public boolean isInput() { return false; }
-
- /** Return true if this node is constant */
- public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); }
-
- /** Sets the constant value */
- public void setConstantValue(Value value) { constantValue = value; }
-
- /** Gets the constant value if it exists */
- public Optional<Value> getConstantValue() {
- if (constantValue != null) {
- return Optional.of(constantValue);
- }
- if (constantValueFunction != null) {
- return Optional.of(constantValueFunction.apply(type));
- }
- return Optional.empty();
- }
-
- /** Set the constant value function */
- public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
-
- /** Sets the external control inputs */
- public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
-
- /** Retrieve the control inputs for this operation */
- public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
-
- /** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(name); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; }
-
- /** Retrieve the valid Vespa name of this node if it is a ranking expression function */
- public String rankingExpressionFunctionName() {
- return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
- }
-
- /** Retrieve the list of warnings produced during its lifetime */
- public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
-
- /** Set an input warning */
- public void warning(String warning) { importWarnings.add(warning); }
-
- boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) {
- if (inputs.size() != expected) {
- throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + name + "', got " + inputs.size());
- }
- return inputs.stream().map(func).allMatch(Optional::isPresent);
- }
-
- boolean allInputTypesPresent(int expected) {
- return verifyInputs(expected, IntermediateOperation::type);
- }
-
- boolean allInputFunctionsPresent(int expected) {
- return verifyInputs(expected, IntermediateOperation::function);
- }
-
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- public static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
- }
-
- /**
- * This return the output index part. Indexes are used for nodes with
- * multiple outputs.
- */
- public static int indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
- }
-
- /**
- * An interface mapping operation attributes to Vespa Values.
- * Adapter for differences in different model types.
- */
- public interface AttributeMap {
- Optional<Value> get(String key);
- Optional<Value> get(String key, OrderedTensorType type);
- Optional<List<Value>> getList(String key);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java
deleted file mode 100644
index c98bcb43331..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.function.DoubleBinaryOperator;
-
-public class Join extends IntermediateOperation {
-
- private final DoubleBinaryOperator operator;
-
- public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) {
- super(modelName, nodeName, inputs);
- this.operator = operator;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
-
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < a.rank(); ++i) {
- TensorType.Dimension aDim = a.dimensions().get(i);
- long size = aDim.size().orElse(-1L);
-
- if (i - sizeDifference >= 0) {
- TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference);
- size = Math.max(size, bDim.size().orElse(-1L));
- }
-
- if (aDim.type() == TensorType.Dimension.Type.indexedBound) {
- builder.add(TensorType.Dimension.indexed(aDim.name(), size));
- } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) {
- builder.add(TensorType.Dimension.indexed(aDim.name()));
- } else if (aDim.type() == TensorType.Dimension.Type.mapped) {
- builder.add(TensorType.Dimension.mapped(aDim.name()));
- }
- }
- return builder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
-
- IntermediateOperation a = largestInput();
- IntermediateOperation b = smallestInput();
-
- List<String> aDimensionsToReduce = new ArrayList<>();
- List<String> bDimensionsToReduce = new ArrayList<>();
- int sizeDifference = a.type().get().rank() - b.type().get().rank();
- for (int i = 0; i < b.type().get().rank(); ++i) {
- TensorType.Dimension bDim = b.type().get().dimensions().get(i);
- TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference);
- long bSize = bDim.size().orElse(-1L);
- long aSize = aDim.size().orElse(-1L);
- if (bSize == 1L && aSize != 1L) {
- bDimensionsToReduce.add(bDim.name());
- }
- if (aSize == 1L && bSize != 1L) {
- aDimensionsToReduce.add(bDim.name());
- }
- }
-
- TensorFunction aReducedFunction = a.function().get();
- if (aDimensionsToReduce.size() > 0) {
- aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce);
- }
- TensorFunction bReducedFunction = b.function().get();
- if (bDimensionsToReduce.size() > 0) {
- bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
- }
-
- return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- OrderedTensorType a = largestInput().type().get();
- OrderedTensorType b = smallestInput().type().get();
- int sizeDifference = a.rank() - b.rank();
- for (int i = 0; i < b.rank(); ++i) {
- String bDim = b.dimensions().get(i).name();
- String aDim = a.dimensions().get(i + sizeDifference).name();
- renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
- }
- }
-
- private IntermediateOperation largestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
- private IntermediateOperation smallestInput() {
- OrderedTensorType a = inputs.get(0).type().get();
- OrderedTensorType b = inputs.get(1).type().get();
- return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java
deleted file mode 100644
index 9bf6836ea9b..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java
+++ /dev/null
@@ -1,37 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.Optional;
-import java.util.function.DoubleUnaryOperator;
-
-public class Map extends IntermediateOperation {
-
- private final DoubleUnaryOperator operator;
-
- public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) {
- super(modelName, nodeName, inputs);
- this.operator = operator;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
- return inputs.get(0).type().get();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1)) {
- return null;
- }
- Optional<TensorFunction> input = inputs.get(0).function();
- return new com.yahoo.tensor.functions.Map(input.get(), operator);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java
deleted file mode 100644
index 287c23080de..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java
+++ /dev/null
@@ -1,72 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.Optional;
-
-public class MatMul extends IntermediateOperation {
-
- public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
- return typeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- Optional<TensorFunction> aFunction = inputs.get(0).function();
- Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
- }
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
-
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
-
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
-
- // For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
- }
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java
deleted file mode 100644
index f313831ea56..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java
+++ /dev/null
@@ -1,113 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Optional;
-
-public class Mean extends IntermediateOperation {
-
- private final AttributeMap attributeMap;
- private List<String> reduceDimensions;
-
- public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
- super(modelName, nodeName, inputs);
- this.attributeMap = attributeMap;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- IntermediateOperation reductionIndices = inputs.get(1);
- if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + name + ": " +
- "reduction indices must be a constant.");
- }
- Tensor indices = reductionIndices.getConstantValue().get().asTensor();
- reduceDimensions = new ArrayList<>();
-
- OrderedTensorType inputType = inputs.get(0).type().get();
- for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int dimensionIndex = cell.getValue().intValue();
- if (dimensionIndex < 0) {
- dimensionIndex = inputType.dimensions().size() - dimensionIndex;
- }
- reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
- }
- return reducedType(inputType, shouldKeepDimensions());
- }
-
- // optimization: if keepDims and one reduce dimension that has size 1: same as identity.
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- TensorFunction inputFunction = inputs.get(0).function().get();
- TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
- if (shouldKeepDimensions()) {
- // multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (String name : reduceDimensions) {
- typeBuilder.indexed(name, 1);
- }
- TensorType generatedType = typeBuilder.build();
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
- new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
- }
- return output;
- }
-
- @Override
- public void renameDimensions(DimensionRenamer renamer) {
- super.renameDimensions(renamer);
- List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
- for (String name : reduceDimensions) {
- Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
- return; // presumably, already renamed
- }
- renamedDimensions.add(newName.get());
- }
- reduceDimensions = renamedDimensions;
- }
-
- private boolean shouldKeepDimensions() {
- Optional<Value> keepDims = attributeMap.get("keep_dims");
- return keepDims.isPresent() && keepDims.get().asBoolean();
- }
-
- private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (TensorType.Dimension dimension: inputType.type().dimensions()) {
- if (!reduceDimensions.contains(dimension.name())) {
- builder.add(dimension);
- } else if (keepDimensions) {
- builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
- }
- }
- return builder.build();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java
deleted file mode 100644
index 9ce0ea6151b..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java
+++ /dev/null
@@ -1,35 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-
-public class Merge extends IntermediateOperation {
-
- public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- for (IntermediateOperation operation : inputs) {
- if (operation.type().isPresent()) {
- return operation.type().get();
- }
- }
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- for (IntermediateOperation operation : inputs) {
- if (operation.function().isPresent()) {
- return operation.function().get();
- }
- }
- return null;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java
deleted file mode 100644
index f78c945d09f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java
+++ /dev/null
@@ -1,26 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.Collections;
-import java.util.List;
-
-public class NoOp extends IntermediateOperation {
-
- public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- return null;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java
deleted file mode 100644
index 5da3300cd7a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.Optional;
-
-public class PlaceholderWithDefault extends IntermediateOperation {
-
- public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
- return inputs().get(0).type().orElse(null);
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1)) {
- return null;
- }
- // This should be a call to the function we add below, but for now
- // we treat this as as identity function and just pass the constant.
- return inputs.get(0).function().orElse(null);
- }
-
- @Override
- public Optional<TensorFunction> rankingExpressionFunction() {
- // For now, it is much more efficient to assume we always will return
- // the default value, as we can prune away large parts of the expression
- // tree by having it calculated as a constant. If a case arises where
- // it is important to support this, implement this.
- return Optional.empty();
- }
-
- @Override
- public boolean isConstant() {
- return true; // not true if we add to function
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java
deleted file mode 100644
index 3165890ff03..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java
+++ /dev/null
@@ -1,131 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.stream.Collectors;
-
-public class Reshape extends IntermediateOperation {
-
- public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- IntermediateOperation newShape = inputs.get(1);
- if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + name + ": " +
- "shape input must be a constant.");
- }
- Tensor shape = newShape.getConstantValue().get().asTensor();
-
- OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
- if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- OrderedTensorType.tensorSize(inputType.type()).intValue();
- }
- outputTypeBuilder.add(TensorType.Dimension.indexed(
- String.format("%s_%d", vespaName(), dimensionIndex), size));
- dimensionIndex++;
- }
- return outputTypeBuilder.build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
- OrderedTensorType inputType = inputs.get(0).type().get();
- TensorFunction inputFunction = inputs.get(0).function().get();
- return reshape(inputFunction, inputType.type(), type.type());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- for (TensorType.Dimension dimension : type.type().dimensions()) {
- renamer.addDimension(dimension.name());
- }
- }
-
- public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
- throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
- }
-
- // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
- // then use the dimension order of the new shape to roll back into a tensor.
- // Here we create a transformation tensor that is multiplied with the from tensor to map into
- // the new shape. We have to introduce temporary dimension names and rename back if dimension names
- // in the new and old tensor type overlap.
-
- ExpressionNode unrollFrom = unrollTensorExpression(inputType);
- ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
-
- TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
- Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
-
- TensorFunction outputFunction = new Reduce(
- new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
-
- return outputFunction;
- }
-
- private static ExpressionNode unrollTensorExpression(TensorType type) {
- if (type.rank() == 0) {
- return new ConstantNode(DoubleValue.zero);
- }
- List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
- int size = 1;
- for (int i = type.dimensions().size() - 1; i >= 0; --i) {
- TensorType.Dimension dimension = type.dimensions().get(i);
- children.add(0, new ReferenceNode(dimension.name()));
- if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
- children.add(0, new ConstantNode(new DoubleValue(size)));
- }
- size *= OrderedTensorType.dimensionSize(dimension);
- if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
- }
- }
- return new ArithmeticNode(children, operators);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java
deleted file mode 100644
index 176e6523e5f..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java
+++ /dev/null
@@ -1,88 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.function.DoubleBinaryOperator;
-
-import static com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType.dimensionSize;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType.tensorSize;
-
-public class Select extends IntermediateOperation {
-
- public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(3)) {
- return null;
- }
- OrderedTensorType a = inputs.get(1).type().get();
- OrderedTensorType b = inputs.get(2).type().get();
- if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) {
- throw new IllegalArgumentException("'Select': input tensors must have the same shape");
- }
- return a;
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(3)) {
- return null;
- }
- IntermediateOperation conditionOperation = inputs().get(0);
- TensorFunction a = inputs().get(1).function().get();
- TensorFunction b = inputs().get(2).function().get();
-
- // Shortcut: if we know during import which tensor to select, do that directly here.
- if (conditionOperation.getConstantValue().isPresent()) {
- Tensor condition = conditionOperation.getConstantValue().get().asTensor();
- if (condition.type().rank() == 0) {
- return ((int) condition.asDouble() == 0) ? b : a;
- }
- if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
- return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
- }
- }
-
- // The task is to select cells from 'x' or 'y' based on 'condition'.
- // If 'condition' is 0 (false), select from 'y', if 1 (true) select
- // from 'x'. We do this by individually joining 'x' and 'y' with
- // 'condition', and then joining the resulting two tensors.
-
- TensorFunction conditionFunction = conditionOperation.function().get();
- TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
- TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
- @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
- @Override public String toString() { return "f(a,b)(a * (1-b))"; }
- });
- return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
- }
-
- @Override
- public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(3)) {
- return;
- }
- List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions();
-
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
-
- // These tensors should have the same dimension names
- renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
- renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java
deleted file mode 100644
index 56f05936541..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java
+++ /dev/null
@@ -1,54 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-
-public class Shape extends IntermediateOperation {
-
- public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
- super(modelName, nodeName, inputs);
- createConstantValue();
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
- OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder()
- .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
- .build();
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- return null; // will be added by function() since this is constant.
- }
-
- @Override
- public boolean isConstant() {
- return true;
- }
-
- private void createConstantValue() {
- if (!allInputTypesPresent(1)) {
- return;
- }
- OrderedTensorType inputType = inputs.get(0).type().get();
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
- List<TensorType.Dimension> inputDimensions = inputType.dimensions();
- for (int i = 0; i < inputDimensions.size(); i++) {
- builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
- }
- this.setConstantValue(new TensorValue(builder.build()));
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java
deleted file mode 100644
index f0946064213..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java
+++ /dev/null
@@ -1,85 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer;
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.ArrayList;
-import java.util.List;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-public class Squeeze extends IntermediateOperation {
-
- private final AttributeMap attributeMap;
- private List<String> squeezeDimensions;
-
- public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) {
- super(modelName, nodeName, inputs);
- this.attributeMap = attributeMap;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
- OrderedTensorType inputType = inputs.get(0).type().get();
- squeezeDimensions = new ArrayList<>();
-
- Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims");
- if ( ! squeezeDimsAttr.isPresent()) {
- squeezeDimensions = inputType.type().dimensions().stream().
- filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- } else {
- squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue).
- map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
- map(i -> inputType.type().dimensions().get(i)).
- filter(dim -> OrderedTensorType.dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- }
-
- return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1)) {
- return null;
- }
- TensorFunction inputFunction = inputs.get(0).function().get();
- return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
- }
-
- @Override
- public void renameDimensions(DimensionRenamer renamer) {
- super.renameDimensions(renamer);
- List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size());
- for (String name : squeezeDimensions) {
- Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
- return; // presumably, already renamed
- }
- renamedDimensions.add(newName.get());
- }
- squeezeDimensions = renamedDimensions;
- }
-
- private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (TensorType.Dimension dimension: inputType.type().dimensions()) {
- if ( ! squeezeDimensions.contains(dimension.name())) {
- builder.add(dimension);
- }
- }
- return builder.build();
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java
deleted file mode 100644
index 24212ef175c..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java
+++ /dev/null
@@ -1,50 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml.operations;
-
-import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-import java.util.List;
-import java.util.Optional;
-
-public class Switch extends IntermediateOperation {
-
- private final int port;
-
- public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) {
- super(modelName, nodeName, inputs);
- this.port = port;
- }
-
- @Override
- protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- Optional<OrderedTensorType> predicate = inputs.get(1).type();
- if (predicate.get().type().rank() != 0) {
- throw new IllegalArgumentException("Switch in " + name + ": " +
- "predicate must be a scalar");
- }
- return inputs.get(0).type().orElse(null);
- }
-
- @Override
- protected TensorFunction lazyGetFunction() {
- IntermediateOperation predicateOperation = inputs().get(1);
- if (!predicateOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Switch in " + name + ": " +
- "predicate must be a constant");
- }
- if (port < 0 || port > 1) {
- throw new IllegalArgumentException("Switch in " + name + ": " +
- "choice should be boolean");
- }
-
- double predicate = predicateOperation.getConstantValue().get().asDouble();
- return predicate == port ? inputs().get(0).function().get() : null;
- }
-
-}
-
-
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
deleted file mode 100644
index 0eb782b5a6e..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java
+++ /dev/null
@@ -1,8 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-/**
- * Model integration
- */
-@ExportPackage
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import com.yahoo.osgi.annotation.ExportPackage;
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
deleted file mode 100644
index 04e9933e40b..00000000000
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import org.junit.Test;
-
-import static org.junit.Assert.assertTrue;
-
-public class DimensionRenamerTest {
-
- @Test
- public void testMnistRenaming() {
- DimensionRenamer renamer = new DimensionRenamer();
-
- renamer.addDimension("first_dimension_of_x");
- renamer.addDimension("second_dimension_of_x");
- renamer.addDimension("first_dimension_of_w");
- renamer.addDimension("second_dimension_of_w");
- renamer.addDimension("first_dimension_of_b");
-
- // which dimension to join on matmul
- renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
-
- // other dimensions in matmul can't be equal
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
-
- // for efficiency, put dimension to join on innermost
- renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
- renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
-
- // bias
- renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
-
- renamer.solve();
-
- String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
- String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
- String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
- String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
- String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
-
- assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
- assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
- assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
- assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
- assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
- }
-
-}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
deleted file mode 100644
index f22e7cb087a..00000000000
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java
+++ /dev/null
@@ -1,21 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.ml;
-
-import org.junit.Test;
-
-import static org.junit.Assert.assertEquals;
-
-/**
- * @author bratseth
- */
-public class OrderedTensorTypeTestCase {
-
- @Test
- public void testToFromSpec() {
- String spec = "tensor(b[],c{},a[3])";
- OrderedTensorType type = OrderedTensorType.fromSpec(spec);
- assertEquals(spec, type.toString());
- assertEquals("tensor(a[3],b[],c{})", type.type().toString());
- }
-
-}