From 3789127189224d6cbd6f109b9a95f848869ea6cc Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Fri, 3 Apr 2020 11:29:43 +0200 Subject: for testing only --- .../searchdefinition/MapEvaluationTypeContext.java | 43 +++- .../com/yahoo/searchdefinition/RankProfile.java | 5 +- .../searchdefinition/derived/RawRankProfile.java | 2 + .../processing/RankingExpressionTypeResolver.java | 1 + .../application/validation/RankSetupValidator.java | 11 +- .../com/yahoo/vespa/model/ml/ConvertedModel.java | 260 +++++++++---------- .../RankingExpressionConstantsTestCase.java | 15 ++ .../processing/RankProfileSearchFixture.java | 8 + .../RankingExpressionWithBertTestCase.java | 196 ++++++++++++++ .../RankingExpressionWithTensorFlowTestCase.java | 13 + .../importer/DimensionRenamer.java | 4 +- .../importer/IntermediateGraph.java | 25 +- .../rankingexpression/importer/ModelImporter.java | 25 +- .../importer/NamingConstraintSolver.java | 4 +- .../importer/onnx/GraphImporter.java | 56 +++- .../importer/operations/Const.java | 8 +- .../importer/operations/Constant.java | 12 +- .../importer/operations/Identity.java | 6 - .../importer/operations/IntermediateOperation.java | 92 ++++++- .../importer/operations/Join.java | 7 + .../importer/operations/MatMul.java | 131 +++++++--- .../importer/operations/Rename.java | 2 +- .../importer/operations/Reshape.java | 60 ++++- .../importer/operations/Slice.java | 1 - .../importer/operations/Softmax.java | 9 +- .../importer/operations/Split.java | 119 +++++++++ .../importer/operations/Tile.java | 100 ++++++++ .../importer/operations/Transpose.java | 54 ++++ .../importer/onnx/BertImportTestCase.java | 281 +++++++++++++++++++++ .../importer/onnx/OnnxOperationsTestCase.java | 140 +++++++++- .../importer/onnx/SimpleImportTestCase.java | 22 ++ .../tensorflow/LesterTensorflowImportTestCase.java | 162 ++++++++++++ .../src/test/models/onnx/simple/concat.onnx | Bin 0 -> 135 bytes .../src/test/models/onnx/simple/concat.py | 25 ++ .../src/test/models/onnx/simple/const.onnx | Bin 0 -> 97 bytes .../src/test/models/onnx/simple/const.py | 26 ++ .../src/test/models/onnx/simple/gather.onnx | Bin 150 -> 150 bytes .../src/test/models/onnx/simple/simple.onnx | 4 +- searchlib/abi-spec.json | 21 +- .../rankingexpression/ExpressionFunction.java | 2 + .../rule/FunctionReferenceContext.java | 31 ++- .../rankingexpression/rule/ReferenceNode.java | 49 +++- .../rule/SerializationContext.java | 8 +- .../rankingexpression/rule/TensorFunctionNode.java | 46 +++- .../transform/ConstantDereferencer.java | 20 +- .../evaluation/EvaluationTestCase.java | 120 +++++++++ vespajlib/abi-spec.json | 2 + .../java/com/yahoo/tensor/IndexedDoubleTensor.java | 7 +- .../java/com/yahoo/tensor/functions/Generate.java | 4 + .../java/com/yahoo/tensor/functions/Slice.java | 4 + 50 files changed, 1982 insertions(+), 261 deletions(-) create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java create mode 100644 model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java create mode 100644 model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java create mode 100644 model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java create mode 100644 model-integration/src/test/models/onnx/simple/concat.onnx create mode 100755 model-integration/src/test/models/onnx/simple/concat.py create mode 100644 model-integration/src/test/models/onnx/simple/const.onnx create mode 100755 model-integration/src/test/models/onnx/simple/const.py diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 6de7c985326..65443117c0a 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -47,12 +47,15 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement private final SortedSet queryFeaturesNotDeclared; private boolean tensorsAreUsed; + private final MapEvaluationTypeContext parent; + MapEvaluationTypeContext(Collection functions, Map featureTypes) { super(functions); this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = new ArrayDeque<>(); this.queryFeaturesNotDeclared = new TreeSet<>(); tensorsAreUsed = false; + parent = null; } private MapEvaluationTypeContext(Map functions, @@ -60,12 +63,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement Map featureTypes, Deque currentResolutionCallStack, SortedSet queryFeaturesNotDeclared, - boolean tensorsAreUsed) { + boolean tensorsAreUsed, + MapEvaluationTypeContext parent) { super(functions, bindings); this.featureTypes.putAll(featureTypes); this.currentResolutionCallStack = currentResolutionCallStack; this.queryFeaturesNotDeclared = queryFeaturesNotDeclared; this.tensorsAreUsed = tensorsAreUsed; + this.parent = parent; } public void setType(Reference reference, TensorType type) { @@ -82,16 +87,45 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement resolvedTypes.clear(); } + private TensorType resolvedType(Reference reference, int depth) { +// System.out.println(indent + "In resolvedtype - resolving type for " + reference.toString()); + TensorType resolvedType = resolvedTypes.get(reference); + if (resolvedType != null) { +// System.out.println("Found previously resolved type for " + reference + " at depth " + depth + ": (" + resolvedType + ")"); + return resolvedType; + } + if (parent != null) return parent.resolvedType(reference, depth + 1); // what about argument types? Careful with this! +// System.out.println("Could NOT find type for " + reference + " - down to depth " + depth); + return null; + } + + private MapEvaluationTypeContext findOriginalParent() { + if (parent != null) + return parent.findOriginalParent(); + return this; + } + @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: - TensorType resolvedType = resolvedTypes.get(reference); + // TensorType resolvedType = resolvedTypes.get(reference); + TensorType resolvedType = resolvedType(reference, 0); if (resolvedType != null) return resolvedType; resolvedType = resolveType(reference); if (resolvedType == null) return defaultTypeOf(reference); // Don't store fallback to default as we may know more later - resolvedTypes.put(reference, resolvedType); + +// System.out.println("Resolved type of " + reference + ": (" + resolvedType + ")"); + + // MÃ¥ inn her med et konsept av global eller lokal. + // For globale - legg i lavest parent! + MapEvaluationTypeContext originalParent = findOriginalParent(); + if (originalParent == null) { + originalParent = this; + } + originalParent.resolvedTypes.put(reference, resolvedType); + if (resolvedType.rank() > 0) tensorsAreUsed = true; return resolvedType; @@ -103,6 +137,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) + " -> " + reference); + // Bound to a function argument, and not to a same-named identifier (which would lead to a loop)? Optional binding = boundIdentifier(reference); if (binding.isPresent() && ! binding.get().equals(reference.toString())) { @@ -254,7 +289,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement featureTypes, currentResolutionCallStack, queryFeaturesNotDeclared, - tensorsAreUsed); + tensorsAreUsed, this); } } diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 23eb814de81..ea126123a25 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -680,11 +680,12 @@ public class RankProfile implements Cloneable { Map inlineFunctions = compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms); + firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); + // Function compiling second pass: compile all functions and insert previously compiled inline functions functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms); - firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); - secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms); } private void checkNameCollisions(Map functions, Map constants) { diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 1a22b98fd9f..3578cc786ed 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -222,7 +222,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { List functionExpressions) { SerializationContext context = new SerializationContext(functionExpressions); for (Map.Entry e : functions.entrySet()) { + System.out.println("Deriving: " + e.getKey()); String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString(); + System.out.println("-> Done deriving: " + e.getKey() + ": " + expressionString); context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString); for (Map.Entry argumentType : e.getValue().function().argumentTypes().entrySet()) diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java index 89b8889b4ae..8d9098a10f1 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java @@ -79,6 +79,7 @@ public class RankingExpressionTypeResolver extends Processor { } context.forgetResolvedTypes(); + System.out.println("Resolving type for " + function.getKey()); TensorType type = resolveType(expressionFunction.getBody(), "function '" + function.getKey() + "'", context); function.getValue().setReturnType(type); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java index b6f7ab4ff62..d3e029b8de5 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java @@ -102,7 +102,8 @@ public class RankSetupValidator extends Validator { } private void deleteTempDir(File dir) { - IOUtils.recursiveDeleteDir(dir); + System.out.println("Here we were supposed to delete tmpdir: " + dir.getAbsolutePath()); +// IOUtils.recursiveDeleteDir(dir); } private void writeConfigs(String dir, AbstractConfigProducer producer) throws IOException { @@ -133,7 +134,13 @@ public class RankSetupValidator extends Validator { } private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException { - IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); + + String output = StringUtilities.implodeMultiline(ConfigInstance.serialize(config)); + System.out.println("Writing config for in " + dir + " for configName '" + configName + "' "); + System.out.println(output); + IOUtils.writeFile(dir + configName, output, false); + +// IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false); } private boolean execValidate(String configId, SearchCluster sc, String sdName, DeployLogger deployLogger) { diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index c3d6f457ce8..9f649bc820a 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -386,138 +386,138 @@ public class ConvertedModel { */ private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { - MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); - - // Add any missing inputs for type resolution - Set functionNames = new HashSet<>(); - addFunctionNamesIn(expression.getRoot(), functionNames, model); - for (String functionName : functionNames) { - Optional requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); - if (requiredType.isPresent()) { - Reference ref = Reference.fromIdentifier(functionName); - if (typeContext.getType(ref).equals(TensorType.empty)) { - typeContext.setType(ref, requiredType.get()); - } - } - } - typeContext.forgetResolvedTypes(); - - TensorType typeBeforeReducing = expression.getRoot().type(typeContext); - - // Check generated functions for inputs to reduce - for (String functionName : functionNames) { - if ( ! model.functions().containsKey(functionName)) continue; - - RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); - if (rankingExpressionFunction == null) { - throw new IllegalArgumentException("Model refers to generated function '" + functionName + - "but this function is not present in " + profile); - } - RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); - functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); - } - - // Check expression for inputs to reduce - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, model, typeContext); - TensorType typeAfterReducing = root.type(typeContext); - root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); - expression.setRoot(root); - } - - private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, - MapEvaluationTypeContext typeContext) { - if (node instanceof TensorFunctionNode) { - TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); - if (tensorFunction instanceof Rename) { - List children = ((TensorFunctionNode)node).children(); - if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(tensorFunction, typeContext); - } - } - // Modify any renames in expression to disregard batch dimension - else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { - TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function()); - TensorType childType = childFunction.type(typeContext); - Rename rename = (Rename) tensorFunction; - List from = new ArrayList<>(); - List to = new ArrayList<>(); - for (TensorType.Dimension dimension : childType.dimensions()) { - int i = rename.fromDimensions().indexOf(dimension.name()); - if (i < 0) { - throw new IllegalArgumentException("Rename does not contain dimension '" + - dimension + "' in child expression type: " + childType); - } - from.add((String)rename.fromDimensions().get(i)); - to.add((String)rename.toDimensions().get(i)); - } - return new TensorFunctionNode(new Rename<>(childFunction, from, to)); - } - } - } - if (node instanceof ReferenceNode) { - ReferenceNode referenceNode = (ReferenceNode) node; - if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); - } - } - if (node instanceof CompositeNode) { - List children = ((CompositeNode)node).children(); - List transformedChildren = new ArrayList<>(children.size()); - for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); - } - return ((CompositeNode)node).setChildren(transformedChildren); - } - return node; - } - - private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { - TensorFunction result = function; - TensorType type = function.type(context); - if (type.dimensions().size() > 1) { - List reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } - } - if (reduceDimensions.size() > 0) { - result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); - context.forgetResolvedTypes(); // We changed types - } - } - return new TensorFunctionNode(result); +// MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles); +// +// // Add any missing inputs for type resolution +// Set functionNames = new HashSet<>(); +// addFunctionNamesIn(expression.getRoot(), functionNames, model); +// for (String functionName : functionNames) { +// Optional requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec); +// if (requiredType.isPresent()) { +// Reference ref = Reference.fromIdentifier(functionName); +// if (typeContext.getType(ref).equals(TensorType.empty)) { +// typeContext.setType(ref, requiredType.get()); +// } +// } +// } +// typeContext.forgetResolvedTypes(); +// +// TensorType typeBeforeReducing = expression.getRoot().type(typeContext); +// +// // Check generated functions for inputs to reduce +// for (String functionName : functionNames) { +// if ( ! model.functions().containsKey(functionName)) continue; +// +// RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName); +// if (rankingExpressionFunction == null) { +// throw new IllegalArgumentException("Model refers to generated function '" + functionName + +// "but this function is not present in " + profile); +// } +// RankingExpression functionExpression = rankingExpressionFunction.function().getBody(); +// functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext)); +// } +// +// // Check expression for inputs to reduce +// ExpressionNode root = expression.getRoot(); +// root = reduceBatchDimensionsAtInput(root, model, typeContext); +// TensorType typeAfterReducing = root.type(typeContext); +// root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); +// expression.setRoot(root); } - /** - * If batch dimensions have been reduced away above, bring them back here - * for any following computation of the tensor. - */ - // TODO: determine when this is not necessary! - private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { - if (after.equals(before)) return node; - - TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); - for (TensorType.Dimension dimension : before.dimensions()) { - if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { - typeBuilder.indexed(dimension.name(), 1); - } - } - TensorType expandDimensionsType = typeBuilder.build(); - if (expandDimensionsType.dimensions().size() > 0) { - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); - Generate generatedFunction = new Generate(expandDimensionsType, - new GeneratorLambdaFunctionNode(expandDimensionsType, - generatedExpression) - .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - return node; - } +// private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model, +// MapEvaluationTypeContext typeContext) { +// if (node instanceof TensorFunctionNode) { +// TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); +// if (tensorFunction instanceof Rename) { +// List children = ((TensorFunctionNode)node).children(); +// if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { +// ReferenceNode referenceNode = (ReferenceNode) children.get(0); +// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { +// return reduceBatchDimensionExpression(tensorFunction, typeContext); +// } +// } +// // Modify any renames in expression to disregard batch dimension +// else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { +// TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function()); +// TensorType childType = childFunction.type(typeContext); +// Rename rename = (Rename) tensorFunction; +// List from = new ArrayList<>(); +// List to = new ArrayList<>(); +// for (TensorType.Dimension dimension : childType.dimensions()) { +// int i = rename.fromDimensions().indexOf(dimension.name()); +// if (i < 0) { +// throw new IllegalArgumentException("Rename does not contain dimension '" + +// dimension + "' in child expression type: " + childType); +// } +// from.add((String)rename.fromDimensions().get(i)); +// to.add((String)rename.toDimensions().get(i)); +// } +// return new TensorFunctionNode(new Rename<>(childFunction, from, to)); +// } +// } +// } +// if (node instanceof ReferenceNode) { +// ReferenceNode referenceNode = (ReferenceNode) node; +// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { +// return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); +// } +// } +// if (node instanceof CompositeNode) { +// List children = ((CompositeNode)node).children(); +// List transformedChildren = new ArrayList<>(children.size()); +// for (ExpressionNode child : children) { +// transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); +// } +// return ((CompositeNode)node).setChildren(transformedChildren); +// } +// return node; +// } +// +// private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) { +// TensorFunction result = function; +// TensorType type = function.type(context); +// if (type.dimensions().size() > 1) { +// List reduceDimensions = new ArrayList<>(); +// for (TensorType.Dimension dimension : type.dimensions()) { +// if (dimension.size().orElse(-1L) == 1) { +// reduceDimensions.add(dimension.name()); +// } +// } +// if (reduceDimensions.size() > 0) { +// result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); +// context.forgetResolvedTypes(); // We changed types +// } +// } +// return new TensorFunctionNode(result); +// } +// +// /** +// * If batch dimensions have been reduced away above, bring them back here +// * for any following computation of the tensor. +// */ +// // TODO: determine when this is not necessary! +// private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { +// if (after.equals(before)) return node; +// +// TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType()); +// for (TensorType.Dimension dimension : before.dimensions()) { +// if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { +// typeBuilder.indexed(dimension.name(), 1); +// } +// } +// TensorType expandDimensionsType = typeBuilder.build(); +// if (expandDimensionsType.dimensions().size() > 0) { +// ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0)); +// Generate generatedFunction = new Generate(expandDimensionsType, +// new GeneratorLambdaFunctionNode(expandDimensionsType, +// generatedExpression) +// .asLongListToDoubleOperator()); +// Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); +// return new TensorFunctionNode(expand); +// } +// return node; +// } /** * If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions. diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java index d84d967a184..8bc9040577b 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java @@ -2,8 +2,11 @@ package com.yahoo.searchdefinition; import com.yahoo.collections.Pair; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.search.query.profile.QueryProfileRegistry; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.yolean.Exceptions; import com.yahoo.searchdefinition.derived.AttributeFields; import com.yahoo.searchdefinition.derived.RawRankProfile; @@ -82,6 +85,18 @@ public class RankingExpressionConstantsTestCase extends SchemaTestCase { new AttributeFields(s)).configProperties(); assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString()); assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString()); + + try { + DerivedConfiguration config = new DerivedConfiguration(s, + new BaseDeployLogger(), + new TestProperties(), + rankProfileRegistry, + queryProfileRegistry, + new ImportedMlModels()); + config.export("/Users/lesters/temp/bert/idea/"); + } catch (Exception e) { + throw new IllegalArgumentException(e); + } } @Test diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java index 0cd6674751e..d5638da224c 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java @@ -41,6 +41,14 @@ class RankProfileSearchFixture { private Search search; private Map compiledRankProfiles = new HashMap<>(); + // TEMP + public RankProfileRegistry getRankProfileRegistry() { + return rankProfileRegistry; + } + public QueryProfileRegistry getQueryProfileRegistry() { + return queryProfileRegistry; + } + RankProfileSearchFixture(String rankProfiles) throws ParseException { this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java new file mode 100644 index 00000000000..2c0620a0c52 --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java @@ -0,0 +1,196 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; +import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; +import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter; +import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter; +import com.google.common.collect.ImmutableList; +import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; +import com.yahoo.io.IOUtils; +import com.yahoo.path.Path; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage; +import com.yahoo.vespa.model.VespaModel; +import com.yahoo.vespa.model.ml.ImportedModelTester; +import com.yahoo.yolean.Exceptions; +import org.junit.After; +import org.junit.Ignore; +import org.junit.Test; + +import java.io.IOException; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.fail; + +public class RankingExpressionWithBertTestCase { + + private final Path applicationDir = Path.fromString("src/test/integration/bert/"); + + /** The model name */ + private final static String name = "bertsquad8"; + + private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))"; + + @After + public void removeGeneratedModelFiles() { + IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + } + + + @Ignore + @Test + public void testGlobalBertModel() throws IOException { + ImportedModelTester tester = new ImportedModelTester(name, applicationDir); + VespaModel model = tester.createVespaModel(); +// tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L)); +// tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedAppDir = applicationDir.append("copy"); + try { + storedAppDir.toFile().mkdirs(); + IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString()); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir); + VespaModel storedModel = storedTester.createVespaModel(); +// tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L)); +// tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L)); + } + finally { + IOUtils.recursiveDeleteDir(storedAppDir.toFile()); + } + } + + @Ignore + @Test + public void testBertRankProfile() throws Exception { + StoringApplicationPackage application = new StoringApplicationPackage((applicationDir)); + + ImmutableList importers = ImmutableList.of(new TensorFlowImporter(), + new OnnxImporter(), + new LightGBMImporter(), + new XGBoostImporter()); + + String rankProfiles = " rank-profile my_profile {\n" + + " first-phase {\n" + + " expression: onnx('bertsquad8.onnx', 'default', 'unstack')" + + " }\n" + + " }"; + + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles(); + + SearchBuilder builder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry); + String sdContent = "search test {\n" + + " document test {\n" + + " field unique_ids type tensor(d0[1]) {\n" + + " indexing: summary | attribute\n" + + " }\n" + + " field input_ids type tensor(d0[1],d1[256]) {\n" + + " indexing: summary | attribute\n" + + " }\n" + + " field input_mask type tensor(d0[1],d1[256]) {\n" + + " indexing: summary | attribute\n" + + " }\n" + + " field segment_ids type tensor(d0[1],d1[256]) {\n" + + " indexing: summary | attribute\n" + + " }" + + " }\n" + + " rank-profile my_profile inherits default {\n" + + " function inline unique_ids_raw_output___9() {\n" + + " expression: attribute(unique_ids)\n" + + " }\n" + + " function inline input_ids() {\n" + + " expression: attribute(input_ids)\n" + + " }\n" + + " function inline input_mask() {\n" + + " expression: attribute(input_mask)\n" + + " }\n" + + " function inline segment_ids() {\n" + + " expression: attribute(segment_ids)\n" + + " }\n" + + " first-phase {\n" + + " expression: onnx(\"bertsquad8.onnx\", \"default\", \"unstack\") \n" + + " }\n" + + " }" + + "}"; + builder.importString(sdContent); + builder.build(); + Search search = builder.getSearch(); + + RankProfile compiled = rankProfileRegistry.get(search, "my_profile") + .compile(queryProfileRegistry, + new ImportedMlModels(applicationDir.toFile(), importers)); + + DerivedConfiguration config = new DerivedConfiguration(search, + new BaseDeployLogger(), + new TestProperties(), + rankProfileRegistry, + queryProfileRegistry, + new ImportedMlModels()); + + config.export("/Users/lesters/temp/bert/idea/"); + +// fixture.assertFirstPhaseExpression(vespaExpression, "my_profile"); + System.out.println("Joda"); + } + + private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression, + String constant, String field) { + return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder", + new StoringApplicationPackage(applicationDir)); + } + + private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture(application, application.getQueryProfiles(), + rankProfile, null, null); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + + private RankProfileSearchFixture fixtureWith(String functionExpression, + String firstPhaseExpression, + String constant, + String field, + String functionName, + StoringApplicationPackage application) { + try { + RankProfileSearchFixture fixture = new RankProfileSearchFixture( + application, + application.getQueryProfiles(), + " rank-profile my_profile {\n" + + " function " + functionName + "() {\n" + + " expression: " + functionExpression + + " }\n" + + " first-phase {\n" + + " expression: " + firstPhaseExpression + + " }\n" + + " }", + constant, + field); + fixture.compileRankProfile("my_profile", applicationDir.append("models")); + return fixture; + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index cba931e81f0..c444bf8d7dc 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -1,8 +1,11 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchdefinition.processing; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels; import com.yahoo.config.application.api.ApplicationFile; import com.yahoo.config.application.api.ApplicationPackage; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.io.IOUtils; @@ -10,6 +13,7 @@ import com.yahoo.io.reader.NamedReader; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; import com.yahoo.searchdefinition.RankingConstant; +import com.yahoo.searchdefinition.derived.DerivedConfiguration; import com.yahoo.searchdefinition.parser.ParseException; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; @@ -385,6 +389,15 @@ public class RankingExpressionWithTensorFlowTestCase { finally { IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); } + + DerivedConfiguration config = new DerivedConfiguration(search.search(), + new BaseDeployLogger(), + new TestProperties(), + search.getRankProfileRegistry(), + search.getQueryProfileRegistry(), + new ImportedMlModels()); + config.export("/Users/lesters/temp/bert/idea/"); + } private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java index c7f320ed3b4..87f7c1c71f8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java @@ -66,7 +66,7 @@ public class DimensionRenamer { void solve() { log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints)); - renames = solve(100000); + renames = solve(100000000); log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames)); } @@ -86,7 +86,7 @@ public class DimensionRenamer { private Map solveWithOrWithoutSoftConstraints(int maxIterations) { Map solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations); - if ( solution == null) { + if (solution == null) { ListMap hardConstraints = new ListMap<>(); boolean anyRemoved = copyHard(constraints, hardConstraints); if (anyRemoved) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index 14aa3ebf84e..3c8a6bde232 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -74,6 +75,8 @@ public class IntermediateGraph { renameDimensions(); } + static int counter = 0; + /** * Find dimension names to avoid excessive renaming while evaluating the model. */ @@ -93,16 +96,34 @@ public class IntermediateGraph { } private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) { + Set operations = new HashSet<>(); + addDimensionNameConstraints(operation, renamer, operations); + } + + private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set operations) { + if (operations.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, operations)); operation.addDimensionNameConstraints(renamer); + operations.add(operation.name()); } } private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) { + Set operations = new HashSet<>(); + renameDimensions(operation, renamer, operations); + } + + private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set operations) { + if (operations.contains(operation.name())) { + return; + } if (operation.type().isPresent()) { - operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.inputs().forEach(input -> renameDimensions(input, renamer, operations)); operation.renameDimensions(renamer); + operations.add(operation.name()); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 3774e64c886..7fad077ceb2 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -3,11 +3,14 @@ package ai.vespa.rankingexpression.importer; import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import ai.vespa.rankingexpression.importer.operations.Constant; import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; @@ -15,9 +18,11 @@ import com.yahoo.text.ExpressionFormatter; import com.yahoo.yolean.Exceptions; import java.io.File; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Level; import java.util.logging.Logger; @@ -122,8 +127,16 @@ public abstract class ModelImporter implements MlModelImporter { return operation.function(); } + private static boolean isImported(IntermediateOperation operation, ImportedModel model) { + return model.expressions().containsKey(operation.name()); // test for others? + } + private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) { - operation.inputs().forEach(input -> importExpression(input, model)); + operation.inputs().forEach(input -> { + if ( ! isImported(operation, model)) { + importExpression(input, model); + } + }); } private static Optional importConstant(IntermediateOperation operation, ImportedModel model) { @@ -206,18 +219,22 @@ public abstract class ModelImporter implements MlModelImporter { private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { - reportWarnings(graph.get(outputName), model); + reportWarnings(graph.get(outputName), model, new HashSet()); } } } - private static void reportWarnings(IntermediateOperation operation, ImportedModel model) { + private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set reported) { + if (reported.contains(operation.name())) { + return; + } for (String warning : operation.warnings()) { // If we want to report warnings, that code goes here } for (IntermediateOperation input : operation.inputs()) { - reportWarnings(input, model); + reportWarnings(input, model, reported); } + reported.add(operation.name()); } /** diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java index 21cc6b27dad..9a7fcc85ee1 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java @@ -37,7 +37,8 @@ class NamingConstraintSolver { private static ListMap allPossibilities(Set dimensions) { ListMap all = new ListMap<>(); for (String dimension : dimensions) { - for (int i = 0; i < dimensions.size(); ++i) + // 20 (different dimension names) should be enough for most problems. + for (int i = 0; i < Math.min(dimensions.size(), 20); ++i) all.put(dimension, i); } return all; @@ -89,6 +90,7 @@ class NamingConstraintSolver { workList.add(constraint); } } + if (iterations > maxIterations) return false; } return true; } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index ffc64c38f16..c98a5c7d4f5 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -2,7 +2,6 @@ package ai.vespa.rankingexpression.importer.onnx; -import ai.vespa.rankingexpression.importer.operations.ExpandDims; import ai.vespa.rankingexpression.importer.operations.Gather; import ai.vespa.rankingexpression.importer.operations.OnnxCast; import ai.vespa.rankingexpression.importer.operations.Gemm; @@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce; import ai.vespa.rankingexpression.importer.operations.Select; import ai.vespa.rankingexpression.importer.operations.Slice; import ai.vespa.rankingexpression.importer.operations.Softmax; +import ai.vespa.rankingexpression.importer.operations.Split; import ai.vespa.rankingexpression.importer.operations.Squeeze; +import ai.vespa.rankingexpression.importer.operations.Tile; +import ai.vespa.rankingexpression.importer.operations.Transpose; import ai.vespa.rankingexpression.importer.operations.Unsqueeze; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import java.util.Collection; +import java.util.HashMap; import java.util.HashSet; import java.util.List; import java.util.Optional; @@ -53,19 +57,21 @@ class GraphImporter { private static IntermediateOperation mapOperation(Onnx.NodeProto node, List inputs, - IntermediateGraph graph) { + IntermediateGraph graph, + int outputIndex) { String type = node.getOpType(); String modelName = graph.name(); String nodeName = getNodeName(node); AttributeConverter attributes = AttributeConverter.convert(node); - return mapOperation(type, inputs, modelName, nodeName, attributes); + return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex); } static IntermediateOperation mapOperation(String opType, List inputs, String modelName, String nodeName, - AttributeConverter attributes) { + AttributeConverter attributes, + int outputIndex) { switch (opType.toLowerCase()) { case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); @@ -115,17 +121,21 @@ class GraphImporter { case "slice": return new Slice(modelName, nodeName, inputs, attributes); case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); + case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex); case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); case "where": return new Select(modelName, nodeName, inputs); case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); + case "tile": return new Tile(modelName, nodeName, inputs); + case "transpose": return new Transpose(modelName, nodeName, inputs, attributes); case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes); } IntermediateOperation op = new NoOp(modelName, nodeName, inputs); op.warning("Operation '" + opType + "' is currently not implemented"); + System.out.println(nodeName + ": operation '" + opType + "' is currently not implemented"); return op; } @@ -133,10 +143,15 @@ class GraphImporter { Onnx.GraphProto onnxGraph = model.getGraph(); IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); + System.out.println("Importing operations..."); importOperations(onnxGraph, intermediateGraph); + System.out.println("Verifying no warnings..."); verifyNoWarnings(intermediateGraph); + System.out.println("Verifying output types..."); verifyOutputTypes(onnxGraph, intermediateGraph); + System.out.println("Ok..."); + return intermediateGraph; } @@ -150,8 +165,10 @@ class GraphImporter { Onnx.GraphProto onnxGraph, IntermediateGraph intermediateGraph) { if (intermediateGraph.alreadyImported(name)) { +// System.out.println("Trying to import '" + name + "' but is was already imported."); return intermediateGraph.get(name); } +// System.out.println("Importing '" + name + "' ..."); IntermediateOperation operation; if (isArgumentTensor(name, onnxGraph)) { Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph); @@ -163,16 +180,21 @@ class GraphImporter { intermediateGraph.inputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.vespaName()); +// System.out.println(" '" + name + "' imported as argument..."); + } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); +// System.out.println(" '" + name + "' imported as constant..."); + } else { Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph); + int outputIndex = getOutputIndex(node, name); List inputs = importOperationInputs(node, onnxGraph, intermediateGraph); - operation = mapOperation(node, inputs, intermediateGraph); + operation = mapOperation(node, inputs, intermediateGraph, outputIndex); // propagate constant values if all inputs are constant if (operation.isConstant()) { @@ -183,8 +205,12 @@ class GraphImporter { intermediateGraph.outputs(intermediateGraph.defaultSignature()) .put(IntermediateOperation.namePartOf(name), operation.name()); } + +// System.out.println(" '" + name + "' imported as normal..."); + } intermediateGraph.put(operation.name(), operation); + intermediateGraph.put(name, operation); return operation; } @@ -262,7 +288,8 @@ class GraphImporter { Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph); OrderedTensorType type = operation.type().orElseThrow( () -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type.")); - TypeConverter.verifyType(onnxNode.getType(), type); + System.out.println(onnxNode.getType() + " vs. " + type); + //TypeConverter.verifyType(onnxNode.getType(), type); } } @@ -296,6 +323,10 @@ class GraphImporter { return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst(); } + private static int getOutputIndex(Onnx.NodeProto node, String outputName) { + return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0); + } + private static String getNodeName(Onnx.NodeProto node) { String nodeName = node.getName(); if (nodeName.length() > 0) @@ -307,11 +338,14 @@ class GraphImporter { } private static Set getWarnings(IntermediateOperation op) { - Set warnings = new HashSet<>(op.warnings()); - for (IntermediateOperation input : op.inputs()) { - warnings.addAll(getWarnings(input)); - } - return warnings; + java.util.Map> warnings = new HashMap<>(); + getWarnings(op, warnings); + return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet()); } + private static void getWarnings(IntermediateOperation op, java.util.Map> warnings) { + if (warnings.containsKey(op.name())) return; + op.inputs().forEach(input -> getWarnings(input, warnings)); + warnings.put(op.name(), new HashSet<>(op.warnings())); + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 01fd7ee55bd..956d727fbad 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -54,10 +54,10 @@ public class Const extends IntermediateOperation { } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } +// @Override +// public String vespaName() { +// return modelName + "_" + super.vespaName(); +// } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index ad56eefe5f2..b12f83f274b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -22,10 +22,10 @@ public class Constant extends IntermediateOperation { } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + vespaName(name); - } +// @Override +// public String vespaName() { +// return modelName + "_" + vespaName(name); +// } @Override protected OrderedTensorType lazyGetType() { @@ -61,7 +61,9 @@ public class Constant extends IntermediateOperation { public Constant withInputs(List inputs) { if ( ! inputs.isEmpty()) throw new IllegalArgumentException("Constant cannot take inputs"); - return new Constant(modelName(), name(), type); + Constant constant = new Constant(modelName(), name(), type); + constant.setConstantValueFunction(constantValueFunction); + return constant; } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index 5463f645355..af192fcec38 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -12,12 +12,6 @@ public class Identity extends IntermediateOperation { super(modelName, nodeName, inputs); } - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } - @Override protected OrderedTensorType lazyGetType() { if (!allInputTypesPresent(1)) diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 2aa8b2a0d48..83e15a4081a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -47,6 +49,8 @@ public abstract class IntermediateOperation { protected TensorFunction rankingExpressionFunction = null; protected boolean exportAsRankingFunction = false; + private boolean hasRenamedDimensions = false; + private final List importWarnings = new ArrayList<>(); private Value constantValue = null; private List controlInputs = Collections.emptyList(); @@ -121,7 +125,10 @@ public abstract class IntermediateOperation { } /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + public void renameDimensions(DimensionRenamer renamer) { + type = type.rename(renamer); + hasRenamedDimensions = true; + } /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ public boolean isInput() { return false; } @@ -144,7 +151,11 @@ public abstract class IntermediateOperation { } /** Set the constant value function */ - public void setConstantValueFunction(Function func) { this.constantValueFunction = func; } + public void setConstantValueFunction(Function func) { + this.constantValueFunction = func; + } + + public boolean hasConstantValueFunction() { return constantValueFunction != null; } /** Sets the external control inputs */ public void setControlInputs(List inputs) { this.controlInputs = inputs; } @@ -153,12 +164,23 @@ public abstract class IntermediateOperation { public List getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(name); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; } + public String vespaName() { + if (isConstant()) + return modelName + "_" + vespaName(name); + return vespaName(name); + } + + public String vespaName(String name) { + return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; + } /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ public String rankingExpressionFunctionName() { - return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + String vespaName = vespaName(); + if (vespaName == null) { + return null; + } + return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName; } /** Retrieve the list of warnings produced during its lifetime */ @@ -185,30 +207,80 @@ public abstract class IntermediateOperation { /** Recursively evaluates this operation's constant value to avoid doing it run-time. */ public Value evaluateAsConstant(OrderedTensorType type) { +// System.out.println("Starting constant evaluation for " + name); if ( ! isConstant() ) { throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); } - Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); - if (type != null && ! val.asTensor().type().equals(type.type()) ) { + if (type == null) { + System.out.println("Evaluating as constant for " + name + " with type null! Probably an error."); + } + + IntermediateOperation evaluateOn = this; + if ( ! hasRenamedDimensions) { + // make a copy of the tree, perform renaming and evaluate + IntermediateOperation copy = copyTree(0); + optimizeAndRename(copy); + evaluateOn = copy; + } + Value val = evaluateOn.evaluateAsConstant(new MapContext(DoubleValue.NaN), 0); + + if (type == null) { + return val; + } + Tensor tensor = val.asTensor(); //.withType(type.type()); + if ( ! tensor.type().isRenamableTo(type.type()) ) { throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + "Expected: " + type.type() + " Got: " + val.asTensor().type()); } - return val; + // set constant value so we don't have to re-evaluate + setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type()))); +// System.out.println("Returning constant evaluation for " + name); + return new TensorValue(tensor.withType(type.type())); + } + + private IntermediateOperation copyTree(int indent) { + String indentString = ""; for (int i = 0; i < indent; ++i) indentString += " "; +// System.out.println(indentString + "Copying " + name); + List in = new ArrayList<>(); + if (constantValue != null) { +// System.out.println(indentString + name + " has a constant value"); + IntermediateOperation constant = new Constant(modelName, name, type); + constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type()))); + return constant; + } + inputs.forEach(i -> in.add(i.copyTree(indent + 1))); + IntermediateOperation copy = withInputs(in); + if (constantValueFunction != null) { + copy.constantValueFunction = constantValueFunction; // works? + } + return copy; + } + + private TensorFunction optimizeAndRename(IntermediateOperation op) { + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(name, op); + graph.outputs(graph.defaultSignature()).put(name, name); + graph.optimize(); + return op.function().get(); } - private Value evaluateAsConstant(Context context) { + private Value evaluateAsConstant(Context context, int indent) { + String in = ""; for (int i = 0; i < indent; ++i) in += " "; +// System.out.println(in + "Constant evaluating for " + name); String constantName = "constant(" + vespaName() + ")"; Value result = context.get(constantName); if (result == DoubleValue.NaN) { if (constantValue != null) { +// System.out.println(in + name + " has constant value."); result = constantValue; } else if (inputs.size() == 0) { +// System.out.println(in + name + " has no inputs."); if (getConstantValue().isEmpty()) { throw new IllegalArgumentException("Error in evaluating constant for " + name); } result = getConstantValue().get(); } else { - inputs.forEach(i -> i.evaluateAsConstant(context)); + inputs.forEach(i -> i.evaluateAsConstant(context, indent+1)); result = new TensorValue(lazyGetFunction().evaluate(context)); } context.put(constantName, result); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index adb54474812..3211a44fa68 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -82,6 +82,13 @@ public class Join extends IntermediateOperation { bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); } + // retain order of inputs + if (a == inputs.get(1)) { + TensorFunction temp = bReducedFunction; + bReducedFunction = aReducedFunction; + aReducedFunction = temp; + } + return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 6849e64641e..1eb21eb2a5e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -4,6 +4,9 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.text.ExpressionFormatter; @@ -20,64 +23,126 @@ public class MatMul extends IntermediateOperation { protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + + // add some more checks here + if (aType.type().rank() < 1 || bType.type().rank() < 1) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1"); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + OrderedTensorType largestRankType = aType.rank() >= bType.rank() ? aType : bType; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + typeBuilder.add(largestRankType.dimensions().get(i)); + } + if (aType.rank() >= 2) { + typeBuilder.add(aType.dimensions().get(aType.rank() - 2)); + } + if (bType.rank() >= 2) { + typeBuilder.add(bType.dimensions().get(bType.rank() - 1)); + } return typeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - Optional aFunction = inputs.get(0).function(); Optional bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; - } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + + // only change to this is for dimensions with size 1 - check in getType + + return new com.yahoo.tensor.functions.Reduce(new Join(aFunction.get(), bFunction.get(), ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + aType.dimensions().get(aType.rank() - 1).name()); } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { if ( ! allInputTypesPresent(2)) return; - List aDimensions = inputs.get(0).type().get().dimensions(); - List bDimensions = inputs.get(1).type().get().dimensions(); + /* + * A: a1, a2, a3, a4 + * B: b1, b2, b3, b4 + * + * a4 == b3 + * a3 < b4 + * a3 < a4 + * b4 < b3 + * + * a1 == b1 -> men ogsÃ¥ størrelsesmessig. + * a2 == b2 + * etc + */ + + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + String lastDimA = typeA.dimensions().get(typeA.rank()-1).name(); + String lastDimB = typeB.dimensions().get(typeB.rank()-1).name(); + String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name(); + String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name(); + + // The last dimension of A should have the same name as the second-to-last dimension of B + renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this); - assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); - assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + // For efficiency, the dimensions to join over should be innermost - soft constraint + if (typeA.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this); + } + if (typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this); + } - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); + // The second-to-last dimension of a should have a different name than the last dimension of b + if (typeA.rank() >= 2 && typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this); + } - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); + // a1 < a2 < a3 < a4 + OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + String iDim = largestRankType.dimensionNames().get(i); + for (int j = i+1; j < largestRankType.rank() - 2; ++j) { + String jDim = largestRankType.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + } + + // TODO: handle non similar sizes + + // a1 == b1 etc + if (typeA.rank() == typeB.rank()) { + for (int i = 0; i < typeA.rank() - 2; ++i) { + renamer.addConstraint(typeA.dimensionNames().get(i), typeB.dimensionNames().get(i), DimensionRenamer.Constraint.equal(false), this); + } + } - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); - // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); - } - private void assertTwoDimensions(List dimensions, IntermediateOperation supplier, String inputDescription) { - if (dimensions.size() >= 2) return; - throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + - " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); + + // So, what about the other dimensions? +// if (aDimensions.size() > 2) { +// for (int i = 1; i < aDimensions.size(); ++i) { +// renamer.addConstraint(aDimensions.get(0).name(), aDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this); +// } +// for (int i = 0; i < bDimensions.size(); ++i) { +// renamer.addConstraint(aDimensions.get(0).name(), bDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this); +// } +// } + } +// private void assertTwoDimensions(List dimensions, IntermediateOperation supplier, String inputDescription) { +// if (dimensions.size() >= 2) return; +// throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + +// " but got just " + dimensions + " from\n" + +// ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); +// } + @Override public MatMul withInputs(List inputs) { return new MatMul(modelName(), name(), inputs); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index e040ae62149..07ac457cca8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -54,7 +54,7 @@ public class Rename extends IntermediateOperation { } public void renameDimensions(DimensionRenamer renamer) { - type = type.rename(renamer); + super.renameDimensions(renamer); from = renamer.dimensionNameOf(from).orElse(from); to = renamer.dimensionNameOf(to).orElse(to); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c88fc18e6c6..f96dd420d30 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -2,8 +2,10 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -11,8 +13,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.Function; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -27,6 +32,8 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + public class Reshape extends IntermediateOperation { private final AttributeMap attributeMap; @@ -38,6 +45,10 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + if (inputs.size() == 2) { return typeWithShapeAsInput(); } else if (inputs.size() == 1) { @@ -126,10 +137,54 @@ public class Reshape extends IntermediateOperation { return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + IntermediateOperation input = inputs.get(0); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List> dimensionValues = new ArrayList<>(); + + // ala (d0 * 2 + d1) + ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType)); + + long innerSize = 1; + for (int dim = 0; dim < inputType.rank(); ++dim) { + innerSize *= inputType.dimensions().get(dim).size().get(); + } + + for (int dim = 0; dim < inputType.rank(); ++dim) { + String inputDimensionName = inputType.dimensions().get(dim).name(); + long inputDimensionSize = inputType.dimensions().get(dim).size().get(); + long previousInnerSize = innerSize; + innerSize /= inputDimensionSize; + + ExpressionNode inputDimensionExpression; + if (inputDimensionSize == 1) { + inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); + } else if (dim == (inputType.rank() - 1)) { + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + inputDimensionExpression = new EmbracedNode(div); + } else { + ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); + ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); + ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); + ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div)); + } + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); + } + + TensorFunction inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(outputType.type(), wrapScalar(sliceExpression)); + return generate; + + /* // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. // Here we create a transformation tensor that is multiplied with the from tensor to map into @@ -168,11 +223,14 @@ public class Reshape extends IntermediateOperation { result = new Rename(result, to, from); } return result; + */ } + /* private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } + */ private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index e5463291ef8..8dd1e3ff33d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -182,7 +182,6 @@ public class Slice extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - // Todo: what to do? for (int i = 0; i < type.dimensions().size(); i++) { renamer.addDimension(type.dimensions().get(i).name()); for (int j = i + 1; j < type.dimensions().size(); j++) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 83086926316..e2b83246bfc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Map; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; @@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; + + // input is referenced twice due to avoidance of overflow. so make this it's own function. + inputs.get(0).exportAsRankingFunction = true; + return inputs.get(0).type().get(); } @@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation { } TensorFunction input = inputs.get(0).function().get(); - TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction exp = new Map(cap, ScalarFunctions.exp()); TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java new file mode 100644 index 00000000000..02d780c52cd --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -0,0 +1,119 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Split extends IntermediateOperation { + + private final AttributeMap attributes; + private final int output; + + private final int axis; + private int start; + private int end; + + public Split(String modelName, String nodeName, List inputs, AttributeMap attributes, int output) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + this.output = output; + axis = (int) attributes.get("axis").orElse(DoubleValue.zero).asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + OrderedTensorType inputType = inputs.get(0).type().get(); + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + int axisSize = inputType.dimensions().get(axis).size().get().intValue(); + start = 0; + end = axisSize; + + if (attributes.getList("split").isPresent()) { + List splitList = attributes.getList("split").get(); + if (output > splitList.size()) { + throw new IllegalArgumentException("Split in " + name + ": output out of range of split list"); + } + for (int i = 0; i < output; ++i) { + start += (int) splitList.get(i).asDouble(); + } + if (output < splitList.size()) { + end = start + (int) splitList.get(output).asDouble(); + } + } else { + start = axisSize / 2 * output; + end = start + axisSize / 2; + } + + if (start >= axisSize || start < 0) { + throw new IllegalArgumentException("Split in " + name + ": split start index out of range (" + start + ")"); + } + if (end > axisSize || end < 0) { + throw new IllegalArgumentException("Split in " + name + ": split end index out of range (" + end + ")"); + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < inputType.rank(); ++i) { + TensorType.Dimension inputDimension = inputType.dimensions().get(i); + long dimSize = i == axis ? end - start : inputDimension.size().get(); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), dimSize)); + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List> dimensionValues = new ArrayList<>(); + + for (int i = 0; i < inputType.rank(); ++i) { + String inputDimensionName = inputType.dimensions().get(i).name(); + ExpressionNode reference = new ReferenceNode(inputDimensionName); + ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); + } + + TensorFunction inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public Split withInputs(List inputs) { + return new Split(modelName(), name(), inputs, attributes, output); + } + + @Override + public String operationName() { return "Split"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java new file mode 100644 index 00000000000..8d3468f3d04 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -0,0 +1,100 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/** + * Onnx tile operation. + */ +public class Tile extends IntermediateOperation { + + public Tile(String modelName, String nodeName, List inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) return null; + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + IntermediateOperation repeats = inputs.get(1); + if (repeats.getConstantValue().isEmpty()) + throw new IllegalArgumentException("Tile " + name + ": repeats input must be a constant."); + + Tensor shape = repeats.getConstantValue().get().asTensor(); + if (shape.type().rank() != 1) + throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor."); + + OrderedTensorType inputType = inputs.get(0).type().get(); + if (shape.type().dimensions().get(0).size().get() != inputType.rank()) + throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank."); + + List dimSizes = new ArrayList<>(inputType.rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + TensorType.Dimension inputDimension = inputType.dimensions().get(i); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get() * dimSizes.get(i))); + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List> dimensionValues = new ArrayList<>(); + + for (int axis = 0; axis < inputType.rank(); ++axis) { + String inputDimensionName = inputType.dimensions().get(axis).name(); + long inputDimensionSize = inputType.dimensions().get(axis).size().get(); + + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode reference = new ReferenceNode(inputDimensionName); + ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size); + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); + } + + TensorFunction inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public Tile withInputs(List inputs) { + return new Tile(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Tile"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java new file mode 100644 index 00000000000..178759fbf2a --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Transpose extends IntermediateOperation { + + private final AttributeMap attributes; + + public Transpose(String modelName, String nodeName, List inputs, AttributeMap attributes) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < inputType.rank(); ++i) { + int inputIndex = inputType.rank() - 1 - i; + if (attributes.getList("perm").isPresent()) { + inputIndex = (int) attributes.getList("perm").get().get(i).asDouble(); + } + TensorType.Dimension inputDimension = inputType.dimensions().get(inputIndex); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get())); + } + OrderedTensorType result = typeBuilder.build(); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) + return null; + return inputs.get(0).function().orElse(null); + } + + @Override + public Transpose withInputs(List inputs) { + return new Transpose(modelName(), name(), inputs, attributes); + } + + @Override + public String operationName() { return "Transpose"; } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java new file mode 100644 index 00000000000..f4ed2f1b64d --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java @@ -0,0 +1,281 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.onnx; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import com.yahoo.io.IOUtils; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunction; +import com.yahoo.tensor.functions.Slice; +import org.junit.Ignore; +import org.junit.Test; +import org.tensorflow.op.core.Rank; + +import java.io.BufferedReader; +import java.io.IOException; +import java.sql.Ref; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static org.junit.Assert.assertEquals; + +/** + * @author lesters + */ +public class BertImportTestCase extends TestableModel { + + @Test + public void test() throws Exception { + String filename = "/Users/lesters/github/onnx-models/text/machine_comprehension/bert-squad/java.txt"; + List lines = IOUtils.getLines(filename); + Tensor tensor = Tensor.from(lines.get(0)); + TestableModelContext context = new TestableModelContext(); + context.put("test", new TensorValue(tensor)); + + // Tensor: tensor(d1[1],d2[256],d4[12],d5[64]) + + String expr = "tensor(d0[256],d1[768])" + + "((test{" + + "d1:(floor(0.0)), " + + "d2:(floor(((768.0 * d0 + d1) % 196608) / 768.0)), " + + "d4:(floor(((768.0 * d0 + d1) % 768.0) / 64.0)), " + + "d5:(floor((768.0 * d0 + d1) % 64.0))" + + "}))"; + Tensor result = new RankingExpression(expr).evaluate(context).asTensor(); + + assertEquals(result.sum(), -6074.247); + } + + @Ignore + @Test + public void testBertImport() { + ImportedModel model = new OnnxImporter().importModel("test", "/Users/lesters/github/onnx-models/text/machine_comprehension/bert-squad/bertsquad8_modified.onnx"); +// ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/bert/bertsquad8_modified.onnx"); +// ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/bert/bertsquad10.onnx"); +// assertEquals(0, model.signature("default").skippedOutputs().size()); +// Tensor onnxResult = evaluateVespa(model, "output", model.inputs()); +// assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult); + + String filename = "/Users/lesters/github/onnx-models/text/machine_comprehension/bert-squad/context.vespa"; + + // bert/encoder/layer_0/attention/self/mul_2 + assert null != model.largeConstants().get("test_bert_encoder_layer_0_attention_self_Reshape_3__294"); + + TestableModelContext context; + if (true) { + // inputs + Tensor unique_ids_raw_output__9 = Tensor.from("tensor(d0[1]):[1]"); + Tensor input_ids = Tensor.from("tensor(d0[1],d1[256]):[101,2073,2003,1996,5661,10549,2000,2175,1029,102,1999,2049,2220,2086,1010,1996,2047,4680,2415,3478,2000,3113,5270,1998,6599,10908,1012,1031,2260,1033,2011,2526,1010,2116,13773,3028,5661,2020,10549,1996,2172,3469,9587,9363,2638,2415,1999,2624,3799,2058,1996,2624,4560,4680,2415,2349,2000,1996,3732,1005,1055,3132,2686,1012,1037,10428,5468,2000,5446,2019,4935,3081,1037,3309,4171,3478,2000,3362,1996,3223,2048,1011,12263,3484,2000,3413,1012,1999,2238,2384,1010,2136,2624,4560,2328,1996,2148,2534,1010,1037,1002,1020,1012,6255,2454,1010,2630,1998,2317,9311,1010,5815,3770,1010,2199,2675,2519,1006,1021,1010,4278,25525,1007,1997,8327,2686,102,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"); + Tensor input_mask = Tensor.from("tensor(d0[1],d1[256]):[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"); + Tensor segment_ids = Tensor.from("tensor(d0[1],d1[256]):[0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]"); + + context = contextFrom(model); + context.put("unique_ids_raw_output___9", new TensorValue(unique_ids_raw_output__9)); + context.put("input_ids", new TensorValue(input_ids)); + context.put("input_mask", new TensorValue(input_mask)); + context.put("segment_ids", new TensorValue(segment_ids)); + + context.write(filename); + } else { + context = TestableModelContext.read(filename); + } + + // expected outputs from onnxruntime + Tensor unique_ids = Tensor.from("tensor(d0[1]):[1]"); + Tensor unstack_0 = Tensor.from("tensor(d0[1],d1[256]):[-7.169589,-8.165145,-8.795558,-8.276284,-8.408593,-8.313643,-8.421538,-8.771402,-8.71111,-8.014886,-6.4415646,-7.5764513,-7.7209125,-8.689668,-8.05441,-4.357495,-4.082243,-2.4557219,-6.421309,-7.8627315,-8.612887,-8.122109,-8.072487,-8.678347,-8.467162,-8.818881,-8.143695,-6.412044,-7.765201,-8.125683,-2.5739796,-4.2929254,-7.8812947,-2.4893682,-1.5166948,-6.2354407,-4.039099,-5.6837378,-0.41342238,3.0958412,1.5454307,-0.89450985,6.0985346,-4.108738,-4.67186,-2.6797965,-0.65007347,4.2300944,-2.9132476,-4.853151,1.1584995,4.041984,-3.5257776,-2.3050616,-4.363427,-7.1510825,-8.426602,-6.6682553,-7.027374,-8.076435,-8.3017435,-6.9958987,-7.243815,-7.1347113,-7.5506253,-7.771371,-8.606251,-7.472072,-7.902196,-7.563202,-7.330995,-7.5767503,-7.8097973,-6.645113,-8.927777,-8.438513,-8.708496,-8.474434,-8.231956,-8.635139,-7.8764973,-8.80273,-9.103729,-9.07057,-8.610826,-9.084642,-8.795743,-7.2711506,-7.733648,-8.708181,-8.020964,-7.1652384,-7.9469404,-9.461184,-8.624146,-7.2252526,-6.4015207,-9.220176,-9.195709,-8.228707,-7.9646325,-8.685807,-8.980191,-8.858017,-9.290145,-8.921865,-7.656322,-8.872562,-8.898288,-8.683226,-9.219653,-8.371141,-7.130355,-8.930712,-9.05438,-8.771264,-9.621703,-8.550959,-7.327657,-9.138217,-9.377564,-9.111144,-9.653343,-8.726485,-8.215803,-9.300696,-8.044907,-8.641199,-8.641449,-8.640882,-8.641207,-8.648375,-8.645785,-8.639973,-8.650788,-8.660226,-8.6503525,-8.6601925,-8.647732,-8.652576,-8.665123,-8.653585,-8.653888,-8.661093,-8.661934,-8.6514845,-8.662573,-8.671499,-8.661195,-8.667901,-8.666959,-8.659721,-8.673244,-8.678537,-8.66441,-8.651034,-8.660175,-8.659063,-8.657169,-8.6603565,-8.6569,-8.649067,-8.651927,-8.6421995,-8.649052,-8.6478615,-8.6426935,-8.646153,-8.646865,-8.636821,-8.643324,-8.645994,-8.639597,-8.647679,-8.655649,-8.6609745,-8.654906,-8.6613455,-8.656511,-8.663024,-8.675192,-8.663131,-8.665018,-8.652522,-8.661668,-8.66894,-8.670112,-8.67217,-8.657303,-8.651893,-8.652592,-8.650168,-8.640702,-8.636455,-8.647628,-8.638621,-8.648,-8.656844,-8.649821,-8.657603,-8.648884,-8.661986,-8.663507,-8.652322,-8.662775,-8.664504,-8.662872,-8.668943,-8.6559105,-8.655738,-8.671845,-8.6666,-8.659552,-8.679308,-8.659756,-8.664594,-8.6688175,-8.666396,-8.673796,-8.65924,-8.664916,-8.6703005,-8.6611395,-8.660061,-8.660967,-8.672797,-8.66394,-8.657039,-8.671023,-8.663469,-8.659371,-8.6713705,-8.659359,-8.649764,-8.6620035,-8.656843,-8.654225,-8.661666,-8.647326,-8.652874,-8.650523,-8.644273,-8.649993,-8.65307,-8.645219,-8.6537075,-8.655814,-8.654312,-8.658724,-8.666763,-8.654713,-8.662302,-8.672376,-8.661079,-8.659652,-8.661736]"); + Tensor unstack_1 = Tensor.from("tensor(d0[1],d1[256]):[-5.1743593,-8.167716,-8.096918,-8.610186,-8.627197,-8.518608,-8.413071,-8.04796,-8.405228,-5.775467,-8.891069,-8.499419,-8.482899,-7.4575906,-5.1060586,-9.029796,-7.9796743,-7.411322,-0.62632525,-8.209348,-8.202109,-8.436105,-8.226212,-8.245562,-7.7150273,-5.4672513,-6.134469,-8.531252,-7.390566,-6.717802,-7.9110403,-5.084878,-5.02966,-7.6901536,-7.6643076,-0.42670453,-4.2289968,-6.957412,-5.192218,-6.1616683,-6.4489427,-3.5914042,-3.7853065,-6.857571,-2.3781726,6.1620126,-3.007885,-4.688912,6.258016,-5.2202945,-6.7945094,-5.1450105,0.7468612,-4.919924,5.489712,-7.307814,-7.952,-9.152897,-6.4863043,-8.328119,-7.9448185,-8.395245,-3.6581624,-0.8252581,-8.731679,-8.624653,-7.61354,-8.755644,-8.341698,-8.758186,-5.954141,-8.560192,-8.833243,-7.6137505,-5.96118,-8.43961,-8.188338,-8.373185,-8.683964,-8.246368,-8.824446,-8.05728,-7.623751,-7.56998,-8.277908,-6.8986,-6.8709283,-9.279125,-8.84588,-7.6791453,-5.2976,-9.191589,-8.797903,-6.440836,-8.179676,-9.236156,-8.972708,-5.8724217,-6.928253,-8.685118,-8.84946,-8.293621,-7.8572874,-8.053903,-7.398021,-7.549705,-9.004784,-8.060446,-7.950672,-7.2188964,-6.497633,-8.454956,-9.045556,-7.8463507,-7.771165,-8.067679,-5.9176393,-8.09684,-8.5619955,-7.5696144,-7.100621,-7.0136676,-6.464568,-8.108538,-8.516457,-5.488856,-5.853514,-8.457255,-8.457188,-8.454844,-8.448273,-8.447864,-8.447953,-8.445874,-8.444903,-8.442595,-8.448225,-8.443758,-8.451776,-8.447646,-8.440473,-8.447313,-8.44705,-8.443977,-8.442994,-8.449743,-8.441086,-8.433916,-8.43898,-8.435363,-8.434594,-8.431777,-8.433416,-8.433545,-8.442987,-8.453411,-8.450068,-8.4503565,-8.451651,-8.450909,-8.454222,-8.456041,-8.452284,-8.449699,-8.454986,-8.455096,-8.459543,-8.458114,-8.458371,-8.4632635,-8.458183,-8.457299,-8.458008,-8.452067,-8.444335,-8.442348,-8.445211,-8.441855,-8.443939,-8.441303,-8.436119,-8.442878,-8.439337,-8.446676,-8.441184,-8.438475,-8.440033,-8.4386015,-8.447922,-8.455316,-8.452563,-8.454967,-8.459164,-8.460839,-8.453004,-8.451543,-8.446279,-8.441412,-8.448481,-8.446184,-8.448539,-8.445241,-8.444487,-8.450539,-8.446448,-8.446319,-8.447268,-8.440758,-8.448286,-8.447366,-8.437631,-8.441085,-8.444475,-8.431786,-8.441355,-8.436929,-8.432141,-8.436456,-8.435032,-8.445299,-8.442143,-8.438964,-8.445743,-8.445099,-8.444958,-8.438029,-8.439503,-8.446831,-8.43919,-8.442334,-8.446472,-8.442076,-8.449043,-8.451941,-8.449556,-8.454564,-8.455859,-8.452123,-8.461076,-8.45802,-8.456931,-8.458485,-8.45496,-8.4508295,-8.453123,-8.451649,-8.451098,-8.450148,-8.446929,-8.44253,-8.44839,-8.444667,-8.437894,-8.444409,-8.444666,-8.441956]"); + + +// model.functions().forEach((k, v) -> { +// evaluateFunction(context, model, k, ""); +// }); + +// RankingExpression e = model.expressions().get("unique_ids_graph_outputs_Identity__10"); +// evaluateFunctionDependencies(context, model, e.getRoot(), ""); +// Tensor result = e.evaluate(context).asTensor(); +// assertEquals(result, unique_ids); + + RankingExpression e = model.expressions().get("bert/encoder/layer_0/output/LayerNorm/batchnorm/add_1"); + + evaluateFunctionDependencies(context, model, e.getRoot(), ""); + context.write(filename); + Tensor result = e.evaluate(context).asTensor(); + double sum = result.sum().asDouble(); + System.out.println(sum); + + Tensor matmul1 = model.expressions().get("bert/encoder/layer_0/attention/self/MatMul_1").evaluate(context).asTensor(); + + Tensor transpose = model.expressions().get("bert/encoder/layer_0/attention/self/transpose_3").evaluate(context).asTensor(); + String cast = model.largeConstants().get("test_bert_encoder_layer_0_attention_self_Reshape_3__294"); + Tensor reshape = model.expressions().get("bert/encoder/layer_0/attention/self/Reshape_3").evaluate(context).asTensor(); + Tensor matmul = model.expressions().get("bert/encoder/layer_0/attention/output/dense/MatMul").evaluate(context).asTensor(); + Tensor add = model.expressions().get("bert/encoder/layer_0/attention/output/dense/BiasAdd").evaluate(context).asTensor(); + + Tensor add1 = model.expressions().get("bert/encoder/layer_0/attention/output/add").evaluate(context).asTensor(); + Tensor add2 = model.expressions().get("bert/encoder/layer_0/attention/output/LayerNorm/batchnorm/add_1").evaluate(context).asTensor(); + Tensor add3 = model.expressions().get("bert/encoder/layer_0/output/add").evaluate(context).asTensor(); + Tensor add4 = model.expressions().get("bert/encoder/layer_0/output/LayerNorm/batchnorm/add_1").evaluate(context).asTensor(); + + assertEquals(result, unique_ids); + +// Tensor result = model.expressions().get("unique_ids_graph_outputs_Identity__10").evaluate(context).asTensor(); +// assertEquals(result, unique_ids); + +// result = model.expressions().get("unstack_graph_outputs_Identity__7").evaluate(context).asTensor(); // or map from signature outputs +// assertEquals(result, unstack_0); + + // en feil her i outputs: har bare en: unstack, men vi mÃ¥ ha to: unstack:0 og unstack:1 + + } + + + private void evaluateFunction(Context context, ImportedModel model, String functionName, String in) { + if (!context.names().contains(functionName)) { + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); + System.out.println(in + "Looking for dependencies of function " + functionName + ": " + e.toString()); + evaluateFunctionDependencies(context, model, e.getRoot(), in); + System.out.println(in + "Evaluating function " + functionName + ": " + e.toString()); + long start = System.currentTimeMillis(); + Tensor result = e.evaluate(context).asTensor(); + context.put(functionName, new TensorValue(result)); + long end = System.currentTimeMillis(); + System.out.println(in + "[" + (end - start) + "] completed " + functionName + " (" + result.type() + "), context is: " + context.names().size() + " " + contextSize(context)); + } else { + System.out.println(in + "Function " + functionName + " already evaluated..."); + } + } + + private long contextSize(Context context) { + long size = 0; + for (String name : context.names()) { + Tensor val = context.getTensor(name); + if (val != null) size += val.size(); + } + return size; + } + + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node, String in) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + ReferenceNode ref = (ReferenceNode) node; + if (ref.getName().equals("constant")) { + String constant = ref.getArguments().expressions().get(0).toString(); + if (!context.names().contains(constant)) { + String value = null; + if (model.smallConstants().containsKey(constant)) { + value = model.smallConstants().get(constant); + } + if (model.largeConstants().containsKey(constant)) { + value = model.largeConstants().get(constant); + } + if (value != null) { + System.out.println(in + "Adding constant: " + name); + long start = System.currentTimeMillis(); + Tensor val = Tensor.from(value); + context.put(name, new TensorValue(val)); + long end = System.currentTimeMillis(); + System.out.println(in + "Added constant: " + name + " (" + val.type() + ") in [" + (end - start) + "]"); + } + } + } + if (model.functions().containsKey(name)) { + evaluateFunction(context, model, name, in + " "); + } + } + else if (node instanceof CompositeNode) { + if (node instanceof TensorFunctionNode && ((TensorFunctionNode)node).function() instanceof Generate) { + Generate generate = (Generate) ((TensorFunctionNode)node).function(); + TensorFunctionNode.ExpressionScalarFunction func = (TensorFunctionNode.ExpressionScalarFunction) generate.getBoundGenerator(); + if (func != null) { + ExpressionNode bound = func.getExpression(); + if (bound.toString().contains("imported_ml_")) { + System.out.println(in + "Found expression inside generator: " + bound.toString()); + evaluateFunctionDependencies(context, model, bound, in); + } + } + } + else if (node instanceof TensorFunctionNode && ((TensorFunctionNode)node).function() instanceof Slice) { + Slice slice = (Slice) ((TensorFunctionNode)node).function(); + for (Slice.DimensionValue value : slice.getSubspaceAddress()) { + TensorFunctionNode.ExpressionScalarFunction func = (TensorFunctionNode.ExpressionScalarFunction) value.index().orElse(null); + if (func != null) { + ExpressionNode bound = func.getExpression(); + if (bound.toString().contains("imported_ml_")) { + System.out.println(in + "Found expression inside slice: " + bound.toString()); + evaluateFunctionDependencies(context, model, bound, in); + } + } + } + } + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateFunctionDependencies(context, model, child, in); + } + } + } + + static TestableModelContext contextFrom(ImportedModel result) { + TestableModelContext context = new TestableModelContext(); + if (result != null) { + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + } + return context; + } + + private static class TestableModelContext extends MapContext implements ContextIndex { + @Override + public int size() { + return bindings().size(); + } + @Override + public int getIndex(String name) { + throw new UnsupportedOperationException(this + " does not support index lookup by name"); + } + + public void write(String filename) { + try { + for (Map.Entry entry: bindings().entrySet()) { + String line = entry.getKey() + "\t" + entry.getValue().asTensor() + "\n"; + IOUtils.writeFile(filename, line, true); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public static TestableModelContext read(String filename) { + System.out.println("Reading content from " + filename); + TestableModelContext context = new TestableModelContext(); + try (BufferedReader reader = IOUtils.createReader(filename)) { + String line; + while (null != (line = reader.readLine())) { + String[] strings = line.trim().split("\t"); + String name = strings[0]; + Tensor tensor = Tensor.from(strings[1]); + context.put(name, new TensorValue(tensor)); + } + } catch (IOException e) { + throw new RuntimeException(e); + } + System.out.println("Done reading context"); + return context; + } + } + +} diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java index 94c5577357b..0c9acc9b372 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java @@ -107,6 +107,18 @@ public class OnnxOperationsTestCase { assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y)); assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y)); assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y)); + + // broadcasting - opposite order + x = evaluate("random(d0[4]) + 1"); + y = evaluate("random(d0[2],d1[3],d2[4]) + 1"); + assertEval("add", x, y, evaluate("rename(x, d0, d2) + y", x, y)); + assertEval("sub", x, y, evaluate("rename(x, d0, d2) - y", x, y)); + assertEval("mul", x, y, evaluate("rename(x, d0, d2) * y", x, y)); + assertEval("div", x, y, evaluate("rename(x, d0, d2) / y", x, y)); + assertEval("greater", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a > b))", x, y)); + assertEval("less", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a < b))", x, y)); + assertEval("equal", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a == b))", x, y)); + assertEval("pow", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(pow(a,b)))", x, y)); } @Test @@ -185,9 +197,55 @@ public class OnnxOperationsTestCase { @Test public void testMatMul1() throws ParseException { - Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]"); - Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]"); - assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]")); + Tensor a = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + Tensor b = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("91")); + + a = evaluate("tensor(d0[3]):[1,2,3]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2]):[22, 28]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3]):[1,2,3]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2]):[14, 32]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]")); + + a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"); + b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]")); + + a = evaluate("tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]")); + + +// a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]"); +// b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]")); + +// a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,22,28,49,64,58,64,139,154]")); + +// a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,22,28,49,64,58,64,139,154]")); + +// a = evaluate("tensor(d0[3]]):[1,2,3]"); +// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2]):[22,28,58,64,22,28,58,64]")); } @Test @@ -217,6 +275,10 @@ public class OnnxOperationsTestCase { y = evaluate("tensor(d0[4]):[3,2,-1,1]"); assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]")); + + x = evaluate("tensor(d0[1],d1[2],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12]"); + y = evaluate("tensor(d0[2]):[2,6]"); + assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[6]):[1,2,3,4,5,6,7,8,9,10,11,12]")); } @Test @@ -435,6 +497,50 @@ public class OnnxOperationsTestCase { } + @Test + public void testTranspose1() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]"); + assertEval("transpose", x, evaluate("tensor(d0[3],d1[2]):[[1,4],[2,5],[3,6]]")); + } + + @Test + public void testTile6() throws ParseException { + Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + Tensor y = evaluate("tensor(d0[2]):[1,2]"); + assertEval("tile", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]")); + + x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]"); + y = evaluate("tensor(d0[2]):[3,1]"); + assertEval("tile", x, y, evaluate("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]")); + + x = evaluate("tensor(d0[1],d1[1],d2[1]):[1]"); + y = evaluate("tensor(d0[3]):[1,6,1]"); + assertEval("tile", x, y, evaluate("tensor(d0[1],d1[6],d2[1]):[1,1,1,1,1,1]")); + + } + + + @Test + public void testSplit2() throws ParseException { + Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]"); + assertEval("split", x, evaluate("tensor(d0[3]):[1,2,3]"), 0); + assertEval("split", x, evaluate("tensor(d0[3]):[4,5,6]"), 1); + assertEval("split", x, evaluate("tensor(d0[2]):[1,2]"), createAttribute("split", new int[] {2}), 0); + assertEval("split", x, evaluate("tensor(d0[4]):[3,4,5,6]"), createAttribute("split", new int[] {2}), 1); + assertEval("split", x, evaluate("tensor(d0[3]):[3,4,5]"), createAttribute("split", new int[] {2,3}), 1); + assertEval("split", x, evaluate("tensor(d0[1]):[6]"), createAttribute("split", new int[] {2,3}), 2); + + x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]"); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]")); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), 0); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), 1); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), createAttribute("split", new int[] {1}), 0); + assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), createAttribute("split", new int[] {1}), 1); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[1,4]"), createAttribute("axis", 1), 0); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[2,5]"), createAttribute("axis", 1), 1); + assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2); + } + private Tensor evaluate(String expr) throws ParseException { return evaluate(expr, null, null, null); } @@ -461,41 +567,49 @@ public class OnnxOperationsTestCase { } private void assertEval(String opName, Tensor x, Tensor expected) { - assertEval(opName, x, null, null, null, null, expected, null); + assertEval(opName, x, null, null, null, null, expected, null, 0); + } + + private void assertEval(String opName, Tensor x, Tensor expected, int output) { + assertEval(opName, x, null, null, null, null, expected, null, output); } private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, null, null, null, null, expected, attr); + assertEval(opName, x, null, null, null, null, expected, attr, 0); + } + + private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr, int output) { + assertEval(opName, x, null, null, null, null, expected, attr, output); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, null, null, null, expected, attr); + assertEval(opName, x, y, null, null, null, expected, attr, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) { - assertEval(opName, x, y, null, null, null, expected, null); + assertEval(opName, x, y, null, null, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) { - assertEval(opName, x, y, z, null, null, expected, null); + assertEval(opName, x, y, z, null, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) { - assertEval(opName, x, y, z, null, null, expected, attr); + assertEval(opName, x, y, z, null, null, expected, attr, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) { - assertEval(opName, x, y, z, q, null, expected, null); + assertEval(opName, x, y, z, q, null, expected, null, 0); } private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) { - assertEval(opName, x, y, z, q, r, expected, null); + assertEval(opName, x, y, z, q, r, expected, null, 0); } - private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) { + private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr, int output) { Context context = new MapContext(DoubleValue.NaN); List inputs = createInputs(context, x, y, z, q, r); - IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build()); + IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build(), output); optimizeAndRename(opName, op); Tensor result = evaluate(op); assertEquals(expected, result); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java index 9631bddd93d..abecf4f5cb4 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java @@ -34,6 +34,15 @@ public class SimpleImportTestCase { assertEquals(result, Tensor.from("tensor(d0[1],d1[1]):{{d0:0,d1:0}:1.3}")); } + @Test + public void testConstant() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/const.onnx"); + + MapContext context = new MapContext(); + Tensor result = model.expressions().get("output").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor():0.42")); + } + @Test public void testGather() { ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx"); @@ -48,6 +57,19 @@ public class SimpleImportTestCase { assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]")); } + @Test + public void testConcat() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx"); + + MapContext context = new MapContext(); + context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]"))); + context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]"))); + context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]"))); + + Tensor result = model.expressions().get("y").evaluate(context).asTensor(); + assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]")); + } + private void evaluateFunction(Context context, ImportedModel model, String functionName) { if (!context.names().contains(functionName)) { RankingExpression e = RankingExpression.from(model.functions().get(functionName)); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java new file mode 100644 index 00000000000..66af131aa36 --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java @@ -0,0 +1,162 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.tensorflow; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; +import com.yahoo.collections.Pair; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; +import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.system.ProcessExecuter; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.junit.Assert; +import org.junit.Ignore; +import org.junit.Test; +import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; + +import java.io.IOException; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.util.List; +import java.util.Map; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class LesterTensorflowImportTestCase { + + @Test + @Ignore + public void testPyTorchExport() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/test.onnx"); + Tensor onnxResult = evaluateVespa(model, "output", model.inputs()); + assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.2134095202835272, -0.08556838456161658]]"), onnxResult); + } + + @Test + @Ignore + public void testBERT() { + ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/bert/bertsquad10.onnx"); + } + + private Tensor evaluateVespa(ImportedModel model, String operationName, Map inputs) { + Context context = contextFrom(model); + for (Map.Entry entry : inputs.entrySet()) { + Tensor argument = vespaInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue()); + context.put(entry.getKey(), new TensorValue(argument)); + } + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + return expression.evaluate(context).asTensor(); + } + + @Test + @Ignore + public void testModelImport() { + + // MÃ¥ endre til tf 2.0 i java! + + String modelDir = "src/test/models/tensorflow/tf2/saved_model/"; + // output function + String operationName = "out"; + + // Import TF + SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve"); + ImportedModel model = new TensorFlowImporter().importModel("test", modelDir, tensorFlowModel); + ImportedModel.Signature signature = model.signature("serving_default"); + assertEquals("Should have no skipped outputs", 0, model.signature("serving_default").skippedOutputs().size()); + ImportedMlFunction output = signature.outputFunction("output", operationName); + assertNotNull(output); + + // Test TF + Session.Runner runner = tensorFlowModel.session().runner(); + runner.feed("x", tensorFlowFloatInputArgument(1, 4)); + List> results = runner.fetch(operationName).run(); + assertEquals(1, results.size()); + Tensor tfResult = TensorConverter.toVespaTensor(results.get(0)); + + // Test Vespa + Context context = contextFrom(model); + context.put("x", new TensorValue(vespaInputArgument(1, 4))); + model.functions().forEach((k, v) -> evaluateFunction(context, model, k)); + RankingExpression expression = model.expressions().get(operationName); + ExpressionOptimizer optimizer = new ExpressionOptimizer(); + optimizer.optimize(expression, (ContextIndex)context); + Tensor vespaResult = expression.evaluate(context).asTensor(); + + // Equal result? + System.out.println(tfResult); + System.out.println(vespaResult); + assertEquals(tfResult, vespaResult); + } + + private org.tensorflow.Tensor tensorFlowFloatInputArgument(int d0Size, int d1Size) { + FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size); + int i = 0; + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; ++d1) + fb1.put(i++, (float)(d1 * 1.0 / d1Size)); + return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1); + } + + private Tensor vespaInputArgument(int d0Size, int d1Size) { + Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); + for (int d0 = 0; d0 < d0Size; d0++) + for (int d1 = 0; d1 < d1Size; d1++) + b.cell(d1 * 1.0 / d1Size, d0, d1); + return b.build(); + } + + static Context contextFrom(ImportedModel result) { + TestableModelContext context = new TestableModelContext(); + result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor)))); + return context; + } + + private void evaluateFunction(Context context, ImportedModel model, String functionName) { + if (!context.names().contains(functionName)) { + RankingExpression e = RankingExpression.from(model.functions().get(functionName)); + evaluateFunctionDependencies(context, model, e.getRoot()); + context.put(functionName, new TensorValue(e.evaluate(context).asTensor())); + } + } + + private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + if (model.functions().containsKey(name)) { + evaluateFunction(context, model, name); + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateFunctionDependencies(context, model, child); + } + } + } + + private static class TestableModelContext extends MapContext implements ContextIndex { + @Override + public int size() { + return bindings().size(); + } + @Override + public int getIndex(String name) { + throw new UnsupportedOperationException(this + " does not support index lookup by name"); + } + } + +} diff --git a/model-integration/src/test/models/onnx/simple/concat.onnx b/model-integration/src/test/models/onnx/simple/concat.onnx new file mode 100644 index 00000000000..945bc3c9445 Binary files /dev/null and b/model-integration/src/test/models/onnx/simple/concat.onnx differ diff --git a/model-integration/src/test/models/onnx/simple/concat.py b/model-integration/src/test/models/onnx/simple/concat.py new file mode 100755 index 00000000000..b77cf5decc1 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/concat.py @@ -0,0 +1,25 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +import numpy as np +from onnx import helper, TensorProto + +i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1]) +j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1]) +k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1]) + +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3]) + +node = onnx.helper.make_node( + 'Concat', + inputs=['i', 'j', 'k'], + outputs=['y'], + axis=0, +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'concat_test', + inputs = [i_type, j_type, k_type], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='concat.py') +onnx.save(model_def, 'concat.onnx') diff --git a/model-integration/src/test/models/onnx/simple/const.onnx b/model-integration/src/test/models/onnx/simple/const.onnx new file mode 100644 index 00000000000..c75a92ff12c Binary files /dev/null and b/model-integration/src/test/models/onnx/simple/const.onnx differ diff --git a/model-integration/src/test/models/onnx/simple/const.py b/model-integration/src/test/models/onnx/simple/const.py new file mode 100755 index 00000000000..35d6dee1346 --- /dev/null +++ b/model-integration/src/test/models/onnx/simple/const.py @@ -0,0 +1,26 @@ +# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +import onnx +import numpy as np +from onnx import helper, TensorProto + +output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, []) + +node = onnx.helper.make_node( + 'Constant', + inputs=[], + outputs=['y'], + value=onnx.helper.make_tensor( + name='const_tensor', + data_type=onnx.TensorProto.FLOAT, + dims=(), + vals=[0.42] + ), +) +graph_def = onnx.helper.make_graph( + nodes = [node], + name = 'constant_test', + inputs = [], + outputs = [output_type] +) +model_def = helper.make_model(graph_def, producer_name='const.py') +onnx.save(model_def, 'const.onnx') diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx index 62451ad953d..0647d86ed0f 100644 Binary files a/model-integration/src/test/models/onnx/simple/gather.onnx and b/model-integration/src/test/models/onnx/simple/gather.onnx differ diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx index 1c746c90efa..41b458451d0 100644 --- a/model-integration/src/test/models/onnx/simple/simple.onnx +++ b/model-integration/src/test/models/onnx/simple/simple.onnx @@ -1,4 +1,4 @@ - simple.py:ã + simple.py:ã 0 query_tensor attribute_tensormatmul"MatMul @@ -20,4 +20,4 @@ output   -B +B \ No newline at end of file diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index d058104bf1b..39642f5cb50 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1413,6 +1413,7 @@ "public void (java.util.Collection, java.util.Map)", "public void (java.util.Map)", "public void (java.util.Map, java.util.Map)", + "public void (java.util.Map, java.util.Map, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext)", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction getFunction(java.lang.String)", "protected com.google.common.collect.ImmutableMap functions()", "public java.lang.String getBinding(java.lang.String)", @@ -1568,7 +1569,7 @@ "public void (java.util.Map)", "public void (java.util.Collection, java.util.Map)", "public void (java.util.Collection, java.util.Map, java.util.Map)", - "public void (com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map)", + "public void (com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext)", "public void addFunctionSerialization(java.lang.String, java.lang.String)", "public void addArgumentTypeSerialization(java.lang.String, java.lang.String, com.yahoo.tensor.TensorType)", "public void addFunctionTypeSerialization(java.lang.String, com.yahoo.tensor.TensorType)", @@ -1597,6 +1598,24 @@ ], "fields": [] }, + "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionScalarFunction": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.tensor.functions.ScalarFunction" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void (com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getExpression()", + "public java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public bridge synthetic java.lang.Object apply(java.lang.Object)" + ], + "fields": [] + }, "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction": { "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", "interfaces": [], diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 674571ff73e..f2f8799b342 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -134,6 +134,8 @@ public class ExpressionFunction { for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) { argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString()); } + String symbol = toSymbol(argumentBindings); + System.out.println("Expanding function " + symbol); return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java index 83aabada8f0..9d094ce06f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -22,6 +22,8 @@ public class FunctionReferenceContext { /** Mapping from argument names to the expressions they resolve to */ private final Map bindings = new HashMap<>(); + private final FunctionReferenceContext parent; + /** Create a context for a single serialization task */ public FunctionReferenceContext() { this(Collections.emptyList()); @@ -43,9 +45,14 @@ public class FunctionReferenceContext { /** Create a context for a single serialization task */ public FunctionReferenceContext(Map functions, Map bindings) { + this(functions, bindings, null); + } + + public FunctionReferenceContext(Map functions, Map bindings, FunctionReferenceContext parent) { this.functions = ImmutableMap.copyOf(functions); if (bindings != null) this.bindings.putAll(bindings); + this.parent = parent; } private static ImmutableMap toMap(Collection list) { @@ -56,16 +63,34 @@ public class FunctionReferenceContext { } /** Returns a function or null if it isn't defined in this context */ - public ExpressionFunction getFunction(String name) { return functions.get(name); } + public ExpressionFunction getFunction(String name) { + ExpressionFunction function = functions.get(name); + if (function != null) { + return function; + } + if (parent != null) { + return parent.getFunction(name); + } + return null; + } protected ImmutableMap functions() { return functions; } /** Returns the resolution of an identifier, or null if it isn't defined in this context */ - public String getBinding(String name) { return bindings.get(name); } + public String getBinding(String name) { + String binding = bindings.get(name); + if (binding != null) { + return binding; + } + if (parent != null) { + return parent.getBinding(name); + } + return null; + } /** Returns a new context with the bindings replaced by the given bindings */ public FunctionReferenceContext withBindings(Map bindings) { - return new FunctionReferenceContext(this.functions, bindings); + return new FunctionReferenceContext(this.functions, bindings, this); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 8fec3603f3e..a994f5247b7 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -74,20 +74,49 @@ public final class ReferenceNode extends CompositeNode { return string.append(context.getBinding(getName())); } + String name = getName(); // A reference to a function? ExpressionFunction function = context.getFunction(getName()); if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) { // a function reference: replace by the referenced function wrapped in rankingExpression - if (path == null) - path = new ArrayDeque<>(); - String myPath = getName() + getArguments().expressions(); - if (path.contains(myPath)) - throw new IllegalStateException("Cycle in ranking expression function: " + path); - path.addLast(myPath); - ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); - path.removeLast(); - context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); - return string.append("rankingExpression(").append(instance.getName()).append(')'); +// if (path == null) +// path = new ArrayDeque<>(); +// String myPath = getName() + getArguments().expressions(); +// if (path.contains(myPath)) +// throw new IllegalStateException("Cycle in ranking expression function: " + path); +// path.addLast(myPath); +// ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); +// path.removeLast(); +// context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); +// return string.append("rankingExpression(").append(instance.getName()).append(')'); + +// return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString()); + + // hack for testing: + // So, this worked. Meaning that when expanding we could probably cut down on the context tree? +// String expression = function.getBody().toString(); +// context.addFunctionSerialization(RankingExpression.propertyName(getName()), expression); // <- actually set by deriveFunctionProperties - this will only overwrite + + String prefix = string.toString(); // incredibly ugly hack - for testing this + + // so problem here with input values + if (prefix.startsWith("attribute")) { + if (name.equals("segment_ids") || name.equals("input_mask") || name.equals("input_ids")) { + return string.append(getName()); + // TODO: divine this! + } + } + + // so, in one case +// rankprofile[2].fef.property[35].name "rankingExpression(imported_ml_function_bertsquad8_input_ids).rankingScript" +// rankprofile[2].fef.property[35].value "input_ids" + // vs +// rankprofile[2].fef.property[2].name "rankingExpression(input_ids).rankingScript" +// rankprofile[2].fef.property[2].value "attribute(input_ids)" + // uppermost is wrong, then we need the below + + return string.append("rankingExpression(").append(getName()).append(')'); + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index d7807caa2b6..c79f5556e03 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -50,7 +50,7 @@ public class SerializationContext extends FunctionReferenceContext { */ public SerializationContext(Collection functions, Map bindings, Map serializedFunctions) { - this(toMap(functions), bindings, serializedFunctions); + this(toMap(functions), bindings, serializedFunctions, null); } private static ImmutableMap toMap(Collection list) { @@ -69,8 +69,8 @@ public class SerializationContext extends FunctionReferenceContext { * is transferred to this and will be modified in it */ public SerializationContext(ImmutableMap functions, Map bindings, - Map serializedFunctions) { - super(functions, bindings); + Map serializedFunctions, FunctionReferenceContext root) { + super(functions, bindings, root); this.serializedFunctions = serializedFunctions; } @@ -92,7 +92,7 @@ public class SerializationContext extends FunctionReferenceContext { @Override public SerializationContext withBindings(Map bindings) { - return new SerializationContext(functions(), bindings, this.serializedFunctions); + return new SerializationContext(functions(), bindings, this.serializedFunctions, this); } public Map serializedFunctions() { return serializedFunctions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 6e1cdf52158..1ab9702367a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -143,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode { return new ExpressionScalarFunction(node); } - private static class ExpressionScalarFunction implements ScalarFunction { + public static class ExpressionScalarFunction implements ScalarFunction { private final ExpressionNode expression; @@ -151,6 +151,10 @@ public class TensorFunctionNode extends CompositeNode { this.expression = expression; } + public ExpressionNode getExpression() { + return expression; + } + @Override public Double apply(EvaluationContext context) { return expression.evaluate(new ContextWrapper(context)).asDouble(); @@ -321,13 +325,45 @@ public class TensorFunctionNode extends CompositeNode { public ToStringContext parent() { return wrappedToStringContext; } + private int contextNodes() { + int nodes = 0; + if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) { + nodes += ((ExpressionToStringContext)wrappedToStringContext).contextNodes(); + } else if (wrappedToStringContext != null) { + nodes += 1; + } + if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) { + nodes += ((ExpressionToStringContext)wrappedSerializationContext).contextNodes(); + } else if (wrappedSerializationContext != null) { + nodes += 1; + } + return nodes + 1; + } + + private int contextDepth() { + int depth = 0; + if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) { + depth += ((ExpressionToStringContext)wrappedToStringContext).contextDepth(); + } + if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) { + int d = ((ExpressionToStringContext)wrappedSerializationContext).contextDepth(); + depth = Math.max(depth, d); + } + return depth + 1; + } + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ @Override public String getBinding(String name) { - if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) - return wrappedToStringContext.getBinding(name); - else - return wrappedSerializationContext.getBinding(name); +// System.out.println("getBinding for " + name + " with node count " + contextNodes() + " and max depth " + contextDepth()); + String binding; + if (wrappedToStringContext != null) { + binding = wrappedToStringContext.getBinding(name); + if (binding != null) { + return binding; + } + } + return wrappedSerializationContext.getBinding(name); } /** Returns a new context with the bindings replaced by the given bindings */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index a541eac2421..95652bb0e15 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -28,8 +29,16 @@ public class ConstantDereferencer extends ExpressionTransformer 0) { return node; // not a number constant reference } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 807eb3aa7ce..7b246f22cf2 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -389,6 +389,126 @@ public class EvaluationTestCase { } + @Test + public void testTile() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]", + "tensor(d0[2],d1[4])(tensor0{input0:(d0 % 2), input1:(d1 % 2) } )", + "tensor(input0[2],input1[2]):[1, 2, 3, 4]", + "tensor(repeats0[2]):[1,2]"); + + tester.assertEvaluates("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]", + "tensor(d0[6],d1[2])(tensor0{input0:(d0 % 2), input1:(d1 % 2) } )", + "tensor(input0[2],input1[2]):[1, 2, 3, 4]", + "tensor(repeats0[2]):[3,1]"); + } + + @Test + public void testReshape() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor(d0[4]):[1,2,3,4]", + "tensor(d0[4])(tensor0{a0:(d0 / 2), a1:(d0 % 2)})", + "tensor(a0[2],a1[2]):[1,2,3,4]", + "tensor(d0[1]):[4]"); + + tester.assertEvaluates("tensor(d0[2],d1[2]):[1,2,3,4]", + "tensor(d0[2],d1[2])(tensor0{a0:(d0), a1:(d1)})", + "tensor(a0[2],a1[2]):[1,2,3,4]", + "tensor(d0[2]):[2,2]"); + + tester.assertEvaluates("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]", + "tensor(d0[2],d1[1],d2[2])(tensor0{a0:(d0), a1:(d2)})", + "tensor(a0[2],a1[2]):[1,2,3,4]", + "tensor(d0[3]):[2,1,2]"); + + tester.assertEvaluates("tensor(d0[3],d1[2]):[1,2,3,4,5,6]", + "tensor(d0[3],d1[2])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", + "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", + "tensor(d0[2]):[3,2]"); + + tester.assertEvaluates("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]", + "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", + "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", + "tensor(d0[4]):[3,2,-1,1]"); + + } + + @Test + public void testMatmul() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor():{91}", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)", + "tensor(d0[6]):[1,2,3,4,5,6]", + "tensor(d0[6]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d1[2]):[22, 28]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)", + "tensor(d0[3]):[1,2,3]", + "tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d1[2]):[22, 28]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)", + "tensor(d0[3],d1[2]):[1,2,3,4,5,6]", + "tensor(d0[3]):[1,2,3]"); + + tester.assertEvaluates("tensor(d0[2],d2[2]):[22,28,49,64]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d1)", + "tensor(d0[2],d1[3]):[1,2,3,4,5,6]", + "tensor(d1[3],d2[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[2],d3[2]):[22,28,49,64]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d2)", + "tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]", + "tensor(d2[3],d3[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[2],d3[2]):[22,28,49,64]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d2)", + "tensor(d1[2],d2[3]):[1,2,3,4,5,6]", + "tensor(d0[1],d2[3],d3[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[4],d2[2],d4[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]", + "reduce(join(tensor0{d1:0}, tensor1, f(x,y)(x*y)), sum, d3)", // notice peek + "tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]", + "tensor(d0[1],d1[4],d3[3],d4[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + + tester.assertEvaluates("tensor(d0[1],d1[4],d2[2],d4[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d3)", + "tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]", + "tensor(d0[1],d1[4],d3[3],d4[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + + } + + @Test + public void testSplit() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor(d0[3]):[1,2,3]", + "tensor(d0[3])(tensor0{input0:(d0)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[3]):[4,5,6]", + "tensor(d0[3])(tensor0{input0:(d0+3)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[4]):[3,4,5,6]", + "tensor(d0[4])(tensor0{input0:(d0+2)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[2]):[3,4]", + "tensor(d0[2])(tensor0{input0:(d0+2)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[2]):[5,6]", + "tensor(d0[2])(tensor0{input0:(d0+4)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[3]):[1,2,3]", + "tensor(d0[1],d1[3])(tensor0{input0:(d0), input1:(d1)} )", + "tensor(input0[2],input1[3]):[[1,2,3],[4,5,6]]"); + tester.assertEvaluates("tensor(d0[1],d1[3]):[4,5,6]", + "tensor(d0[1],d1[3])(tensor0{input0:(d0+1), input1:(d1)} )", + "tensor(input0[2],input1[3]):[[1,2,3],[4,5,6]]"); + } + @Test public void testTake() { EvaluationTester tester = new EvaluationTester(); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 04f859e2802..2e13fe0de26 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1653,6 +1653,7 @@ "public void (com.yahoo.tensor.TensorType, java.util.function.Function)", "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)", "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", + "public com.yahoo.tensor.functions.ScalarFunction getBoundGenerator()", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", @@ -2548,6 +2549,7 @@ "public void (com.yahoo.tensor.functions.TensorFunction, java.util.List)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)", + "public java.util.List getSubspaceAddress()", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index 219a3fa2278..68c4aa4c809 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -58,7 +58,12 @@ class IndexedDoubleTensor extends IndexedTensor { @Override public IndexedTensor.BoundBuilder cell(double value, long ... indexes) { - values[(int)toValueIndex(indexes, sizes())] = value; + int index = (int)toValueIndex(indexes, sizes()); + if (index < 0 || index >= values.length) { + System.out.println("Argh"); + } + values[index] = value; +// values[(int)toValueIndex(indexes, sizes())] = value; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index fa3d70a4ddf..2a2551c9a58 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -71,6 +71,10 @@ public class Generate extends PrimitiveTensorFunction getBoundGenerator() { + return boundGenerator; + } + @Override public List> arguments() { return Collections.emptyList(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index bccd66acd31..a0a3552eb92 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -53,6 +53,10 @@ public class Slice extends PrimitiveTensorFunction(arguments.get(0), subspaceAddress); } + public List> getSubspaceAddress() { + return subspaceAddress; + } + @Override public PrimitiveTensorFunction toPrimitive() { return this; } -- cgit v1.2.3