diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-11-22 14:27:58 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-11-22 14:27:58 +0100 |
commit | b288e61f7af7331656a1850fbdc58cc95fd1bbad (patch) | |
tree | 9d41fa770d2890585a902f41a89c41040ed764be /model-integration | |
parent | 3c4020645b13be560c14e60969e50e3ad41e3d3c (diff) |
Move all importing to model-integration
Diffstat (limited to 'model-integration')
48 files changed, 2767 insertions, 60 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 28a00dcbdbc..da18d659060 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -20,24 +20,26 @@ <artifactId>junit</artifactId> <scope>test</scope> </dependency> + <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>component</artifactId> + <artifactId>annotations</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>vespajlib</artifactId> + <artifactId>searchlib</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> <dependency> <groupId>com.yahoo.vespa</groupId> - <artifactId>searchlib</artifactId> + <artifactId>vespajlib</artifactId> <version>${project.version}</version> <scope>provided</scope> </dependency> + <dependency> <groupId>com.google.guava</groupId> <artifactId>guava</artifactId> diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java new file mode 100644 index 00000000000..9e9f66be700 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -0,0 +1,210 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * A constraint satisfier to find suitable dimension names to reduce the + * amount of necessary renaming during evaluation of an imported model. + * + * @author lesters + */ +public class DimensionRenamer { + + private final String dimensionPrefix; + private final Map<String, List<Integer>> variables = new HashMap<>(); + private final Map<Arc, Constraint> constraints = new HashMap<>(); + private final Map<String, Integer> renames = new HashMap<>(); + + private int iterations = 0; + + public DimensionRenamer() { + this("d"); + } + + public DimensionRenamer(String dimensionPrefix) { + this.dimensionPrefix = dimensionPrefix; + } + + /** + * Add a dimension name variable. + */ + public void addDimension(String name) { + variables.computeIfAbsent(name, d -> new ArrayList<>()); + } + + /** + * Add a constraint between dimension names. + */ + public void addConstraint(String from, String to, Constraint pred, IntermediateOperation operation) { + Arc arc = new Arc(from, to, operation); + Arc opposite = arc.opposite(); + constraints.put(arc, pred); + constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + } + + /** + * Retrieve resulting name of dimension after solving for constraints. + */ + public Optional<String> dimensionNameOf(String name) { + if (!renames.containsKey(name)) { + return Optional.empty(); + } + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + } + + /** + * Perform iterative arc consistency until we have found a solution. After + * an initial iteration, the variables (dimensions) will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equals, lesserThan and greaterThan do that, but adding notEquals does + * not typically result in a guaranteed ordering. If that is needed, the + * algorithm below needs to be adapted with a backtracking (tree) search + * to find solutions. + */ + private void solve(int maxIterations) { + initialize(); + + // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + + for (String dimension : variables.keySet()) { + List<Integer> values = variables.get(dimension); + if (values.size() > 1) { + if (!ac3()) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution."); + } + values.sort(Integer::compare); + variables.put(dimension, Collections.singletonList(values.get(0))); + } + renames.put(dimension, variables.get(dimension).get(0)); + if (iterations > maxIterations) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + + maxIterations + " iterations"); + } + } + + // Todo: handle failure more gracefully: + // If a solution can't be found, look at the operation node in the arc + // with the most remaining constraints, and inject a rename operation. + // Then run this algorithm again. + } + + void solve() { + solve(100000); + } + + private void initialize() { + for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { + List<Integer> values = variable.getValue(); + for (int i = 0; i < variables.size(); ++i) { + values.add(i); // invariant: values are in increasing order + } + } + } + + private boolean ac3() { + Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); + while (!workList.isEmpty()) { + Arc arc = workList.pop(); + iterations += 1; + if (revise(arc)) { + if (variables.get(arc.from).size() == 0) { + return false; // no solution found + } + for (Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + workList.add(constraint); + } + } + } + } + return true; + } + + private boolean revise(Arc arc) { + boolean revised = false; + for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).test(from, to)) { + satisfied = true; + } + } + if (!satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + public interface Constraint { + boolean test(Integer x, Integer y); + } + + public static boolean equals(Integer x, Integer y) { + return Objects.equals(x, y); + } + + public static boolean lesserThan(Integer x, Integer y) { + return x < y; + } + + public static boolean greaterThan(Integer x, Integer y) { + return x > y; + } + + private static class Arc { + + private final String from; + private final String to; + private final IntermediateOperation operation; + + Arc(String from, String to, IntermediateOperation operation) { + this.from = from; + this.to = to; + this.operation = operation; + } + + Arc opposite() { + return new Arc(to, from, operation); + } + + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof Arc)) { + return false; + } + Arc other = (Arc) obj; + return Objects.equals(from, other.from) && Objects.equals(to, other.to); + } + + @Override + public String toString() { + return String.format("%s -> %s", from, to); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java new file mode 100644 index 00000000000..2866a2c76b2 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModel.java @@ -0,0 +1,226 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.collections.Pair; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.regex.Pattern; + +/** + * The result of importing an ML model into Vespa. + * + * @author bratseth + */ +public class ImportedModel { + + private static final String defaultSignatureName = "default"; + + private static final Pattern nameRegexp = Pattern.compile("[A-Za-z0-9_]*"); + private final String name; + private final String source; + + private final Map<String, Signature> signatures = new HashMap<>(); + private final Map<String, TensorType> inputs = new HashMap<>(); + private final Map<String, Tensor> smallConstants = new HashMap<>(); + private final Map<String, Tensor> largeConstants = new HashMap<>(); + private final Map<String, RankingExpression> expressions = new HashMap<>(); + private final Map<String, RankingExpression> functions = new HashMap<>(); + + /** + * Creates a new imported model. + * + * @param name the name of this mode, containing only characters in [A-Za-z0-9_] + * @param source the source path (directory or file) of this model + */ + public ImportedModel(String name, String source) { + if ( ! nameRegexp.matcher(name).matches()) + throw new IllegalArgumentException("An imported model name can only contain [A-Za-z0-9_], but is '" + name + "'"); + this.name = name; + this.source = source; + } + + /** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */ + public String name() { return name; } + + /** Returns the source path (directory or file) of this model */ + public String source() { return source; } + + /** Returns an immutable map of the inputs of this */ + public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); } + + /** + * Returns an immutable map of the small constants of this. + * These should have sizes up to a few kb at most, and correspond to constant values given in the source model. + */ + public Map<String, Tensor> smallConstants() { return Collections.unmodifiableMap(smallConstants); } + + /** + * Returns an immutable map of the large constants of this. + * These can have sizes in gigabytes and must be distributed to nodes separately from configuration. + * For TensorFlow this corresponds to Variable files stored separately. + */ + public Map<String, Tensor> largeConstants() { return Collections.unmodifiableMap(largeConstants); } + + /** + * Returns an immutable map of the expressions of this - corresponding to graph nodes + * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants). + * Note that only nodes recursively referenced by a placeholder/input are added. + */ + public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); } + + /** + * Returns an immutable map of the functions that are part of this model. + * Note that the functions themselves are *not* copies and *not* immutable - they must be copied before modification. + */ + public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); } + + /** Returns an immutable map of the signatures of this */ + public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } + + /** Returns the given signature. If it does not already exist it is added to this. */ + public Signature signature(String name) { + return signatures.computeIfAbsent(name, Signature::new); + } + + /** Convenience method for returning a default signature */ + public Signature defaultSignature() { return signature(defaultSignatureName); } + + public void input(String name, TensorType argumentType) { inputs.put(name, argumentType); } + public void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); } + public void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); } + public void expression(String name, RankingExpression expression) { expressions.put(name, expression); } + public void function(String name, RankingExpression expression) { functions.put(name, expression); } + + /** + * Returns all the output expressions of this indexed by name. The names consist of one or two parts + * separated by dot, where the first part is the signature name + * if signatures are used, or the expression name if signatures are not used and there are multiple + * expressions, and the second is the output name if signature names are used. + */ + public List<Pair<String, ExpressionFunction>> outputExpressions() { + List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); + for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { + for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) + expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), + signatureEntry.getValue().outputExpression(outputEntry.getKey()) + .withName(signatureEntry.getKey() + "." + outputEntry.getKey()))); + if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs + expressions.add(new Pair<>(signatureEntry.getKey(), + new ExpressionFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().values()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty()))); + } + if (signatures().isEmpty()) { // fallback for models without signatures + if (expressions().size() == 1) { + Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); + expressions.add(new Pair<>(singleEntry.getKey(), + new ExpressionFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty()))); + } + else { + for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { + expressions.add(new Pair<>(expressionEntry.getKey(), + new ExpressionFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty()))); + } + } + } + return expressions; + } + + /** + * A signature is a set of named inputs and outputs, where the inputs maps to input + * ("placeholder") names+types, and outputs maps to expressions nodes. + * Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit + * concept of signatures. For now, we handle ONNX models as having a single signature. + */ + public class Signature { + + private final String name; + private final Map<String, String> inputs = new LinkedHashMap<>(); + private final Map<String, String> outputs = new LinkedHashMap<>(); + private final Map<String, String> skippedOutputs = new HashMap<>(); + private final List<String> importWarnings = new ArrayList<>(); + + Signature(String name) { + this.name = name; + } + + public String name() { return name; } + + /** Returns the result this is part of */ + ImportedModel owner() { return ImportedModel.this; } + + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * in this signature to input name in the owning model + */ + public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns the name and type of all inputs in this signature as an immutable map */ + Map<String, TensorType> inputMap() { + ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); + // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* + // in the model, as these are the names which must actually be bound, if we are to avoid creating an + // "input mapping" to accomodate this complexity in + for (Map.Entry<String, String> inputEntry : inputs().entrySet()) + inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); + return inputs.build(); + } + + /** Returns the type of the input this input references */ + public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expression names of this */ + public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); } + + /** + * Returns an immutable list of the outputs of this which could not be imported, + * with a string detailing the reason for each + */ + public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ + public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } + + /** Returns the expression this output references */ + public ExpressionFunction outputExpression(String outputName) { + return new ExpressionFunction(outputName, + new ArrayList<>(inputs.values()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); + } + + @Override + public String toString() { return "signature '" + name + "'"; } + + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, String expressionName) { outputs.put(name, expressionName); } + void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } + + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java new file mode 100644 index 00000000000..1b7532631e1 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ImportedModels.java @@ -0,0 +1,109 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.path.Path; + +import java.io.File; +import java.util.Arrays; +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Optional; + +/** + * All models imported from the models/ directory in the application package. + * If this is empty it may be due to either not having any models in the application package, + * or this being created for a ZooKeeper application package, which does not have imported models. + * + * @author bratseth + */ +public class ImportedModels { + + /** All imported models, indexed by their names */ + private final ImmutableMap<String, ImportedModel> importedModels; + + /** Create a null imported models */ + public ImportedModels() { + importedModels = ImmutableMap.of(); + } + + public ImportedModels(File modelsDirectory, Collection<ModelImporter> importers) { + Map<String, ImportedModel> models = new HashMap<>(); + + // Find all subdirectories recursively which contains a model we can read + importRecursively(modelsDirectory, models, importers); + importedModels = ImmutableMap.copyOf(models); + } + + private static void importRecursively(File dir, + Map<String, ImportedModel> models, + Collection<ModelImporter> importers) { + if ( ! dir.isDirectory()) return; + + Arrays.stream(dir.listFiles()).sorted().forEach(child -> { + Optional<ModelImporter> importer = findImporterOf(child, importers); + if (importer.isPresent()) { + String name = toName(child); + ImportedModel existing = models.get(name); + if (existing != null) + throw new IllegalArgumentException("The models in " + child + " and " + existing.source() + + " both resolve to the model name '" + name + "'"); + models.put(name, importer.get().importModel(name, child)); + } + else { + importRecursively(child, models, importers); + } + }); + } + + private static Optional<ModelImporter> findImporterOf(File path, Collection<ModelImporter> importers) { + return importers.stream().filter(item -> item.canImport(path.toString())).findFirst(); + } + + /** + * Returns the model at the given location in the application package. + * + * @param modelPath the path to this model (file or directory, depending on model type) + * under the application package, both from the root or relative to the + * models directory works + * @return the model at this path or null if none + */ + public ImportedModel get(File modelPath) { + return importedModels.get(toName(modelPath)); + } + + public ImportedModel get(String modelName) { + return importedModels.get(modelName); + } + + /** Returns an immutable collection of all the imported models */ + public Collection<ImportedModel> all() { + return importedModels.values(); + } + + private static String toName(File modelFile) { + Path modelPath = Path.fromString(modelFile.toString()); + if (modelFile.isFile()) + modelPath = stripFileEnding(modelPath); + String localPath = concatenateAfterModelsDirectory(modelPath); + return localPath.replace('.', '_'); + } + + private static Path stripFileEnding(Path path) { + int dotIndex = path.last().lastIndexOf("."); + if (dotIndex <= 0) return path; + return path.withLast(path.last().substring(0, dotIndex)); + } + + private static String concatenateAfterModelsDirectory(Path path) { + boolean afterModels = false; + StringBuilder result = new StringBuilder(); + for (String element : path.elements()) { + if (afterModels) result.append(element).append("_"); + if (element.equals("models")) afterModels = true; + } + return result.substring(0, result.length()-1); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java new file mode 100644 index 00000000000..aec98d06874 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -0,0 +1,107 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.rankingexpression.importer; + +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; + +import java.util.Collection; +import java.util.HashMap; +import java.util.Map; +import java.util.Set; + +/** + * Holds an intermediate representation of an imported model graph. + * After this intermediate representation is constructed, it is used to + * simplify and optimize the computational graph and then converted into the + * final ImportedModel that holds the Vespa ranking expressions for the model. + * + * @author lesters + */ +public class IntermediateGraph { + + private final String modelName; + private final Map<String, IntermediateOperation> index = new HashMap<>(); + private final Map<String, GraphSignature> signatures = new HashMap<>(); + + private static class GraphSignature { + final Map<String, String> inputs = new HashMap<>(); + final Map<String, String> outputs = new HashMap<>(); + } + + public IntermediateGraph(String modelName) { + this.modelName = modelName; + } + + public String name() { + return modelName; + } + + public IntermediateOperation put(String key, IntermediateOperation operation) { + return index.put(key, operation); + } + + public IntermediateOperation get(String key) { + return index.get(key); + } + + public Set<String> signatures() { + return signatures.keySet(); + } + + public Map<String, String> inputs(String signature) { + return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs; + } + + public Map<String, String> outputs(String signature) { + return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs; + } + + public String defaultSignature() { + return "default"; + } + + public boolean alreadyImported(String key) { + return index.containsKey(key); + } + + public Collection<IntermediateOperation> operations() { + return index.values(); + } + + void optimize() { + renameDimensions(); + } + + /** + * Find dimension names to avoid excessive renaming while evaluating the model. + */ + private void renameDimensions() { + DimensionRenamer renamer = new DimensionRenamer(); + for (String signature : signatures()) { + for (String output : outputs(signature).values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + } + renamer.solve(); + for (String signature : signatures()) { + for (String output : outputs(signature).values()) { + renameDimensions(index.get(output), renamer); + } + } + } + + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } + } + + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java new file mode 100644 index 00000000000..cb095e81147 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -0,0 +1,232 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.yolean.Exceptions; + +import java.io.File; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.logging.Logger; + +/** + * Base class for importing ML models (ONNX/TensorFlow etc.) as native Vespa + * ranking expressions. The general mechanism for import is for the + * specific ML platform import implementations to create an + * IntermediateGraph. This class offers common code to convert the + * IntermediateGraph to Vespa ranking expressions and functions. + * + * @author lesters + */ +public abstract class ModelImporter { + + private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); + + /** Returns whether the file or directory at the given path is of the type which can be imported by this */ + public abstract boolean canImport(String modelPath); + + /** Imports the given model */ + public abstract ImportedModel importModel(String modelName, String modelPath); + + final ImportedModel importModel(String modelName, File modelPath) { + return importModel(modelName, modelPath.toString()); + } + + /** + * Takes an IntermediateGraph and converts it to a ImportedModel containing + * the actual Vespa ranking expressions. + */ + protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { + ImportedModel model = new ImportedModel(graph.name(), modelSource); + + graph.optimize(); + + importSignatures(graph, model); + importExpressions(graph, model); + reportWarnings(graph, model); + logVariableTypes(graph); + + return model; + } + + private static void importSignatures(IntermediateGraph graph, ImportedModel model) { + for (String signatureName : graph.signatures()) { + ImportedModel.Signature signature = model.signature(signatureName); + for (Map.Entry<String, String> input : graph.inputs(signatureName).entrySet()) { + signature.input(input.getKey(), input.getValue()); + } + for (Map.Entry<String, String> output : graph.outputs(signatureName).entrySet()) { + signature.output(output.getKey(), output.getValue()); + } + } + } + + private static boolean isSignatureOutput(ImportedModel model, IntermediateOperation operation) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + if (outputName.equals(operation.name())) { + return true; + } + } + } + return false; + } + + /** + * Convert intermediate representation to Vespa ranking expressions. + */ + private static void importExpressions(IntermediateGraph graph, ImportedModel model) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(graph.get(outputName), model); + if (!function.isPresent()) { + signature.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } + } + + private static Optional<TensorFunction> importExpression(IntermediateOperation operation, ImportedModel model) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(operation, model); + } + importExpressionInputs(operation, model); + importRankingExpression(operation, model); + importArgumentExpression(operation, model); + importFunctionExpression(operation, model); + + return operation.function(); + } + + private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { + operation.inputs().forEach(input -> importExpression(input, model)); + } + + private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); + } + + Value value = operation.getConstantValue().orElseThrow(() -> + new IllegalArgumentException("Operation '" + operation.vespaName() + "' " + + "is constant but does not have a value.")); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + + Tensor tensor = value.asTensor(); + if (tensor.type().rank() == 0) { + model.smallConstant(name, tensor); + } else { + model.largeConstant(name, tensor); + } + return operation.function(); + } + + private static void importRankingExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.function().isPresent()) { + String name = operation.name(); + if ( ! model.expressions().containsKey(name)) { + TensorFunction function = operation.function().get(); + + if (isSignatureOutput(model, operation)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.standardType(operationType); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } + + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced from the output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Imported function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + } + + private static void importArgumentExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.isInput()) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get()); + model.input(operation.vespaName(), standardNamingConvention.type()); + } + } + + private static void importFunctionExpression(IntermediateOperation operation, ImportedModel model) { + if (operation.rankingExpressionFunction().isPresent()) { + TensorFunction function = operation.rankingExpressionFunction().get(); + try { + model.function(operation.rankingExpressionFunctionName(), + new RankingExpression(operation.rankingExpressionFunctionName(), + function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Model function " + function + + " cannot be parsed as a ranking expression", e); + } + } + } + + /** + * Add any import warnings to the signature in the ImportedModel. + */ + private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { + for (ImportedModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + reportWarnings(graph.get(outputName), model); + } + } + } + + private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { + for (String warning : operation.warnings()) { + model.defaultSignature().importWarning(warning); + } + for (IntermediateOperation input : operation.inputs()) { + reportWarnings(input, model); + } + } + + /** + * Log all model Variables (i.e file constants) imported as part of this with their ordered type. + * This allows users to learn the exact types (including dimension order after renaming) of the Variables + * such that these can be converted and fed to a parent document independently of the rest of the model + * for fast model weight updates. + */ + private static void logVariableTypes(IntermediateGraph graph) { + for (IntermediateOperation operation : graph.operations()) { + if ( ! (operation instanceof Constant)) continue; + if ( ! operation.type().isPresent()) continue; // will not happen + log.info("Importing model variable " + operation.name() + " as " + operation.vespaName() + + " of type " + operation.type().get()); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java new file mode 100644 index 00000000000..c4acfeb3235 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -0,0 +1,235 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.rankingexpression.importer; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TensorTypeParser; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A Vespa tensor type is ordered by the lexicographical ordering of dimension + * names. Imported tensors have an explicit ordering of their dimensions. + * During import, we need to track the Vespa dimension that matches the + * corresponding imported dimension as the ordering can change after + * dimension renaming. That is the purpose of this class. + * + * @author lesters + */ +public class OrderedTensorType { + + private final TensorType type; + private final List<TensorType.Dimension> dimensions; + + private final long[] innerSizesOriginal; + private final long[] innerSizesVespa; + private final int[] dimensionMap; + + private OrderedTensorType(List<TensorType.Dimension> dimensions) { + this.dimensions = Collections.unmodifiableList(dimensions); + this.type = new TensorType.Builder(dimensions).build(); + this.innerSizesOriginal = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + public TensorType type() { return this.type; } + + public int rank() { return dimensions.size(); } + + public List<TensorType.Dimension> dimensions() { + return dimensions; + } + + public List<String> dimensionNames() { + return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); + } + + private int[] createDimensionMap() { + int numDimensions = dimensions.size(); + if (numDimensions == 0) { + return null; + } + innerSizesOriginal[numDimensions - 1] = 1; + innerSizesVespa[numDimensions - 1] = 1; + for (int i = numDimensions - 1; --i >= 0; ) { + innerSizesOriginal[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesOriginal[i+1]; + innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; + } + int[] mapping = new int[numDimensions]; + for (int i = 0; i < numDimensions; ++i) { + TensorType.Dimension dim1 = dimensions().get(i); + for (int j = 0; j < numDimensions; ++j) { + TensorType.Dimension dim2 = type.dimensions().get(j); + if (dim1.equals(dim2)) { + mapping[i] = j; + break; + } + } + } + return mapping; + } + + public int dimensionMap(int originalIndex) { + return dimensionMap[originalIndex]; + } + + /** + * When dimension ordering between Vespa and imported differs, i.e. + * after dimension renaming, use the dimension map to read in values + * so that they are correctly laid out in memory for Vespa. + * Used when importing tensors. + */ + public int toDirectIndex(int index) { + if (dimensions.size() == 0) { + return 0; + } + if (dimensionMap == null) { + throw new IllegalArgumentException("Dimension map is not available"); + } + int directIndex = 0; + long rest = index; + for (int i = 0; i < dimensions.size(); ++i) { + long address = rest / innerSizesOriginal[i]; + directIndex += innerSizesVespa[dimensionMap[i]] * address; + rest %= innerSizesOriginal[i]; + } + return directIndex; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof OrderedTensorType)) { + return false; + } + OrderedTensorType other = (OrderedTensorType) obj; + if (dimensions.size() != dimensions.size()) { + return false; + } + List<TensorType.Dimension> thisDimensions = this.dimensions(); + List<TensorType.Dimension> otherDimensions = other.dimensions(); + for (int i = 0; i < thisDimensions.size(); ++i) { + if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { + return false; + } + } + return true; + } + + public OrderedTensorType rename(DimensionRenamer renamer) { + List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); + for (TensorType.Dimension dimension : dimensions) { + String oldName = dimension.name(); + Optional<String> newName = renamer.dimensionNameOf(oldName); + if (!newName.isPresent()) + return this; // presumably, already renamed + TensorType.Dimension.Type dimensionType = dimension.type(); + if (dimensionType == TensorType.Dimension.Type.indexedBound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); + } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); + } else if (dimensionType == TensorType.Dimension.Type.mapped) { + renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); + } + } + return new OrderedTensorType(renamedDimensions); + } + + public OrderedTensorType rename(String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < dimensions.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + Optional<Long> dimSize = dimensions.get(i).size(); + if (dimSize.isPresent() && dimSize.get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize.get())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static OrderedTensorType standardType(OrderedTensorType type) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < type.dimensions().size(); ++ i) { + TensorType.Dimension dim = type.dimensions().get(i); + String dimensionName = "d" + i; + if (dim.size().isPresent() && dim.size().get() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dim.size().get())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static Long tensorSize(TensorType type) { + Long size = 1L; + for (TensorType.Dimension dimension : type.dimensions()) { + size *= dimensionSize(dimension); + } + return size; + } + + public static Long dimensionSize(TensorType.Dimension dim) { + return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); + } + + /** + * Returns a string representation of this: A standard tensor type string where dimensions + * are listed in the order of this rather than in the natural order of their names. + */ + @Override + public String toString() { + return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")"; + } + + /** + * Creates an instance from the string representation of this: A standard tensor type string + * where dimensions are listed in the order of this rather than the natural order of their names. + */ + public static OrderedTensorType fromSpec(String typeSpec) { + return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + } + + public static OrderedTensorType fromDimensionList(List<Long> dims) { + return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... + } + + private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (int i = 0; i < dims.size(); ++ i) { + String dimensionName = dimensionPrefix + i; + Long dimSize = dims.get(i); + if (dimSize >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static class Builder { + + private final List<TensorType.Dimension> dimensions; + + public Builder() { + this.dimensions = new ArrayList<>(); + } + + public Builder add(TensorType.Dimension vespaDimension) { + this.dimensions.add(vespaDimension); + return this; + } + + public OrderedTensorType build() { + return new OrderedTensorType(dimensions); + } + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index caf66baef66..dd2add973e4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -3,19 +3,19 @@ package ai.vespa.rankingexpression.importer.onnx; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Shape; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Argument; +import ai.vespa.rankingexpression.importer.operations.ConcatV2; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.Identity; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Join; +import ai.vespa.rankingexpression.importer.operations.Map; +import ai.vespa.rankingexpression.importer.operations.MatMul; +import ai.vespa.rankingexpression.importer.operations.NoOp; +import ai.vespa.rankingexpression.importer.operations.Reshape; +import ai.vespa.rankingexpression.importer.operations.Shape; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java index cb591475d40..0a8a797a847 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/OnnxImporter.java @@ -2,9 +2,9 @@ package ai.vespa.rankingexpression.importer.onnx; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.ModelImporter; import onnx.Onnx; import java.io.File; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java index a267411e8a9..f3d87d89c27 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TensorConverter.java @@ -3,7 +3,7 @@ package ai.vespa.rankingexpression.importer.onnx; import com.google.protobuf.ByteString; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import onnx.Onnx; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 3da7477fcef..f251a14213b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.onnx; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import onnx.Onnx; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java new file mode 100644 index 00000000000..d6ea00ca453 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.List; + +public class Argument extends IntermediateOperation { + + private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... + + public Argument(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); + this.type = type.rename(vespaName() + "_"); + standardNamingType = OrderedTensorType.standardType(type); + } + + @Override + protected OrderedTensorType lazyGetType() { + return type; + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + if (!standardNamingType.equals(type)) { + List<String> renameFrom = standardNamingType.dimensionNames(); + List<String> renameTo = type.dimensionNames(); + output = new Rename(output, renameFrom, renameTo); + } + return output; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isInput() { + return true; + } + + @Override + public boolean isConstant() { + return false; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java new file mode 100644 index 00000000000..a21fc5ff2f7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -0,0 +1,108 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class ConcatV2 extends IntermediateOperation { + + private String concatDimensionName; + + public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { + return null; + } + + IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input + if (!concatDimOp.getConstantValue().isPresent()) { + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + "concat dimension must be a constant."); + } + Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); + if (concatDimTensor.type().rank() != 0) { + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + "concat dimension must be a scalar."); + } + + OrderedTensorType aType = inputs.get(0).type().get(); + + int concatDim = (int)concatDimTensor.asDouble(); + long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L); + + for (int i = 1; i < inputs.size() - 1; ++i) { + OrderedTensorType bType = inputs.get(i).type().get(); + if (bType.rank() != aType.rank()) { + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + "inputs must have save rank."); + } + for (int j = 0; j < aType.rank(); ++j) { + long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); + long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); + if (j == concatDim) { + concatDimSize += dimSizeB; + } else if (dimSizeA != dimSizeB) { + throw new IllegalArgumentException("ConcatV2 in " + name + ": " + + "input dimension " + j + " differs in input tensors."); + } + } + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + int dimensionIndex = 0; + for (TensorType.Dimension dimension : aType.dimensions()) { + if (dimensionIndex == concatDim) { + concatDimensionName = dimension.name(); + typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize)); + } else { + typeBuilder.add(dimension); + } + dimensionIndex++; + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!inputs.stream().map(IntermediateOperation::function).allMatch(Optional::isPresent)) { + return null; + } + TensorFunction result = inputs.get(0).function().get(); + for (int i = 1; i < inputs.size() - 1; ++i) { + TensorFunction b = inputs.get(i).function().get(); + result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName); + } + return result; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { + return; + } + OrderedTensorType a = inputs.get(0).type().get(); + for (int i = 1; i < inputs.size() - 1; ++i) { + OrderedTensorType b = inputs.get(i).type().get(); + String bDim = b.dimensions().get(i).name(); + String aDim = a.dimensions().get(i).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java new file mode 100644 index 00000000000..41d421b1f5a --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -0,0 +1,89 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class Const extends IntermediateOperation { + + private final AttributeMap attributeMap; + + public Const(String modelName, + String nodeName, + List<IntermediateOperation> inputs, + AttributeMap attributeMap, + OrderedTensorType type) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + this.type = type.rename(vespaName() + "_"); + setConstantValue(value()); + } + + @Override + protected OrderedTensorType lazyGetType() { + return type; + } + + @Override + public Optional<TensorFunction> function() { + if (function == null) { + function = lazyGetFunction(); + } + return Optional.ofNullable(function); + } + + @Override + protected TensorFunction lazyGetFunction() { + ExpressionNode expressionNode; + if (type.type().rank() == 0 && getConstantValue().isPresent()) { + expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); + } else { + expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); + } + return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); + } + + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName + "_" + super.vespaName(); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + setConstantValue(value()); + } + + @Override + public boolean isConstant() { + return true; + } + + private Value value() { + Optional<Value> value = attributeMap.get("value", type); + if ( ! value.isPresent()) { + throw new IllegalArgumentException("Node '" + name + "' of type " + + "const has missing or non-recognized 'value' attribute"); + } + return value.get(); + } +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java new file mode 100644 index 00000000000..a1cc83296b0 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -0,0 +1,61 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.Optional; + +public class Constant extends IntermediateOperation { + + private final String modelName; + + public Constant(String modelName, String nodeName, OrderedTensorType type) { + super(modelName, nodeName, Collections.emptyList()); + this.modelName = modelName; + this.type = type.rename(vespaName() + "_"); + } + + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName + "_" + vespaName(name); + } + + @Override + protected OrderedTensorType lazyGetType() { + return type; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + /** + * Constant values are sent in via the constantValueFunction, as the + * dimension names and thus the data layout depends on the dimension + * renaming which happens after the conversion to intermediate graph. + */ + @Override + public Optional<Value> getConstantValue() { + return Optional.ofNullable(constantValueFunction).map(func -> func.apply(type)); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java new file mode 100644 index 00000000000..8ae6d81b8d4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -0,0 +1,106 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class ExpandDims extends IntermediateOperation { + + private List<String> expandDimensions; + + public ExpandDims(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + + IntermediateOperation axisOperation = inputs().get(1); + if (!axisOperation.getConstantValue().isPresent()) { + throw new IllegalArgumentException("ExpandDims in " + name + ": " + + "axis must be a constant."); + } + Tensor axis = axisOperation.getConstantValue().get().asTensor(); + if (axis.type().rank() != 0) { + throw new IllegalArgumentException("ExpandDims in " + name + ": " + + "axis argument must be a scalar."); + } + + OrderedTensorType inputType = inputs.get(0).type().get(); + int dimensionToInsert = (int)axis.asDouble(); + if (dimensionToInsert < 0) { + dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + expandDimensions = new ArrayList<>(); + int dimensionIndex = 0; + for (TensorType.Dimension dimension : inputType.dimensions()) { + if (dimensionIndex == dimensionToInsert) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + typeBuilder.add(dimension); + dimensionIndex++; + } + + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) { + return null; + } + + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : expandDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); + for (String name : expandDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + expandDimensions = renamedDimensions; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java new file mode 100644 index 00000000000..c2787aa14d4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -0,0 +1,35 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Identity extends IntermediateOperation { + + public Identity(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ + @Override + public String vespaName() { + return modelName + "_" + super.vespaName(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) + return null; + return inputs.get(0).function().orElse(null); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java new file mode 100644 index 00000000000..60fba264635 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -0,0 +1,191 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Wraps an imported operation node and produces the respective Vespa tensor + * operation. During import, a graph of these operations are constructed. Then, + * the types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper Vespa + * expressions can be extracted. + * + * @author lesters + */ +public abstract class IntermediateOperation { + + public final static String FUNCTION_PREFIX = "imported_ml_function_"; + + protected final String name; + protected final String modelName; + protected final List<IntermediateOperation> inputs; + protected final List<IntermediateOperation> outputs = new ArrayList<>(); + + protected OrderedTensorType type; + protected TensorFunction function; + protected TensorFunction rankingExpressionFunction = null; + + private final List<String> importWarnings = new ArrayList<>(); + private Value constantValue = null; + private List<IntermediateOperation> controlInputs = Collections.emptyList(); + + protected Function<OrderedTensorType, Value> constantValueFunction = null; + + IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { + this.name = name; + this.modelName = modelName; + this.inputs = Collections.unmodifiableList(inputs); + this.inputs.forEach(i -> i.outputs.add(this)); + } + + protected abstract OrderedTensorType lazyGetType(); + protected abstract TensorFunction lazyGetFunction(); + + /** Returns the Vespa tensor type of this operation if it exists */ + public Optional<OrderedTensorType> type() { + if (type == null) { + type = lazyGetType(); + } + return Optional.ofNullable(type); + } + + /** Returns the Vespa tensor function implementing all operations from this node with inputs */ + public Optional<TensorFunction> function() { + if (function == null) { + if (isConstant()) { + ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); + function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + } else if (outputs.size() > 1) { + rankingExpressionFunction = lazyGetFunction(); + function = new VariableTensor(rankingExpressionFunctionName(), type.type()); + } else { + function = lazyGetFunction(); + } + } + return Optional.ofNullable(function); + } + + /** Returns original name of this operation node */ + public String name() { return name; } + + /** Return unmodifiable list of inputs */ + public List<IntermediateOperation> inputs() { return inputs; } + + /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a function. */ + public List<IntermediateOperation> outputs() { return Collections.unmodifiableList(outputs); } + + /** Returns a function that should be added as a ranking expression function */ + public Optional<TensorFunction> rankingExpressionFunction() { + return Optional.ofNullable(rankingExpressionFunction); + } + + /** Add dimension name constraints for this operation */ + public void addDimensionNameConstraints(DimensionRenamer renamer) { } + + /** Performs dimension rename for this operation */ + public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + + /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ + public boolean isInput() { return false; } + + /** Return true if this node is constant */ + public boolean isConstant() { return inputs.stream().allMatch(IntermediateOperation::isConstant); } + + /** Sets the constant value */ + public void setConstantValue(Value value) { constantValue = value; } + + /** Gets the constant value if it exists */ + public Optional<Value> getConstantValue() { + if (constantValue != null) { + return Optional.of(constantValue); + } + if (constantValueFunction != null) { + return Optional.of(constantValueFunction.apply(type)); + } + return Optional.empty(); + } + + /** Set the constant value function */ + public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } + + /** Sets the external control inputs */ + public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } + + /** Retrieve the control inputs for this operation */ + public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } + + /** Retrieve the valid Vespa name of this node */ + public String vespaName() { return vespaName(name); } + public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_') : null; } + + /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ + public String rankingExpressionFunctionName() { + return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + } + + /** Retrieve the list of warnings produced during its lifetime */ + public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } + + /** Set an input warning */ + public void warning(String warning) { importWarnings.add(warning); } + + boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) { + if (inputs.size() != expected) { + throw new IllegalArgumentException("Expected " + expected + " inputs " + + "for '" + name + "', got " + inputs.size()); + } + return inputs.stream().map(func).allMatch(Optional::isPresent); + } + + boolean allInputTypesPresent(int expected) { + return verifyInputs(expected, IntermediateOperation::type); + } + + boolean allInputFunctionsPresent(int expected) { + return verifyInputs(expected, IntermediateOperation::function); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + public static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output index part. Indexes are used for nodes with + * multiple outputs. + */ + public static int indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + + /** + * An interface mapping operation attributes to Vespa Values. + * Adapter for differences in different model types. + */ + public interface AttributeMap { + Optional<Value> get(String key); + Optional<Value> get(String key, OrderedTensorType type); + Optional<List<Value>> getList(String key); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java new file mode 100644 index 00000000000..fed95e13bb7 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -0,0 +1,120 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.function.DoubleBinaryOperator; + +public class Join extends IntermediateOperation { + + private final DoubleBinaryOperator operator; + + public Join(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleBinaryOperator operator) { + super(modelName, nodeName, inputs); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType a = largestInput().type().get(); + OrderedTensorType b = smallestInput().type().get(); + + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < a.rank(); ++i) { + TensorType.Dimension aDim = a.dimensions().get(i); + long size = aDim.size().orElse(-1L); + + if (i - sizeDifference >= 0) { + TensorType.Dimension bDim = b.dimensions().get(i - sizeDifference); + size = Math.max(size, bDim.size().orElse(-1L)); + } + + if (aDim.type() == TensorType.Dimension.Type.indexedBound) { + builder.add(TensorType.Dimension.indexed(aDim.name(), size)); + } else if (aDim.type() == TensorType.Dimension.Type.indexedUnbound) { + builder.add(TensorType.Dimension.indexed(aDim.name())); + } else if (aDim.type() == TensorType.Dimension.Type.mapped) { + builder.add(TensorType.Dimension.mapped(aDim.name())); + } + } + return builder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + if (!allInputFunctionsPresent(2)) { + return null; + } + + IntermediateOperation a = largestInput(); + IntermediateOperation b = smallestInput(); + + List<String> aDimensionsToReduce = new ArrayList<>(); + List<String> bDimensionsToReduce = new ArrayList<>(); + int sizeDifference = a.type().get().rank() - b.type().get().rank(); + for (int i = 0; i < b.type().get().rank(); ++i) { + TensorType.Dimension bDim = b.type().get().dimensions().get(i); + TensorType.Dimension aDim = a.type().get().dimensions().get(i + sizeDifference); + long bSize = bDim.size().orElse(-1L); + long aSize = aDim.size().orElse(-1L); + if (bSize == 1L && aSize != 1L) { + bDimensionsToReduce.add(bDim.name()); + } + if (aSize == 1L && bSize != 1L) { + aDimensionsToReduce.add(bDim.name()); + } + } + + TensorFunction aReducedFunction = a.function().get(); + if (aDimensionsToReduce.size() > 0) { + aReducedFunction = new Reduce(a.function().get(), Reduce.Aggregator.sum, aDimensionsToReduce); + } + TensorFunction bReducedFunction = b.function().get(); + if (bDimensionsToReduce.size() > 0) { + bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); + } + + return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + OrderedTensorType a = largestInput().type().get(); + OrderedTensorType b = smallestInput().type().get(); + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < b.rank(); ++i) { + String bDim = b.dimensions().get(i).name(); + String aDim = a.dimensions().get(i + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + } + } + + private IntermediateOperation largestInput() { + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + return a.rank() >= b.rank() ? inputs.get(0) : inputs.get(1); + } + + private IntermediateOperation smallestInput() { + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java new file mode 100644 index 00000000000..e0842d820f9 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -0,0 +1,37 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleUnaryOperator; + +public class Map extends IntermediateOperation { + + private final DoubleUnaryOperator operator; + + public Map(String modelName, String nodeName, List<IntermediateOperation> inputs, DoubleUnaryOperator operator) { + super(modelName, nodeName, inputs); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + Optional<TensorFunction> input = inputs.get(0).function(); + return new com.yahoo.tensor.functions.Map(input.get(), operator); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java new file mode 100644 index 00000000000..1dbfd6e40dc --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -0,0 +1,72 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class MatMul extends IntermediateOperation { + + public MatMul(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); + typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + if (aType.type().rank() < 2 || bType.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (aType.type().rank() != bType.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // The second dimension of a should have the same name as the first dimension of b + renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + + // The first dimension of a should have a different name than the second dimension of b + renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + + // For efficiency, the dimensions to join over should be innermost - soft constraint + renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + } +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java new file mode 100644 index 00000000000..4be220db9d5 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -0,0 +1,113 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +public class Mean extends IntermediateOperation { + + private final AttributeMap attributeMap; + private List<String> reduceDimensions; + + public Mean(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + IntermediateOperation reductionIndices = inputs.get(1); + if (!reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Mean in " + name + ": " + + "reduction indices must be a constant."); + } + Tensor indices = reductionIndices.getConstantValue().get().asTensor(); + reduceDimensions = new ArrayList<>(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int dimensionIndex = cell.getValue().intValue(); + if (dimensionIndex < 0) { + dimensionIndex = inputType.dimensions().size() - dimensionIndex; + } + reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); + } + return reducedType(inputType, shouldKeepDimensions()); + } + + // optimization: if keepDims and one reduce dimension that has size 1: same as identity. + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); + if (shouldKeepDimensions()) { + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + } + return output; + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size()); + for (String name : reduceDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + reduceDimensions = renamedDimensions; + } + + private boolean shouldKeepDimensions() { + Optional<Value> keepDims = attributeMap.get("keep_dims"); + return keepDims.isPresent() && keepDims.get().asBoolean(); + } + + private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if (!reduceDimensions.contains(dimension.name())) { + builder.add(dimension); + } else if (keepDimensions) { + builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); + } + } + return builder.build(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java new file mode 100644 index 00000000000..ce0c58971d0 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Merge.java @@ -0,0 +1,35 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Merge extends IntermediateOperation { + + public Merge(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + for (IntermediateOperation operation : inputs) { + if (operation.type().isPresent()) { + return operation.type().get(); + } + } + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + for (IntermediateOperation operation : inputs) { + if (operation.function().isPresent()) { + return operation.function().get(); + } + } + return null; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java new file mode 100644 index 00000000000..4c5ce33b1b5 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/NoOp.java @@ -0,0 +1,26 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends IntermediateOperation { + + public NoOp(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, Collections.emptyList()); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java new file mode 100644 index 00000000000..e5e5c29f8f1 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/PlaceholderWithDefault.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class PlaceholderWithDefault extends IntermediateOperation { + + public PlaceholderWithDefault(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + return inputs().get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + // This should be a call to the function we add below, but for now + // we treat this as as identity function and just pass the constant. + return inputs.get(0).function().orElse(null); + } + + @Override + public Optional<TensorFunction> rankingExpressionFunction() { + // For now, it is much more efficient to assume we always will return + // the default value, as we can prune away large parts of the expression + // tree by having it calculated as a constant. If a case arises where + // it is important to support this, implement this. + return Optional.empty(); + } + + @Override + public boolean isConstant() { + return true; // not true if we add to function + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java new file mode 100644 index 00000000000..18f3cc1cc39 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -0,0 +1,131 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +public class Reshape extends IntermediateOperation { + + public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + IntermediateOperation newShape = inputs.get(1); + if (!newShape.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Reshape in " + name + ": " + + "shape input must be a constant."); + } + Tensor shape = newShape.getConstantValue().get().asTensor(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); + int dimensionIndex = 0; + for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int size = cell.getValue().intValue(); + if (size < 0) { + size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / + OrderedTensorType.tensorSize(inputType.type()).intValue(); + } + outputTypeBuilder.add(TensorType.Dimension.indexed( + String.format("%s_%d", vespaName(), dimensionIndex), size)); + dimensionIndex++; + } + return outputTypeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + if (!allInputFunctionsPresent(2)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + TensorFunction inputFunction = inputs.get(0).function().get(); + return reshape(inputFunction, inputType.type(), type.type()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { + if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { + throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + } + + // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, + // then use the dimension order of the new shape to roll back into a tensor. + // Here we create a transformation tensor that is multiplied with the from tensor to map into + // the new shape. We have to introduce temporary dimension names and rename back if dimension names + // in the new and old tensor type overlap. + + ExpressionNode unrollFrom = unrollTensorExpression(inputType); + ExpressionNode unrollTo = unrollTensorExpression(outputType); + ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); + + TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + Generate transformTensor = new Generate(transformationType, + new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); + + TensorFunction outputFunction = new Reduce( + new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + + return outputFunction; + } + + private static ExpressionNode unrollTensorExpression(TensorType type) { + if (type.rank() == 0) { + return new ConstantNode(DoubleValue.zero); + } + List<ExpressionNode> children = new ArrayList<>(); + List<ArithmeticOperator> operators = new ArrayList<>(); + int size = 1; + for (int i = type.dimensions().size() - 1; i >= 0; --i) { + TensorType.Dimension dimension = type.dimensions().get(i); + children.add(0, new ReferenceNode(dimension.name())); + if (size > 1) { + operators.add(0, ArithmeticOperator.MULTIPLY); + children.add(0, new ConstantNode(new DoubleValue(size))); + } + size *= OrderedTensorType.dimensionSize(dimension); + if (i > 0) { + operators.add(0, ArithmeticOperator.PLUS); + } + } + return new ArithmeticNode(children, operators); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java new file mode 100644 index 00000000000..dc690329a8d --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Select.java @@ -0,0 +1,88 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.function.DoubleBinaryOperator; + +import static ai.vespa.rankingexpression.importer.OrderedTensorType.dimensionSize; +import static ai.vespa.rankingexpression.importer.OrderedTensorType.tensorSize; + +public class Select extends IntermediateOperation { + + public Select(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(3)) { + return null; + } + OrderedTensorType a = inputs.get(1).type().get(); + OrderedTensorType b = inputs.get(2).type().get(); + if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) { + throw new IllegalArgumentException("'Select': input tensors must have the same shape"); + } + return a; + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(3)) { + return null; + } + IntermediateOperation conditionOperation = inputs().get(0); + TensorFunction a = inputs().get(1).function().get(); + TensorFunction b = inputs().get(2).function().get(); + + // Shortcut: if we know during import which tensor to select, do that directly here. + if (conditionOperation.getConstantValue().isPresent()) { + Tensor condition = conditionOperation.getConstantValue().get().asTensor(); + if (condition.type().rank() == 0) { + return ((int) condition.asDouble() == 0) ? b : a; + } + if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { + return condition.cellIterator().next().getValue().intValue() == 0 ? b : a; + } + } + + // The task is to select cells from 'x' or 'y' based on 'condition'. + // If 'condition' is 0 (false), select from 'y', if 1 (true) select + // from 'x'. We do this by individually joining 'x' and 'y' with + // 'condition', and then joining the resulting two tensors. + + TensorFunction conditionFunction = conditionOperation.function().get(); + TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply()); + TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() { + @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } + @Override public String toString() { return "f(a,b)(a * (1-b))"; } + }); + return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(3)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // These tensors should have the same dimension names + renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java new file mode 100644 index 00000000000..361729a8c14 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Shape extends IntermediateOperation { + + public Shape(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + createConstantValue(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + return new OrderedTensorType.Builder() + .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) + .build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public boolean isConstant() { + return true; + } + + private void createConstantValue() { + if (!allInputTypesPresent(1)) { + return; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type()); + List<TensorType.Dimension> inputDimensions = inputType.dimensions(); + for (int i = 0; i < inputDimensions.size(); i++) { + builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L)); + } + this.setConstantValue(new TensorValue(builder.build())); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java new file mode 100644 index 00000000000..2eeefcbe8a2 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -0,0 +1,85 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class Squeeze extends IntermediateOperation { + + private final AttributeMap attributeMap; + private List<String> squeezeDimensions; + + public Squeeze(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributeMap) { + super(modelName, nodeName, inputs); + this.attributeMap = attributeMap; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + squeezeDimensions = new ArrayList<>(); + + Optional<List<Value>> squeezeDimsAttr = attributeMap.getList("squeeze_dims"); + if ( ! squeezeDimsAttr.isPresent()) { + squeezeDimensions = inputType.type().dimensions().stream(). + filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } else { + squeezeDimensions = squeezeDimsAttr.get().stream().map(Value::asDouble).map(Double::intValue). + map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). + map(i -> inputType.type().dimensions().get(i)). + filter(dim -> OrderedTensorType.dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } + + return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + TensorFunction inputFunction = inputs.get(0).function().get(); + return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size()); + for (String name : squeezeDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + squeezeDimensions = renamedDimensions; + } + + private OrderedTensorType reducedType(OrderedTensorType inputType) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if ( ! squeezeDimensions.contains(dimension.name())) { + builder.add(dimension); + } + } + return builder.build(); + } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java new file mode 100644 index 00000000000..131af8de065 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Switch.java @@ -0,0 +1,50 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; +import java.util.Optional; + +public class Switch extends IntermediateOperation { + + private final int port; + + public Switch(String modelName, String nodeName, List<IntermediateOperation> inputs, int port) { + super(modelName, nodeName, inputs); + this.port = port; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + Optional<OrderedTensorType> predicate = inputs.get(1).type(); + if (predicate.get().type().rank() != 0) { + throw new IllegalArgumentException("Switch in " + name + ": " + + "predicate must be a scalar"); + } + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + IntermediateOperation predicateOperation = inputs().get(1); + if (!predicateOperation.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Switch in " + name + ": " + + "predicate must be a constant"); + } + if (port < 0 || port > 1) { + throw new IllegalArgumentException("Switch in " + name + ": " + + "choice should be boolean"); + } + + double predicate = predicateOperation.getConstantValue().get().asDouble(); + return predicate == port ? inputs().get(0).function().get() : null; + } + +} + + diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java new file mode 100644 index 00000000000..4473f306dcd --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/package-info.java @@ -0,0 +1,11 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +/** + * Model integration. + * + * CAUTION!: Config models depends on this API. It cannot be changed without ensuring compatibility with + * old config models. + */ +@ExportPackage +package ai.vespa.rankingexpression.importer; + +import com.yahoo.osgi.annotation.ExportPackage; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java index 978fc3ecf60..ecb67f93d69 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java @@ -5,8 +5,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java index e264b0daf6e..cb838cd67b1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java @@ -3,27 +3,27 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Argument; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.ConcatV2; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Const; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Constant; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.ExpandDims; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Identity; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.IntermediateOperation; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Join; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Map; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.MatMul; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Mean; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Merge; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.NoOp; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.PlaceholderWithDefault; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Reshape; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Select; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Shape; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Squeeze; -import com.yahoo.searchlib.rankingexpression.integration.ml.operations.Switch; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import ai.vespa.rankingexpression.importer.operations.Argument; +import ai.vespa.rankingexpression.importer.operations.ConcatV2; +import ai.vespa.rankingexpression.importer.operations.Const; +import ai.vespa.rankingexpression.importer.operations.Constant; +import ai.vespa.rankingexpression.importer.operations.ExpandDims; +import ai.vespa.rankingexpression.importer.operations.Identity; +import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; +import ai.vespa.rankingexpression.importer.operations.Join; +import ai.vespa.rankingexpression.importer.operations.Map; +import ai.vespa.rankingexpression.importer.operations.MatMul; +import ai.vespa.rankingexpression.importer.operations.Mean; +import ai.vespa.rankingexpression.importer.operations.Merge; +import ai.vespa.rankingexpression.importer.operations.NoOp; +import ai.vespa.rankingexpression.importer.operations.PlaceholderWithDefault; +import ai.vespa.rankingexpression.importer.operations.Reshape; +import ai.vespa.rankingexpression.importer.operations.Select; +import ai.vespa.rankingexpression.importer.operations.Shape; +import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Switch; import com.yahoo.tensor.functions.ScalarFunctions; import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 80c23f2af69..6c92ffa6055 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index a8b453b7a1c..2a406f92756 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -1,9 +1,9 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.IntermediateGraph; -import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.IntermediateGraph; +import ai.vespa.rankingexpression.importer.ModelImporter; import org.tensorflow.SavedModelBundle; import java.io.File; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index d2430d34711..63a605ce97a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.NodeDef; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java index c6bc889053a..31cb60b5509 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java @@ -1,7 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package ai.vespa.rankingexpression.importer.tensorflow; -import com.yahoo.searchlib.rankingexpression.integration.ml.OrderedTensorType; +import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.serialization.JsonFormat; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java index 35895995d3b..ac462cc39eb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImporter.java @@ -2,8 +2,8 @@ package ai.vespa.rankingexpression.importer.xgboost; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; -import com.yahoo.searchlib.rankingexpression.integration.ml.ModelImporter; +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.ModelImporter; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import java.io.File; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java new file mode 100644 index 00000000000..cf8dd6e8e71 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/DimensionRenamerTest.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import org.junit.Test; + +import static org.junit.Assert.assertTrue; + +public class DimensionRenamerTest { + + @Test + public void testMnistRenaming() { + DimensionRenamer renamer = new DimensionRenamer(); + + renamer.addDimension("first_dimension_of_x"); + renamer.addDimension("second_dimension_of_x"); + renamer.addDimension("first_dimension_of_w"); + renamer.addDimension("second_dimension_of_w"); + renamer.addDimension("first_dimension_of_b"); + + // which dimension to join on matmul + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + + // other dimensions in matmul can't be equal + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + + // for efficiency, put dimension to join on innermost + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + + // bias + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + + renamer.solve(); + + String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get(); + String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get(); + String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get(); + String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get(); + String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get(); + + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0); + assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0); + assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0); + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0); + assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java new file mode 100644 index 00000000000..afe699d6e05 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java @@ -0,0 +1,21 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class OrderedTensorTypeTestCase { + + @Test + public void testToFromSpec() { + String spec = "tensor(b[],c{},a[3])"; + OrderedTensorType type = OrderedTensorType.fromSpec(spec); + assertEquals(spec, type.toString()); + assertEquals("tensor(a[3],b[],c{})", type.type().toString()); + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index d86e7d6dd8e..d3996da9b58 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -6,7 +6,7 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java index d112a3fa9f2..1a072f54c89 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BatchNormImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; import static org.junit.Assert.assertEquals; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java index fa89e060006..37104ab43db 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/BlogEvaluationBenchmark.java @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java index b3559a0a5f6..5e20be051ea 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/DropoutImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.TensorType; import org.junit.Assert; import org.junit.Test; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java index 7e717c204f8..28b91b3797a 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/MnistImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Assert; import org.junit.Test; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java index f98b37b7e55..6215997d8f9 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowMnistSoftmaxImportTestCase.java @@ -2,7 +2,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Assert; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java index faa2c7acc18..c3b82cccb46 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/TestableTensorFlowModel.java @@ -7,7 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java index 30b50c025d0..965d5eb8577 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/xgboost/XGBoostImportTestCase.java @@ -1,7 +1,7 @@ package ai.vespa.rankingexpression.importer.xgboost; import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.integration.ml.ImportedModel; +import ai.vespa.rankingexpression.importer.ImportedModel; import org.junit.Test; import static org.junit.Assert.assertEquals; |