diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-22 14:27:58 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-22 14:27:58 +0100 |
commit | b288e61f7af7331656a1850fbdc58cc95fd1bbad (patch) | |
tree | 9d41fa770d2890585a902f41a89c41040ed764be /searchlib/src | |
parent | 3c4020645b13be560c14e60969e50e3ad41e3d3c (diff) |
Move all importing to model-integration
Diffstat (limited to 'searchlib/src')
28 files changed, 0 insertions, 2713 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java deleted file mode 100644 index 86c4c287f05..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamer.java +++ /dev/null @@ -1,210 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation; - -import java.util.ArrayDeque; -import java.util.ArrayList; -import java.util.Collections; -import java.util.Deque; -import java.util.HashMap; -import java.util.Iterator; -import java.util.List; -import java.util.Map; -import java.util.Objects; -import java.util.Optional; - -/** - * A constraint satisfier to find suitable dimension names to reduce the - * amount of necessary renaming during evaluation of an imported model. - * - * @author lesters - */ -public class DimensionRenamer { - - private final String dimensionPrefix; - private final Map<String, List<Integer>> variables = new HashMap<>(); - private final Map<Arc, Constraint> constraints = new HashMap<>(); - private final Map<String, Integer> renames = new HashMap<>(); - - private int iterations = 0; - - public DimensionRenamer() { - this("d"); - } - - public DimensionRenamer(String dimensionPrefix) { - this.dimensionPrefix = dimensionPrefix; - } - - /** - * Add a dimension name variable. - */ - public void addDimension(String name) { - variables.computeIfAbsent(name, d -> new ArrayList<>()); - } - - /** - * Add a constraint between dimension names. - */ - public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { - Arc arc = new Arc(from, to, operation); - Arc opposite = arc.opposite(); - constraints.put(arc, pred); - constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric - } - - /** - * Retrieve resulting name of dimension after solving for constraints. - */ - public Optional<String> dimensionNameOf(String name) { - if (!renames.containsKey(name)) { - return Optional.empty(); - } - return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); - } - - /** - * Perform iterative arc consistency until we have found a solution. After - * an initial iteration, the variables (dimensions) will have multiple - * valid values. Find a single valid assignment by iteratively locking one - * dimension after another, and running the arc consistency algorithm - * multiple times. - * - * This requires having constraints that result in an absolute ordering: - * equals, lesserThan and greaterThan do that, but adding notEquals does - * not typically result in a guaranteed ordering. If that is needed, the - * algorithm below needs to be adapted with a backtracking (tree) search - * to find solutions. - */ - public void solve(int maxIterations) { - initialize(); - - // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts - - for (String dimension : variables.keySet()) { - List<Integer> values = variables.get(dimension); - if (values.size() > 1) { - if (!ac3()) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution."); - } - values.sort(Integer::compare); - variables.put(dimension, Collections.singletonList(values.get(0))); - } - renames.put(dimension, variables.get(dimension).get(0)); - if (iterations > maxIterations) { - throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + - maxIterations + " iterations"); - } - } - - // Todo: handle failure more gracefully: - // If a solution can't be found, look at the operation node in the arc - // with the most remaining constraints, and inject a rename operation. - // Then run this algorithm again. - } - - public void solve() { - solve(100000); - } - - private void initialize() { - for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { - List<Integer> values = variable.getValue(); - for (int i = 0; i < variables.size(); ++i) { - values.add(i); // invariant: values are in increasing order - } - } - } - - private boolean ac3() { - Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); - while (!workList.isEmpty()) { - Arc arc = workList.pop(); - iterations += 1; - if (revise(arc)) { - if (variables.get(arc.from).size() == 0) { - return false; // no solution found - } - for (Arc constraint : constraints.keySet()) { - if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { - workList.add(constraint); - } - } - } - } - return true; - } - - private boolean revise(Arc arc) { - boolean revised = false; - for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { - Integer from = fromIterator.next(); - boolean satisfied = false; - for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { - Integer to = toIterator.next(); - if (constraints.get(arc).test(from, to)) { - satisfied = true; - } - } - if (!satisfied) { - fromIterator.remove(); - revised = true; - } - } - return revised; - } - - public interface Constraint { - boolean test(Integer x, Integer y); - } - - public static boolean equals(Integer x, Integer y) { - return Objects.equals(x, y); - } - - public static boolean lesserThan(Integer x, Integer y) { - return x < y; - } - - public static boolean greaterThan(Integer x, Integer y) { - return x > y; - } - - private static class Arc { - - private final String from; - private final String to; - private final IntermediateOperation operation; - - Arc(String from, String to, IntermediateOperation operation) { - this.from = from; - this.to = to; - this.operation = operation; - } - - Arc opposite() { - return new Arc(to, from, operation); - } - - @Override - public int hashCode() { - return Objects.hash(from, to); - } - - @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof Arc)) { - return false; - } - Arc other = (Arc) obj; - return Objects.equals(from, other.from) && Objects.equals(to, other.to); - } - - @Override - public String toString() { - return String.format("%s -> %s", from, to); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java deleted file mode 100644 index e66d0ab6f35..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ /dev/null @@ -1,226 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.google.common.collect.ImmutableMap; -import com.yahoo.collections.Pair; -import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.HashMap; -import java.util.LinkedHashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.regex.Pattern; - -/** - * The result of importing an ML model into Vespa. - * - * @author bratseth - */ -public class ImportedModel { - - private static final String defaultSignatureName = "default"; - - private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); - private final String name; - private final String source; - - private final Map<String, Signature> signatures = new HashMap<>(); - private final Map<String, TensorType> inputs = new HashMap<>(); - private final Map<String, Tensor> smallConstants = new HashMap<>(); - private final Map<String, Tensor> largeConstants = new HashMap<>(); - private final Map<String, RankingExpression> expressions = new HashMap<>(); - private final Map<String, RankingExpression> functions = new HashMap<>(); - - /** - * Creates a new imported model. - * - * @param name the name of this mode, containing only characters in [A-Za-z0-9_] - * @param source the source path (directory or file) of this model - */ - public ImportedModel(String name, String source) { - if ( ! nameRegexp.matcher(name).matches()) - throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); - this.name = name; - this.source = source; - } - - /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ - public String name() { return name; } - - /** Returns the source path (directory or file) of this model */ - public String source() { return source; } - - /** Returns an immutable map of the inputs of this */ - public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } - - /** - * Returns an immutable map of the small constants of this. - * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. - */ - public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } - - /** - * Returns an immutable map of the large constants of this. - * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. - * For TensorFlow this corresponds to Variable files stored separately. - */ - public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } - - /** - * Returns an immutable map of the expressions of this - corresponding to graph nodes - * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants). - * Note that only nodes recursively referenced by a placeholder/input are added. - */ - public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } - - /** - * Returns an immutable map of the functions that are part of this model. - * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. - */ - public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } - - /** Returns an immutable map of the signatures of this */ - public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } - - /** Returns the given signature. If it does not already exist it is added to this. */ - public Signature signature(String name) { - return signatures.computeIfAbsent(name, Signature::new); - } - - /** Convenience method for returning a default signature */ - public Signature defaultSignature() { return signature(defaultSignatureName); } - - public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } - public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } - public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } - public void expression(String name, RankingExpression expression) { expressions.put(name, expression); } - public void function(String name, RankingExpression expression) { functions.put(name, expression); } - - /** - * Returns all the output expressions of this indexed by name. The names consist of one or two parts - * separated by dot, where the first part is the signature name - * if signatures are used, or the expression name if signatures are not used and there are multiple - * expressions, and the second is the output name if signature names are used. - */ - public List<Pair<String, ExpressionFunction>> outputExpressions() { - List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); - for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { - for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) - expressions.add(new Pair<>(signatureEntry.getKey() + "" + outputEntry.getKey(), - signatureEntry.getValue().outputExpression(outputEntry.getKey()) - .withName(signatureEntry.getKey() + "" + outputEntry.getKey()))); - if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs - expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().values()), - expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap(), - Optional.empty()))); - } - if (signatures().isEmpty()) { // fallback for models without signatures - if (expressions().size() == 1) { - Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); - expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionFunction(singleEntry.getKey(), - new ArrayList<>(inputs.keySet()), - singleEntry.getValue(), - inputs, - Optional.empty()))); - } - else { - for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { - expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionFunction(expressionEntry.getKey(), - new ArrayList<>(inputs.keySet()), - expressionEntry.getValue(), - inputs, - Optional.empty()))); - } - } - } - return expressions; - } - - /** - * A signature is a set of named inputs and outputs, where the inputs maps to input - * ("placeholder") names+types, and outputs maps to expressions nodes. - * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit - * concept of signatures. For now, we handle ONNX models as having a single signature. - */ - public class Signature { - - private final String name; - private final Map<String, String> inputs = new LinkedHashMap<>(); - private final Map<String, String> outputs = new LinkedHashMap<>(); - private final Map<String, String> skippedOutputs = new HashMap<>(); - private final List<String> importWarnings = new ArrayList<>(); - - public Signature(String name) { - this.name = name; - } - - public String name() { return name; } - - /** Returns the result this is part of */ - public ImportedModel owner() { return ImportedModel.this; } - - /** - * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name - * in this signature to input name in the owning model - */ - public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } - - /** Returns the name and type of all inputs in this signature as an immutable map */ - public Map<String, TensorType> inputMap() { - ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); - // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* - // in the model, as these are the names which must actually be bound, if we are to avoid creating an - // "input mapping" to accomodate this complexity in - for (Map.Entry<String, String> inputEntry : inputs().entrySet()) - inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); - return inputs.build(); - } - - /** Returns the type of the input this input references */ - public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); } - - /** Returns an immutable list of the expression names of this */ - public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } - - /** - * Returns an immutable list of the outputs of this which could not be imported, - * with a string detailing the reason for each - */ - public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } - - /** - * Returns an immutable list of possibly non-fatal warnings encountered during import. - */ - public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } - - /** Returns the expression this output references */ - public ExpressionFunction outputExpression(String outputName) { - return new ExpressionFunction(outputName, - new ArrayList<>(inputs.values()), - owner().expressions().get(outputs.get(outputName)), - inputMap(), - Optional.empty()); - } - - @Override - public String toString() { return "signature '" + name + "'"; } - - void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } - void output(String name, String expressionName) { outputs.put(name, expressionName); } - void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } - void importWarning(String warning) { importWarnings.add(warning); } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java deleted file mode 100644 index 896cd2e8d21..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModels.java +++ /dev/null @@ -1,109 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.google.common.collect.ImmutableMap; -import com.yahoo.path.Path; - -import java.io.File; -import java.util.Arrays; -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; -import java.util.Optional; - -/** - * All models imported from the models/ directory in the application package. - * If this is empty it may be due to either not having any models in the application package, - * or this being created for a ZooKeeper application package, which does not have imported models. - * - * @author bratseth - */ -public class ImportedModels { - - /** All imported models, indexed by their names */ - private final ImmutableMap<String, ImportedModel> importedModels; - - /** Create a null imported models */ - public ImportedModels() { - importedModels = ImmutableMap.of(); - } - - public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) { - Map<String, ImportedModel> models = new HashMap<>(); - - // Find all subdirectories recursively which contains a model we can read - importRecursively(modelsDirectory, models, importers); - importedModels = ImmutableMap.copyOf(models); - } - - private static void importRecursively(File dir, - Map<String, ImportedModel> models, - Collection<ModelImporter> importers) { - if ( ! dir.isDirectory()) return; - - Arrays.stream(dir.listFiles()).sorted().forEach(child -> { - Optional<ModelImporter> importer = findImporterOf(child, importers); - if (importer.isPresent()) { - String name = toName(child); - ImportedModel existing = models.get(name); - if (existing != null) - throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + - " both resolve to the model name '" + name + "'"); - models.put(name, importer.get().importModel(name, child)); - } - else { - importRecursively(child, models, importers); - } - }); - } - - private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) { - return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); - } - - /** - * Returns the model at the given location in the application package. - * - * @param modelPath the path to this model (file or directory, depending on model type) - * under the application package, both from the root or relative to the - * models directory works - * @return the model at this path or null if none - */ - public ImportedModel get(File modelPath) { - return importedModels.get(toName(modelPath)); - } - - public ImportedModel get(String modelName) { - return importedModels.get(modelName); - } - - /** Returns an immutable collection of all the imported models */ - public Collection<ImportedModel> all() { - return importedModels.values(); - } - - private static String toName(File modelFile) { - Path modelPath = Path.fromString(modelFile.toString()); - if (modelFile.isFile()) - modelPath = stripFileEnding(modelPath); - String localPath = concatenateAfterModelsDirectory(modelPath); - return localPath.replace('.', '_'); - } - - private static Path stripFileEnding(Path path) { - int dotIndex = path.last().lastIndexOf(""); - if (dotIndex <= 0) return path; - return path.withLast(path.last().substring(0, dotIndex)); - } - - private static String concatenateAfterModelsDirectory(Path path) { - boolean afterModels = false; - StringBuilder result = new StringBuilder(); - for (String element : path.elements()) { - if (afterModels) result.append(element).append("_"); - if (element.equals("models")) afterModels = true; - } - return result.substring(0, result.length()-1); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java deleted file mode 100644 index 81c176707e5..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/IntermediateGraph.java +++ /dev/null @@ -1,107 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation; - -import java.util.Collection; -import java.util.HashMap; -import java.util.Map; -import java.util.Set; - -/** - * Holds an intermediate representation of an imported model graph. - * After this intermediate representation is constructed, it is used to - * simplify and optimize the computational graph and then converted into the - * final ImportedModel that holds the Vespa ranking expressions for the model. - * - * @author lesters - */ -public class IntermediateGraph { - - private final String modelName; - private final Map<String, IntermediateOperation> index = new HashMap<>(); - private final Map<String, GraphSignature> signatures = new HashMap<>(); - - private static class GraphSignature { - final Map<String, String> inputs = new HashMap<>(); - final Map<String, String> outputs = new HashMap<>(); - } - - public IntermediateGraph(String modelName) { - this.modelName = modelName; - } - - public String name() { - return modelName; - } - - public IntermediateOperation put(String key, IntermediateOperation operation) { - return index.put(key, operation); - } - - public IntermediateOperation get(String key) { - return index.get(key); - } - - public Set<String> signatures() { - return signatures.keySet(); - } - - public Map<String, String> inputs(String signature) { - return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs; - } - - public Map<String, String> outputs(String signature) { - return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs; - } - - public String defaultSignature() { - return "default"; - } - - public boolean alreadyImported(String key) { - return index.containsKey(key); - } - - public Collection<IntermediateOperation> operations() { - return index.values(); - } - - public void optimize() { - renameDimensions(); - } - - /** - * Find dimension names to avoid excessive renaming while evaluating the model. - */ - private void renameDimensions() { - DimensionRenamer renamer = new DimensionRenamer(); - for (String signature : signatures()) { - for (String output : outputs(signature).values()) { - addDimensionNameConstraints(index.get(output), renamer); - } - } - renamer.solve(); - for (String signature : signatures()) { - for (String output : outputs(signature).values()) { - renameDimensions(index.get(output), renamer); - } - } - } - - private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); - operation.addDimensionNameConstraints(renamer); - } - } - - private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { - if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); - operation.renameDimensions(renamer); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java deleted file mode 100644 index 47bf6a2a240..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ /dev/null @@ -1,243 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; -import com.yahoo.yolean.Exceptions; - -import java.io.File; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.logging.Logger; - -/** - * Base class for importing ML models (ONNX/TensorFlow) as native Vespa - * ranking expressions. The general mechanism for import is for the - * specific ML platform import implementations to create an - * IntermediateGraph. This class offers common code to convert the - * IntermediateGraph to Vespa ranking expressions and functions. - * - * @author lesters - */ -public abstract class ModelImporter { - - private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); - - /** Returns whether the file or directory at the given path is of the type which can be imported by this */ - public abstract boolean canImport(String modelPath); - - /** Imports the given model */ - public abstract ImportedModel importModel(String modelName, String modelPath); - - public final ImportedModel importModel(String modelName, File modelPath) { - return importModel(modelName, modelPath.toString()); - } - - /** - * Takes an IntermediateGraph and converts it to a ImportedModel containing - * the actual Vespa ranking expressions. - */ - public static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { - ImportedModel model = new ImportedModel(graph.name(), modelSource); - - graph.optimize(); - - importSignatures(graph, model); - importExpressions(graph, model); - reportWarnings(graph, model); - logVariableTypes(graph); - - return model; - } - - private static void importSignatures(IntermediateGraph graph, ImportedModel model) { - for (String signatureName : graph.signatures()) { - ImportedModel.Signature signature = model.signature(signatureName); - for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) { - signature.input(input.getKey(), input.getValue()); - } - for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { - signature.output(output.getKey(), output.getValue()); - } - } - } - - private static boolean isSignatureInput(ImportedModel model, IntermediateOperation operation) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String inputName : signature.inputs().values()) { - if (inputName.equals(operation.name())) { - return true; - } - } - } - return false; - } - - private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - if (outputName.equals(operation.name())) { - return true; - } - } - } - return false; - } - - /** - * Convert intermediate representation to Vespa ranking expressions. - */ - static void importExpressions(IntermediateGraph graph, ImportedModel model) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - try { - Optional<TensorFunction> function = importExpression(graph.get(outputName), model); - if (!function.isPresent()) { - signature.skippedOutput(outputName, "No valid output function could be found."); - } - } - catch (IllegalArgumentException e) { - signature.skippedOutput(outputName, Exceptions.toMessageString(e)); - } - } - } - } - - private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { - if (!operation.type().isPresent()) { - return Optional.empty(); - } - if (operation.isConstant()) { - return importConstant(operation, model); - } - importExpressionInputs(operation, model); - importRankingExpression(operation, model); - importArgumentExpression(operation, model); - importFunctionExpression(operation, model); - - return operation.function(); - } - - private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { - operation.inputs().forEach(input -> importExpression(input, model)); - } - - private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { - String name = operation.vespaName(); - if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { - return operation.function(); - } - - Value value = operation.getConstantValue().orElseThrow(() -> - new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + - "is constant but does not have a value.")); - if ( ! (value instanceof TensorValue)) { - return operation.function(); // scalar values are inserted directly into the expression - } - - Tensor tensor = value.asTensor(); - if (tensor.type().rank() == 0) { - model.smallConstant(name, tensor); - } else { - model.largeConstant(name, tensor); - } - return operation.function(); - } - - private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.function().isPresent()) { - String name = operation.name(); - if ( ! model.expressions().containsKey(name)) { - TensorFunction function = operation.function().get(); - - if (isSignatureOutput(model, operation)) { - OrderedTensorType operationType = operation.type().get(); - OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); - if ( ! operationType.equals(standardNamingType)) { - List<String> renameFrom = operationType.dimensionNames(); - List<String> renameTo = standardNamingType.dimensionNames(); - function = new Rename(function, renameFrom, renameTo); - } - } - - try { - // We add all intermediate nodes imported as separate expressions. Only - // those referenced from the output will be used. We parse the - // TensorFunction here to convert it to a RankingExpression tree. - model.expression(name, new RankingExpression(name, function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Imported function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - } - - private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.isInput()) { - // All inputs must have dimensions with standard naming convention: d0, d1, ... - OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); - model.input(operation.vespaName(), standardNamingConvention.type()); - } - } - - private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) { - if (operation.rankingExpressionFunction().isPresent()) { - TensorFunction function = operation.rankingExpressionFunction().get(); - try { - model.function(operation.rankingExpressionFunctionName(), - new RankingExpression(operation.rankingExpressionFunctionName(), - function.toString())); - } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); - } - } - } - - /** - * Add any import warnings to the signature in the ImportedModel. - */ - private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { - for (ImportedModel.Signature signature : model.signatures().values()) { - for (String outputName : signature.outputs().values()) { - reportWarnings(graph.get(outputName), model); - } - } - } - - private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { - for (String warning : operation.warnings()) { - model.defaultSignature().importWarning(warning); - } - for (IntermediateOperation input : operation.inputs()) { - reportWarnings(input, model); - } - } - - /** - * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. - * This allows users to learn the exact types (including dimension order after renaming) of the Variables - * such that these can be converted and fed to a parent document independently of the rest of the model - * for fast model weight updates. - */ - private static void logVariableTypes(IntermediateGraph graph) { - for (IntermediateOperation operation : graph.operations()) { - if ( ! (operation instanceof Constant)) continue; - if ( ! operation.type().isPresent()) continue; // will not happen - log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + - " of type " + operation.type().get()); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java deleted file mode 100644 index dfd073f82b4..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorType.java +++ /dev/null @@ -1,235 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.TensorTypeParser; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * A Vespa tensor type is ordered by the lexicographical ordering of dimension - * names. Imported tensors have an explicit ordering of their dimensions. - * During import, we need to track the Vespa dimension that matches the - * corresponding imported dimension as the ordering can change after - * dimension renaming. That is the purpose of this class. - * - * @author lesters - */ -public class OrderedTensorType { - - private final TensorType type; - private final List<TensorType.Dimension> dimensions; - - private final long[] innerSizesOriginal; - private final long[] innerSizesVespa; - private final int[] dimensionMap; - - private OrderedTensorType(List<TensorType.Dimension> dimensions) { - this.dimensions = Collections.unmodifiableList(dimensions); - this.type = new TensorType.Builder(dimensions).build(); - this.innerSizesOriginal = new long[dimensions.size()]; - this.innerSizesVespa = new long[dimensions.size()]; - this.dimensionMap = createDimensionMap(); - } - - public TensorType type() { return this.type; } - - public int rank() { return dimensions.size(); } - - public List<TensorType.Dimension> dimensions() { - return dimensions; - } - - public List<String> dimensionNames() { - return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); - } - - private int[] createDimensionMap() { - int numDimensions = dimensions.size(); - if (numDimensions == 0) { - return null; - } - innerSizesOriginal[numDimensions - 1] = 1; - innerSizesVespa[numDimensions - 1] = 1; - for (int i = numDimensions - 1; --i >= 0; ) { - innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1]; - innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; - } - int[] mapping = new int[numDimensions]; - for (int i = 0; i < numDimensions; ++i) { - TensorType.Dimension dim1 = dimensions().get(i); - for (int j = 0; j < numDimensions; ++j) { - TensorType.Dimension dim2 = type.dimensions().get(j); - if (dim1.equals(dim2)) { - mapping[i] = j; - break; - } - } - } - return mapping; - } - - public int dimensionMap(int originalIndex) { - return dimensionMap[originalIndex]; - } - - /** - * When dimension ordering between Vespa and imported differs, i.e. - * after dimension renaming, use the dimension map to read in values - * so that they are correctly laid out in memory for Vespa. - * Used when importing tensors. - */ - public int toDirectIndex(int index) { - if (dimensions.size() == 0) { - return 0; - } - if (dimensionMap == null) { - throw new IllegalArgumentException("Dimension map is not available"); - } - int directIndex = 0; - long rest = index; - for (int i = 0; i < dimensions.size(); ++i) { - long address = rest / innerSizesOriginal[i]; - directIndex += innerSizesVespa[dimensionMap[i]] * address; - rest %= innerSizesOriginal[i]; - } - return directIndex; - } - - @Override - public boolean equals(Object obj) { - if (obj == null || !(obj instanceof OrderedTensorType)) { - return false; - } - OrderedTensorType other = (OrderedTensorType) obj; - if (dimensions.size() != dimensions.size()) { - return false; - } - List<TensorType.Dimension> thisDimensions = this.dimensions(); - List<TensorType.Dimension> otherDimensions = other.dimensions(); - for (int i = 0; i < thisDimensions.size(); ++i) { - if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { - return false; - } - } - return true; - } - - public OrderedTensorType rename(DimensionRenamer renamer) { - List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); - for (TensorType.Dimension dimension : dimensions) { - String oldName = dimension.name(); - Optional<String> newName = renamer.dimensionNameOf(oldName); - if (!newName.isPresent()) - return this; // presumably, already renamed - TensorType.Dimension.Type dimensionType = dimension.type(); - if (dimensionType == TensorType.Dimension.Type.indexedBound) { - renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); - } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { - renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); - } else if (dimensionType == TensorType.Dimension.Type.mapped) { - renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); - } - } - return new OrderedTensorType(renamedDimensions); - } - - public OrderedTensorType rename(String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dimensions.size(); ++ i) { - String dimensionName = dimensionPrefix + i; - Optional<Long> dimSize = dimensions.get(i).size(); - if (dimSize.isPresent() && dimSize.get() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - public static OrderedTensorType standardType(OrderedTensorType type) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < type.dimensions().size(); ++ i) { - TensorType.Dimension dim = type.dimensions().get(i); - String dimensionName = "d" + i; - if (dim.size().isPresent() && dim.size().get() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - public static Long tensorSize(TensorType type) { - Long size = 1L; - for (TensorType.Dimension dimension : type.dimensions()) { - size *= dimensionSize(dimension); - } - return size; - } - - public static Long dimensionSize(TensorType.Dimension dim) { - return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); - } - - /** - * Returns a string representation of this: A standard tensor type string where dimensions - * are listed in the order of this rather than in the natural order of their names. - */ - @Override - public String toString() { - return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; - } - - /** - * Creates an instance from the string representation of this: A standard tensor type string - * where dimensions are listed in the order of this rather than the natural order of their names. - */ - public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); - } - - public static OrderedTensorType fromDimensionList(List<Long> dims) { - return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... - } - - public static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dims.size(); ++ i) { - String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); - if (dimSize >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - public static class Builder { - - private final List<TensorType.Dimension> dimensions; - - public Builder() { - this.dimensions = new ArrayList<>(); - } - - public Builder add(TensorType.Dimension vespaDimension) { - this.dimensions.add(vespaDimension); - return this; - } - - public OrderedTensorType build() { - return new OrderedTensorType(dimensions); - } - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java deleted file mode 100644 index 57e4ec13e45..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Argument.java +++ /dev/null @@ -1,57 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.Collections; -import java.util.List; - -public class Argument extends IntermediateOperation { - - private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... - - public Argument(String modelName, String nodeName, OrderedTensorType type) { - super(modelName, nodeName, Collections.emptyList()); - this.type = type.rename(vespaName() + "_"); - standardNamingType = OrderedTensorType.standardType(type); - } - - @Override - protected OrderedTensorType lazyGetType() { - return type; - } - - @Override - protected TensorFunction lazyGetFunction() { - TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); - if (!standardNamingType.equals(type)) { - List<String> renameFrom = standardNamingType.dimensionNames(); - List<String> renameTo = type.dimensionNames(); - output = new Rename(output, renameFrom, renameTo); - } - return output; - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public boolean isInput() { - return true; - } - - @Override - public boolean isConstant() { - return false; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java deleted file mode 100644 index 413b856e43d..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ConcatV2.java +++ /dev/null @@ -1,108 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.Optional; - -public class ConcatV2 extends IntermediateOperation { - - private String concatDimensionName; - - public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { - return null; - } - - IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input - if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a constant."); - } - Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); - if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a scalar."); - } - - OrderedTensorType aType = inputs.get(0).type().get(); - - int concatDim = (int)concatDimTensor.asDouble(); - long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L); - - for (int i = 1; i < inputs.size() - 1; ++i) { - OrderedTensorType bType = inputs.get(i).type().get(); - if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "inputs must have save rank."); - } - for (int j = 0; j < aType.rank(); ++j) { - long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); - long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); - if (j == concatDim) { - concatDimSize += dimSizeB; - } else if (dimSizeA != dimSizeB) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "input dimension " + j + " differs in input tensors."); - } - } - } - - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); - int dimensionIndex = 0; - for (TensorType.Dimension dimension : aType.dimensions()) { - if (dimensionIndex == concatDim) { - concatDimensionName = dimension.name(); - typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize)); - } else { - typeBuilder.add(dimension); - } - dimensionIndex++; - } - return typeBuilder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { - return null; - } - TensorFunction result = inputs.get(0).function().get(); - for (int i = 1; i < inputs.size() - 1; ++i) { - TensorFunction b = inputs.get(i).function().get(); - result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); - } - return result; - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { - return; - } - OrderedTensorType a = inputs.get(0).type().get(); - for (int i = 1; i < inputs.size() - 1; ++i) { - OrderedTensorType b = inputs.get(i).type().get(); - String bDim = b.dimensions().get(i).name(); - String aDim = a.dimensions().get(i).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); - } - } - - @Override - public void renameDimensions(DimensionRenamer renamer) { - super.renameDimensions(renamer); - concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java deleted file mode 100644 index 49c0bb712c5..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Const.java +++ /dev/null @@ -1,89 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.Optional; - -public class Const extends IntermediateOperation { - - private final AttributeMap attributeMap; - - public Const(String modelName, - String nodeName, - List<IntermediateOperation> inputs, - AttributeMap attributeMap, - OrderedTensorType type) { - super(modelName, nodeName, inputs); - this.attributeMap = attributeMap; - this.type = type.rename(vespaName() + "_"); - setConstantValue(value()); - } - - @Override - protected OrderedTensorType lazyGetType() { - return type; - } - - @Override - public Optional<TensorFunction> function() { - if (function == null) { - function = lazyGetFunction(); - } - return Optional.ofNullable(function); - } - - @Override - protected TensorFunction lazyGetFunction() { - ExpressionNode expressionNode; - if (type.type().rank() == 0 && getConstantValue().isPresent()) { - expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); - } else { - expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); - } - return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); - } - - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public void renameDimensions(DimensionRenamer renamer) { - super.renameDimensions(renamer); - setConstantValue(value()); - } - - @Override - public boolean isConstant() { - return true; - } - - private Value value() { - Optional<Value> value = attributeMap.get("value", type); - if ( ! value.isPresent()) { - throw new IllegalArgumentException("Node '" + name + "' of type " + - "const has missing or non-recognized 'value' attribute"); - } - return value.get(); - } -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java deleted file mode 100644 index 670274376a3..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Constant.java +++ /dev/null @@ -1,61 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.Collections; -import java.util.Optional; - -public class Constant extends IntermediateOperation { - - private final String modelName; - - public Constant(String modelName, String nodeName, OrderedTensorType type) { - super(modelName, nodeName, Collections.emptyList()); - this.modelName = modelName; - this.type = type.rename(vespaName() + "_"); - } - - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + vespaName(name); - } - - @Override - protected OrderedTensorType lazyGetType() { - return type; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; // will be added by function() since this is constant. - } - - /** - * Constant values are sent in via the constantValueFunction, as the - * dimension names and thus the data layout depends on the dimension - * renaming which happens after the conversion to intermediate graph. - */ - @Override - public Optional<Value> getConstantValue() { - return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type)); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public boolean isConstant() { - return true; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java deleted file mode 100644 index e037a9a2497..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/ExpandDims.java +++ /dev/null @@ -1,106 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; - -public class ExpandDims extends IntermediateOperation { - - private List<String> expandDimensions; - - public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - - IntermediateOperation axisOperation = inputs().get(1); - if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis must be a constant."); - } - Tensor axis = axisOperation.getConstantValue().get().asTensor(); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis argument must be a scalar."); - } - - OrderedTensorType inputType = inputs.get(0).type().get(); - int dimensionToInsert = (int)axis.asDouble(); - if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; - } - - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); - expandDimensions = new ArrayList<>(); - int dimensionIndex = 0; - for (TensorType.Dimension dimension : inputType.dimensions()) { - if (dimensionIndex == dimensionToInsert) { - String name = String.format("%s_%d", vespaName(), dimensionIndex); - expandDimensions.add(name); - typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); - } - typeBuilder.add(dimension); - dimensionIndex++; - } - - return typeBuilder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(2)) { - return null; - } - - // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (String name : expandDimensions) { - typeBuilder.indexed(name, 1); - } - TensorType generatedType = typeBuilder.build(); - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, - new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - @Override - public void renameDimensions(DimensionRenamer renamer) { - super.renameDimensions(renamer); - List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); - for (String name : expandDimensions) { - Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { - return; // presumably, already renamed - } - renamedDimensions.add(newName.get()); - } - expandDimensions = renamedDimensions; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java deleted file mode 100644 index c52b3357848..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Identity.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; - -public class Identity extends IntermediateOperation { - - public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) - return null; - return inputs.get(0).type().orElse(null); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) - return null; - return inputs.get(0).function().orElse(null); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java deleted file mode 100644 index 1b62ef67d71..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/IntermediateOperation.java +++ /dev/null @@ -1,191 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.Reference; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Optional; -import java.util.function.Function; - -/** - * Wraps an imported operation node and produces the respective Vespa tensor - * operation. During import, a graph of these operations are constructed. Then, - * the types are used to deduce sensible dimension names using the - * DimensionRenamer. After the types have been renamed, the proper Vespa - * expressions can be extracted. - * - * @author lesters - */ -public abstract class IntermediateOperation { - - public final static String FUNCTION_PREFIX = "imported_ml_function_"; - - protected final String name; - protected final String modelName; - protected final List<IntermediateOperation> inputs; - protected final List<IntermediateOperation> outputs = new ArrayList<>(); - - protected OrderedTensorType type; - protected TensorFunction function; - protected TensorFunction rankingExpressionFunction = null; - - private final List<String> importWarnings = new ArrayList<>(); - private Value constantValue = null; - private List<IntermediateOperation> controlInputs = Collections.emptyList(); - - protected Function<OrderedTensorType, Value> constantValueFunction = null; - - IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { - this.name = name; - this.modelName = modelName; - this.inputs = Collections.unmodifiableList(inputs); - this.inputs.forEach(i -> i.outputs.add(this)); - } - - protected abstract OrderedTensorType lazyGetType(); - protected abstract TensorFunction lazyGetFunction(); - - /** Returns the Vespa tensor type of this operation if it exists */ - public Optional<OrderedTensorType> type() { - if (type == null) { - type = lazyGetType(); - } - return Optional.ofNullable(type); - } - - /** Returns the Vespa tensor function implementing all operations from this node with inputs */ - public Optional<TensorFunction> function() { - if (function == null) { - if (isConstant()) { - ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); - function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); - } else if (outputs.size() > 1) { - rankingExpressionFunction = lazyGetFunction(); - function = new VariableTensor(rankingExpressionFunctionName(), type.type()); - } else { - function = lazyGetFunction(); - } - } - return Optional.ofNullable(function); - } - - /** Returns original name of this operation node */ - public String name() { return name; } - - /** Return unmodifiable list of inputs */ - public List<IntermediateOperation> inputs() { return inputs; } - - /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */ - public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } - - /** Returns a function that should be added as a ranking expression function */ - public Optional<TensorFunction> rankingExpressionFunction() { - return Optional.ofNullable(rankingExpressionFunction); - } - - /** Add dimension name constraints for this operation */ - public void addDimensionNameConstraints(DimensionRenamer renamer) { } - - /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } - - /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ - public boolean isInput() { return false; } - - /** Return true if this node is constant */ - public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); } - - /** Sets the constant value */ - public void setConstantValue(Value value) { constantValue = value; } - - /** Gets the constant value if it exists */ - public Optional<Value> getConstantValue() { - if (constantValue != null) { - return Optional.of(constantValue); - } - if (constantValueFunction != null) { - return Optional.of(constantValueFunction.apply(type)); - } - return Optional.empty(); - } - - /** Set the constant value function */ - public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } - - /** Sets the external control inputs */ - public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } - - /** Retrieve the control inputs for this operation */ - public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } - - /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(name); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } - - /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ - public String rankingExpressionFunctionName() { - return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; - } - - /** Retrieve the list of warnings produced during its lifetime */ - public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } - - /** Set an input warning */ - public void warning(String warning) { importWarnings.add(warning); } - - boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) { - if (inputs.size() != expected) { - throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + name + "', got " + inputs.size()); - } - return inputs.stream().map(func).allMatch(Optional::isPresent); - } - - boolean allInputTypesPresent(int expected) { - return verifyInputs(expected, IntermediateOperation::type); - } - - boolean allInputFunctionsPresent(int expected) { - return verifyInputs(expected, IntermediateOperation::function); - } - - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - public static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; - } - - /** - * This return the output index part. Indexes are used for nodes with - * multiple outputs. - */ - public static int indexPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); - } - - /** - * An interface mapping operation attributes to Vespa Values. - * Adapter for differences in different model types. - */ - public interface AttributeMap { - Optional<Value> get(String key); - Optional<Value> get(String key, OrderedTensorType type); - Optional<List<Value>> getList(String key); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java deleted file mode 100644 index c98bcb43331..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Join.java +++ /dev/null @@ -1,120 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.ArrayList; -import java.util.List; -import java.util.function.DoubleBinaryOperator; - -public class Join extends IntermediateOperation { - - private final DoubleBinaryOperator operator; - - public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { - super(modelName, nodeName, inputs); - this.operator = operator; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType a = largestInput().type().get(); - OrderedTensorType b = smallestInput().type().get(); - - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - int sizeDifference = a.rank() - b.rank(); - for (int i = 0; i < a.rank(); ++i) { - TensorType.Dimension aDim = a.dimensions().get(i); - long size = aDim.size().orElse(-1L); - - if (i - sizeDifference >= 0) { - TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference); - size = Math.max(size, bDim.size().orElse(-1L)); - } - - if (aDim.type() == TensorType.Dimension.Type.indexedBound) { - builder.add(TensorType.Dimension.indexed(aDim.name(), size)); - } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) { - builder.add(TensorType.Dimension.indexed(aDim.name())); - } else if (aDim.type() == TensorType.Dimension.Type.mapped) { - builder.add(TensorType.Dimension.mapped(aDim.name())); - } - } - return builder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } - - IntermediateOperation a = largestInput(); - IntermediateOperation b = smallestInput(); - - List<String> aDimensionsToReduce = new ArrayList<>(); - List<String> bDimensionsToReduce = new ArrayList<>(); - int sizeDifference = a.type().get().rank() - b.type().get().rank(); - for (int i = 0; i < b.type().get().rank(); ++i) { - TensorType.Dimension bDim = b.type().get().dimensions().get(i); - TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference); - long bSize = bDim.size().orElse(-1L); - long aSize = aDim.size().orElse(-1L); - if (bSize == 1L && aSize != 1L) { - bDimensionsToReduce.add(bDim.name()); - } - if (aSize == 1L && bSize != 1L) { - aDimensionsToReduce.add(bDim.name()); - } - } - - TensorFunction aReducedFunction = a.function().get(); - if (aDimensionsToReduce.size() > 0) { - aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); - } - TensorFunction bReducedFunction = b.function().get(); - if (bDimensionsToReduce.size() > 0) { - bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); - } - - return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } - OrderedTensorType a = largestInput().type().get(); - OrderedTensorType b = smallestInput().type().get(); - int sizeDifference = a.rank() - b.rank(); - for (int i = 0; i < b.rank(); ++i) { - String bDim = b.dimensions().get(i).name(); - String aDim = a.dimensions().get(i + sizeDifference).name(); - renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); - } - } - - private IntermediateOperation largestInput() { - OrderedTensorType a = inputs.get(0).type().get(); - OrderedTensorType b = inputs.get(1).type().get(); - return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); - } - - private IntermediateOperation smallestInput() { - OrderedTensorType a = inputs.get(0).type().get(); - OrderedTensorType b = inputs.get(1).type().get(); - return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java deleted file mode 100644 index 9bf6836ea9b..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Map.java +++ /dev/null @@ -1,37 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.Optional; -import java.util.function.DoubleUnaryOperator; - -public class Map extends IntermediateOperation { - - private final DoubleUnaryOperator operator; - - public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { - super(modelName, nodeName, inputs); - this.operator = operator; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } - return inputs.get(0).type().get(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) { - return null; - } - Optional<TensorFunction> input = inputs.get(0).function(); - return new com.yahoo.tensor.functions.Map(input.get(), operator); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java deleted file mode 100644 index 287c23080de..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/MatMul.java +++ /dev/null @@ -1,72 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.Optional; - -public class MatMul extends IntermediateOperation { - - public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); - return typeBuilder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - Optional<TensorFunction> aFunction = inputs.get(0).function(); - Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; - } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); - - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); - - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); - - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); - - // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); - } -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java deleted file mode 100644 index f313831ea56..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Mean.java +++ /dev/null @@ -1,113 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; - -public class Mean extends IntermediateOperation { - - private final AttributeMap attributeMap; - private List<String> reduceDimensions; - - public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { - super(modelName, nodeName, inputs); - this.attributeMap = attributeMap; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - IntermediateOperation reductionIndices = inputs.get(1); - if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + name + ": " + - "reduction indices must be a constant."); - } - Tensor indices = reductionIndices.getConstantValue().get().asTensor(); - reduceDimensions = new ArrayList<>(); - - OrderedTensorType inputType = inputs.get(0).type().get(); - for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int dimensionIndex = cell.getValue().intValue(); - if (dimensionIndex < 0) { - dimensionIndex = inputType.dimensions().size() - dimensionIndex; - } - reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); - } - return reducedType(inputType, shouldKeepDimensions()); - } - - // optimization: if keepDims and one reduce dimension that has size 1: same as identity. - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - TensorFunction inputFunction = inputs.get(0).function().get(); - TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); - if (shouldKeepDimensions()) { - // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (String name : reduceDimensions) { - typeBuilder.indexed(name, 1); - } - TensorType generatedType = typeBuilder.build(); - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, - new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); - } - return output; - } - - @Override - public void renameDimensions(DimensionRenamer renamer) { - super.renameDimensions(renamer); - List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size()); - for (String name : reduceDimensions) { - Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { - return; // presumably, already renamed - } - renamedDimensions.add(newName.get()); - } - reduceDimensions = renamedDimensions; - } - - private boolean shouldKeepDimensions() { - Optional<Value> keepDims = attributeMap.get("keep_dims"); - return keepDims.isPresent() && keepDims.get().asBoolean(); - } - - private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (TensorType.Dimension dimension: inputType.type().dimensions()) { - if (!reduceDimensions.contains(dimension.name())) { - builder.add(dimension); - } else if (keepDimensions) { - builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); - } - } - return builder.build(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java deleted file mode 100644 index 9ce0ea6151b..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Merge.java +++ /dev/null @@ -1,35 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; - -public class Merge extends IntermediateOperation { - - public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - for (IntermediateOperation operation : inputs) { - if (operation.type().isPresent()) { - return operation.type().get(); - } - } - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - for (IntermediateOperation operation : inputs) { - if (operation.function().isPresent()) { - return operation.function().get(); - } - } - return null; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java deleted file mode 100644 index f78c945d09f..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/NoOp.java +++ /dev/null @@ -1,26 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.Collections; -import java.util.List; - -public class NoOp extends IntermediateOperation { - - public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs - } - - @Override - protected OrderedTensorType lazyGetType() { - return null; - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java deleted file mode 100644 index 5da3300cd7a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/PlaceholderWithDefault.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.Optional; - -public class PlaceholderWithDefault extends IntermediateOperation { - - public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } - return inputs().get(0).type().orElse(null); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) { - return null; - } - // This should be a call to the function we add below, but for now - // we treat this as as identity function and just pass the constant. - return inputs.get(0).function().orElse(null); - } - - @Override - public Optional<TensorFunction> rankingExpressionFunction() { - // For now, it is much more efficient to assume we always will return - // the default value, as we can prune away large parts of the expression - // tree by having it calculated as a constant. If a case arises where - // it is important to support this, implement this. - return Optional.empty(); - } - - @Override - public boolean isConstant() { - return true; // not true if we add to function - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java deleted file mode 100644 index 3165890ff03..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Reshape.java +++ /dev/null @@ -1,131 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.stream.Collectors; - -public class Reshape extends IntermediateOperation { - - public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - IntermediateOperation newShape = inputs.get(1); - if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + name + ": " + - "shape input must be a constant."); - } - Tensor shape = newShape.getConstantValue().get().asTensor(); - - OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); - if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - OrderedTensorType.tensorSize(inputType.type()).intValue(); - } - outputTypeBuilder.add(TensorType.Dimension.indexed( - String.format("%s_%d", vespaName(), dimensionIndex), size)); - dimensionIndex++; - } - return outputTypeBuilder.build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } - OrderedTensorType inputType = inputs.get(0).type().get(); - TensorFunction inputFunction = inputs.get(0).function().get(); - return reshape(inputFunction, inputType.type(), type.type()); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - for (TensorType.Dimension dimension : type.type().dimensions()) { - renamer.addDimension(dimension.name()); - } - } - - public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { - throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); - } - - // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, - // then use the dimension order of the new shape to roll back into a tensor. - // Here we create a transformation tensor that is multiplied with the from tensor to map into - // the new shape. We have to introduce temporary dimension names and rename back if dimension names - // in the new and old tensor type overlap. - - ExpressionNode unrollFrom = unrollTensorExpression(inputType); - ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); - - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); - Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - - TensorFunction outputFunction = new Reduce( - new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); - - return outputFunction; - } - - private static ExpressionNode unrollTensorExpression(TensorType type) { - if (type.rank() == 0) { - return new ConstantNode(DoubleValue.zero); - } - List<ExpressionNode> children = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); - int size = 1; - for (int i = type.dimensions().size() - 1; i >= 0; --i) { - TensorType.Dimension dimension = type.dimensions().get(i); - children.add(0, new ReferenceNode(dimension.name())); - if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); - children.add(0, new ConstantNode(new DoubleValue(size))); - } - size *= OrderedTensorType.dimensionSize(dimension); - if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); - } - } - return new ArithmeticNode(children, operators); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java deleted file mode 100644 index 176e6523e5f..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Select.java +++ /dev/null @@ -1,88 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.function.DoubleBinaryOperator; - -import static com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType.dimensionSize; -import static com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType.tensorSize; - -public class Select extends IntermediateOperation { - - public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(3)) { - return null; - } - OrderedTensorType a = inputs.get(1).type().get(); - OrderedTensorType b = inputs.get(2).type().get(); - if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) { - throw new IllegalArgumentException("'Select': input tensors must have the same shape"); - } - return a; - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(3)) { - return null; - } - IntermediateOperation conditionOperation = inputs().get(0); - TensorFunction a = inputs().get(1).function().get(); - TensorFunction b = inputs().get(2).function().get(); - - // Shortcut: if we know during import which tensor to select, do that directly here. - if (conditionOperation.getConstantValue().isPresent()) { - Tensor condition = conditionOperation.getConstantValue().get().asTensor(); - if (condition.type().rank() == 0) { - return ((int) condition.asDouble() == 0) ? b : a; - } - if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { - return condition.cellIterator().next().getValue().intValue() == 0 ? b : a; - } - } - - // The task is to select cells from 'x' or 'y' based on 'condition'. - // If 'condition' is 0 (false), select from 'y', if 1 (true) select - // from 'x'. We do this by individually joining 'x' and 'y' with - // 'condition', and then joining the resulting two tensors. - - TensorFunction conditionFunction = conditionOperation.function().get(); - TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply()); - TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() { - @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } - @Override public String toString() { return "f(a,b)(a * (1-b))"; } - }); - return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add()); - } - - @Override - public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(3)) { - return; - } - List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions(); - - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); - - // These tensors should have the same dimension names - renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); - renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java deleted file mode 100644 index 56f05936541..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Shape.java +++ /dev/null @@ -1,54 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; - -public class Shape extends IntermediateOperation { - - public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) { - super(modelName, nodeName, inputs); - createConstantValue(); - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } - OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder() - .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) - .build(); - } - - @Override - protected TensorFunction lazyGetFunction() { - return null; // will be added by function() since this is constant. - } - - @Override - public boolean isConstant() { - return true; - } - - private void createConstantValue() { - if (!allInputTypesPresent(1)) { - return; - } - OrderedTensorType inputType = inputs.get(0).type().get(); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type()); - List<TensorType.Dimension> inputDimensions = inputType.dimensions(); - for (int i = 0; i < inputDimensions.size(); i++) { - builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L)); - } - this.setConstantValue(new TensorValue(builder.build())); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java deleted file mode 100644 index f0946064213..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Squeeze.java +++ /dev/null @@ -1,85 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.DimensionRenamer; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.ArrayList; -import java.util.List; -import java.util.Optional; -import java.util.stream.Collectors; - -public class Squeeze extends IntermediateOperation { - - private final AttributeMap attributeMap; - private List<String> squeezeDimensions; - - public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { - super(modelName, nodeName, inputs); - this.attributeMap = attributeMap; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } - OrderedTensorType inputType = inputs.get(0).type().get(); - squeezeDimensions = new ArrayList<>(); - - Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims"); - if ( ! squeezeDimsAttr.isPresent()) { - squeezeDimensions = inputType.type().dimensions().stream(). - filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } else { - squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue). - map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). - map(i -> inputType.type().dimensions().get(i)). - filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } - - return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); - } - - @Override - protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) { - return null; - } - TensorFunction inputFunction = inputs.get(0).function().get(); - return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); - } - - @Override - public void renameDimensions(DimensionRenamer renamer) { - super.renameDimensions(renamer); - List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size()); - for (String name : squeezeDimensions) { - Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { - return; // presumably, already renamed - } - renamedDimensions.add(newName.get()); - } - squeezeDimensions = renamedDimensions; - } - - private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (TensorType.Dimension dimension: inputType.type().dimensions()) { - if ( ! squeezeDimensions.contains(dimension.name())) { - builder.add(dimension); - } - } - return builder.build(); - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java deleted file mode 100644 index 24212ef175c..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/operations/Switch.java +++ /dev/null @@ -1,50 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml.operations; - -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.tensor.functions.TensorFunction; - -import java.util.List; -import java.util.Optional; - -public class Switch extends IntermediateOperation { - - private final int port; - - public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) { - super(modelName, nodeName, inputs); - this.port = port; - } - - @Override - protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - Optional<OrderedTensorType> predicate = inputs.get(1).type(); - if (predicate.get().type().rank() != 0) { - throw new IllegalArgumentException("Switch in " + name + ": " + - "predicate must be a scalar"); - } - return inputs.get(0).type().orElse(null); - } - - @Override - protected TensorFunction lazyGetFunction() { - IntermediateOperation predicateOperation = inputs().get(1); - if (!predicateOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Switch in " + name + ": " + - "predicate must be a constant"); - } - if (port < 0 || port > 1) { - throw new IllegalArgumentException("Switch in " + name + ": " + - "choice should be boolean"); - } - - double predicate = predicateOperation.getConstantValue().get().asDouble(); - return predicate == port ? inputs().get(0).function().get() : null; - } - -} - - diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java deleted file mode 100644 index 0eb782b5a6e..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/package-info.java +++ /dev/null @@ -1,8 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -/** - * Model integration - */ -@ExportPackage -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import com.yahoo.osgi.annotation.ExportPackage; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java deleted file mode 100644 index 04e9933e40b..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DimensionRenamerTest.java +++ /dev/null @@ -1,48 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import org.junit.Test; - -import static org.junit.Assert.assertTrue; - -public class DimensionRenamerTest { - - @Test - public void testMnistRenaming() { - DimensionRenamer renamer = new DimensionRenamer(); - - renamer.addDimension("first_dimension_of_x"); - renamer.addDimension("second_dimension_of_x"); - renamer.addDimension("first_dimension_of_w"); - renamer.addDimension("second_dimension_of_w"); - renamer.addDimension("first_dimension_of_b"); - - // which dimension to join on matmul - renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); - - // other dimensions in matmul can't be equal - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); - - // for efficiency, put dimension to join on innermost - renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); - renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); - - // bias - renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); - - renamer.solve(); - - String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get(); - String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get(); - String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get(); - String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get(); - String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get(); - - assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0); - assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0); - assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0); - assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0); - assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0); - } - -} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java deleted file mode 100644 index f22e7cb087a..00000000000 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OrderedTensorTypeTestCase.java +++ /dev/null @@ -1,21 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.ml; - -import org.junit.Test; - -import static org.junit.Assert.assertEquals; - -/** - * @author bratseth - */ -public class OrderedTensorTypeTestCase { - - @Test - public void testToFromSpec() { - String spec = "tensor(b[],c{},a[3])"; - OrderedTensorType type = OrderedTensorType.fromSpec(spec); - assertEquals(spec, type.toString()); - assertEquals("tensor(a[3],b[],c{})", type.type().toString()); - } - -} |