aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2020-04-03 11:29:43 +0200
committerLester Solbakken <lesters@oath.com>2020-04-03 11:29:43 +0200
commit3789127189224d6cbd6f109b9a95f848869ea6cc (patch)
tree79cef74e6c61da059ed0eae79632fa001433ddc2
parent706cb2d3b2d623318ba9c0a8db0e4355448af65a (diff)
for testing onlylesters/bert-testing
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java43
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java5
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java2
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java1
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java11
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java260
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java15
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java8
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java196
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java13
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java25
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java25
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java56
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java8
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java92
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java131
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java60
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java1
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java9
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java119
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java100
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java54
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java281
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java140
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java22
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java162
-rw-r--r--model-integration/src/test/models/onnx/simple/concat.onnxbin0 -> 135 bytes
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/concat.py25
-rw-r--r--model-integration/src/test/models/onnx/simple/const.onnxbin0 -> 97 bytes
-rwxr-xr-xmodel-integration/src/test/models/onnx/simple/const.py26
-rw-r--r--model-integration/src/test/models/onnx/simple/gather.onnxbin150 -> 150 bytes
-rw-r--r--model-integration/src/test/models/onnx/simple/simple.onnx4
-rw-r--r--searchlib/abi-spec.json21
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java31
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java49
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java46
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java20
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java120
-rw-r--r--vespajlib/abi-spec.json2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java4
50 files changed, 1982 insertions, 261 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index 6de7c985326..65443117c0a 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -47,12 +47,15 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
private final SortedSet<Reference> queryFeaturesNotDeclared;
private boolean tensorsAreUsed;
+ private final MapEvaluationTypeContext parent;
+
MapEvaluationTypeContext(Collection<ExpressionFunction> functions, Map<Reference, TensorType> featureTypes) {
super(functions);
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = new ArrayDeque<>();
this.queryFeaturesNotDeclared = new TreeSet<>();
tensorsAreUsed = false;
+ parent = null;
}
private MapEvaluationTypeContext(Map<String, ExpressionFunction> functions,
@@ -60,12 +63,14 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
Map<Reference, TensorType> featureTypes,
Deque<Reference> currentResolutionCallStack,
SortedSet<Reference> queryFeaturesNotDeclared,
- boolean tensorsAreUsed) {
+ boolean tensorsAreUsed,
+ MapEvaluationTypeContext parent) {
super(functions, bindings);
this.featureTypes.putAll(featureTypes);
this.currentResolutionCallStack = currentResolutionCallStack;
this.queryFeaturesNotDeclared = queryFeaturesNotDeclared;
this.tensorsAreUsed = tensorsAreUsed;
+ this.parent = parent;
}
public void setType(Reference reference, TensorType type) {
@@ -82,16 +87,45 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
resolvedTypes.clear();
}
+ private TensorType resolvedType(Reference reference, int depth) {
+// System.out.println(indent + "In resolvedtype - resolving type for " + reference.toString());
+ TensorType resolvedType = resolvedTypes.get(reference);
+ if (resolvedType != null) {
+// System.out.println("Found previously resolved type for " + reference + " at depth " + depth + ": (" + resolvedType + ")");
+ return resolvedType;
+ }
+ if (parent != null) return parent.resolvedType(reference, depth + 1); // what about argument types? Careful with this!
+// System.out.println("Could NOT find type for " + reference + " - down to depth " + depth);
+ return null;
+ }
+
+ private MapEvaluationTypeContext findOriginalParent() {
+ if (parent != null)
+ return parent.findOriginalParent();
+ return this;
+ }
+
@Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
- TensorType resolvedType = resolvedTypes.get(reference);
+ // TensorType resolvedType = resolvedTypes.get(reference);
+ TensorType resolvedType = resolvedType(reference, 0);
if (resolvedType != null) return resolvedType;
resolvedType = resolveType(reference);
if (resolvedType == null)
return defaultTypeOf(reference); // Don't store fallback to default as we may know more later
- resolvedTypes.put(reference, resolvedType);
+
+// System.out.println("Resolved type of " + reference + ": (" + resolvedType + ")");
+
+ // MÃ¥ inn her med et konsept av global eller lokal.
+ // For globale - legg i lavest parent!
+ MapEvaluationTypeContext originalParent = findOriginalParent();
+ if (originalParent == null) {
+ originalParent = this;
+ }
+ originalParent.resolvedTypes.put(reference, resolvedType);
+
if (resolvedType.rank() > 0)
tensorsAreUsed = true;
return resolvedType;
@@ -103,6 +137,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
currentResolutionCallStack.stream().map(Reference::toString).collect(Collectors.joining(" -> ")) +
" -> " + reference);
+
// Bound to a function argument, and not to a same-named identifier (which would lead to a loop)?
Optional<String> binding = boundIdentifier(reference);
if (binding.isPresent() && ! binding.get().equals(reference.toString())) {
@@ -254,7 +289,7 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
featureTypes,
currentResolutionCallStack,
queryFeaturesNotDeclared,
- tensorsAreUsed);
+ tensorsAreUsed, this);
}
}
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 23eb814de81..ea126123a25 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -680,11 +680,12 @@ public class RankProfile implements Cloneable {
Map<String, RankingExpressionFunction> inlineFunctions =
compileFunctions(this::getInlineFunctions, queryProfiles, featureTypes, importedModels, Collections.emptyMap(), expressionTransforms);
+ firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+ secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
+
// Function compiling second pass: compile all functions and insert previously compiled inline functions
functions = compileFunctions(this::getFunctions, queryProfiles, featureTypes, importedModels, inlineFunctions, expressionTransforms);
- firstPhaseRanking = compile(this.getFirstPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
- secondPhaseRanking = compile(this.getSecondPhaseRanking(), queryProfiles, featureTypes, importedModels, getConstants(), inlineFunctions, expressionTransforms);
}
private void checkNameCollisions(Map<String, RankingExpressionFunction> functions, Map<String, Value> constants) {
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 1a22b98fd9f..3578cc786ed 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -222,7 +222,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
List<ExpressionFunction> functionExpressions) {
SerializationContext context = new SerializationContext(functionExpressions);
for (Map.Entry<String, RankProfile.RankingExpressionFunction> e : functions.entrySet()) {
+ System.out.println("Deriving: " + e.getKey());
String expressionString = e.getValue().function().getBody().getRoot().toString(new StringBuilder(), context, null, null).toString();
+ System.out.println("-> Done deriving: " + e.getKey() + ": " + expressionString);
context.addFunctionSerialization(RankingExpression.propertyName(e.getKey()), expressionString);
for (Map.Entry<String, TensorType> argumentType : e.getValue().function().argumentTypes().entrySet())
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
index 89b8889b4ae..8d9098a10f1 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/processing/RankingExpressionTypeResolver.java
@@ -79,6 +79,7 @@ public class RankingExpressionTypeResolver extends Processor {
}
context.forgetResolvedTypes();
+ System.out.println("Resolving type for " + function.getKey());
TensorType type = resolveType(expressionFunction.getBody(), "function '" + function.getKey() + "'", context);
function.getValue().setReturnType(type);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
index b6f7ab4ff62..d3e029b8de5 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/RankSetupValidator.java
@@ -102,7 +102,8 @@ public class RankSetupValidator extends Validator {
}
private void deleteTempDir(File dir) {
- IOUtils.recursiveDeleteDir(dir);
+ System.out.println("Here we were supposed to delete tmpdir: " + dir.getAbsolutePath());
+// IOUtils.recursiveDeleteDir(dir);
}
private void writeConfigs(String dir, AbstractConfigProducer<?> producer) throws IOException {
@@ -133,7 +134,13 @@ public class RankSetupValidator extends Validator {
}
private static void writeConfig(String dir, String configName, ConfigInstance config) throws IOException {
- IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false);
+
+ String output = StringUtilities.implodeMultiline(ConfigInstance.serialize(config));
+ System.out.println("Writing config for in " + dir + " for configName '" + configName + "' ");
+ System.out.println(output);
+ IOUtils.writeFile(dir + configName, output, false);
+
+// IOUtils.writeFile(dir + configName, StringUtilities.implodeMultiline(ConfigInstance.serialize(config)), false);
}
private boolean execValidate(String configId, SearchCluster sc, String sdName, DeployLogger deployLogger) {
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index c3d6f457ce8..9f649bc820a 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -386,138 +386,138 @@ public class ConvertedModel {
*/
private static void reduceBatchDimensions(RankingExpression expression, ImportedMlModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
- MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
-
- // Add any missing inputs for type resolution
- Set<String> functionNames = new HashSet<>();
- addFunctionNamesIn(expression.getRoot(), functionNames, model);
- for (String functionName : functionNames) {
- Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
- if (requiredType.isPresent()) {
- Reference ref = Reference.fromIdentifier(functionName);
- if (typeContext.getType(ref).equals(TensorType.empty)) {
- typeContext.setType(ref, requiredType.get());
- }
- }
- }
- typeContext.forgetResolvedTypes();
-
- TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
-
- // Check generated functions for inputs to reduce
- for (String functionName : functionNames) {
- if ( ! model.functions().containsKey(functionName)) continue;
-
- RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
- if (rankingExpressionFunction == null) {
- throw new IllegalArgumentException("Model refers to generated function '" + functionName +
- "but this function is not present in " + profile);
- }
- RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
- functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
- }
-
- // Check expression for inputs to reduce
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, model, typeContext);
- TensorType typeAfterReducing = root.type(typeContext);
- root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
- expression.setRoot(root);
- }
-
- private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
- MapEvaluationTypeContext typeContext) {
- if (node instanceof TensorFunctionNode) {
- TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
- if (tensorFunction instanceof Rename) {
- List<ExpressionNode> children = ((TensorFunctionNode)node).children();
- if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(tensorFunction, typeContext);
- }
- }
- // Modify any renames in expression to disregard batch dimension
- else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
- TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
- TensorType childType = childFunction.type(typeContext);
- Rename rename = (Rename) tensorFunction;
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (TensorType.Dimension dimension : childType.dimensions()) {
- int i = rename.fromDimensions().indexOf(dimension.name());
- if (i < 0) {
- throw new IllegalArgumentException("Rename does not contain dimension '" +
- dimension + "' in child expression type: " + childType);
- }
- from.add((String)rename.fromDimensions().get(i));
- to.add((String)rename.toDimensions().get(i));
- }
- return new TensorFunctionNode(new Rename<>(childFunction, from, to));
- }
- }
- }
- if (node instanceof ReferenceNode) {
- ReferenceNode referenceNode = (ReferenceNode) node;
- if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
- }
- }
- if (node instanceof CompositeNode) {
- List<ExpressionNode> children = ((CompositeNode)node).children();
- List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
- for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
- }
- return ((CompositeNode)node).setChildren(transformedChildren);
- }
- return node;
- }
-
- private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
- TensorFunction result = function;
- TensorType type = function.type(context);
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
- }
- if (reduceDimensions.size() > 0) {
- result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
- context.forgetResolvedTypes(); // We changed types
- }
- }
- return new TensorFunctionNode(result);
+// MapEvaluationTypeContext typeContext = profile.typeContext(queryProfiles);
+//
+// // Add any missing inputs for type resolution
+// Set<String> functionNames = new HashSet<>();
+// addFunctionNamesIn(expression.getRoot(), functionNames, model);
+// for (String functionName : functionNames) {
+// Optional<TensorType> requiredType = model.inputTypeSpec(functionName).map(TensorType::fromSpec);
+// if (requiredType.isPresent()) {
+// Reference ref = Reference.fromIdentifier(functionName);
+// if (typeContext.getType(ref).equals(TensorType.empty)) {
+// typeContext.setType(ref, requiredType.get());
+// }
+// }
+// }
+// typeContext.forgetResolvedTypes();
+//
+// TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+//
+// // Check generated functions for inputs to reduce
+// for (String functionName : functionNames) {
+// if ( ! model.functions().containsKey(functionName)) continue;
+//
+// RankProfile.RankingExpressionFunction rankingExpressionFunction = profile.getFunctions().get(functionName);
+// if (rankingExpressionFunction == null) {
+// throw new IllegalArgumentException("Model refers to generated function '" + functionName +
+// "but this function is not present in " + profile);
+// }
+// RankingExpression functionExpression = rankingExpressionFunction.function().getBody();
+// functionExpression.setRoot(reduceBatchDimensionsAtInput(functionExpression.getRoot(), model, typeContext));
+// }
+//
+// // Check expression for inputs to reduce
+// ExpressionNode root = expression.getRoot();
+// root = reduceBatchDimensionsAtInput(root, model, typeContext);
+// TensorType typeAfterReducing = root.type(typeContext);
+// root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+// expression.setRoot(root);
}
- /**
- * If batch dimensions have been reduced away above, bring them back here
- * for any following computation of the tensor.
- */
- // TODO: determine when this is not necessary!
- private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
- if (after.equals(before)) return node;
-
- TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
- for (TensorType.Dimension dimension : before.dimensions()) {
- if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
- typeBuilder.indexed(dimension.name(), 1);
- }
- }
- TensorType expandDimensionsType = typeBuilder.build();
- if (expandDimensionsType.dimensions().size() > 0) {
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
- Generate generatedFunction = new Generate(expandDimensionsType,
- new GeneratorLambdaFunctionNode(expandDimensionsType,
- generatedExpression)
- .asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
- return node;
- }
+// private static ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, ImportedMlModel model,
+// MapEvaluationTypeContext typeContext) {
+// if (node instanceof TensorFunctionNode) {
+// TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+// if (tensorFunction instanceof Rename) {
+// List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+// if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+// ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
+// return reduceBatchDimensionExpression(tensorFunction, typeContext);
+// }
+// }
+// // Modify any renames in expression to disregard batch dimension
+// else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
+// TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
+// TensorType childType = childFunction.type(typeContext);
+// Rename rename = (Rename) tensorFunction;
+// List<String> from = new ArrayList<>();
+// List<String> to = new ArrayList<>();
+// for (TensorType.Dimension dimension : childType.dimensions()) {
+// int i = rename.fromDimensions().indexOf(dimension.name());
+// if (i < 0) {
+// throw new IllegalArgumentException("Rename does not contain dimension '" +
+// dimension + "' in child expression type: " + childType);
+// }
+// from.add((String)rename.fromDimensions().get(i));
+// to.add((String)rename.toDimensions().get(i));
+// }
+// return new TensorFunctionNode(new Rename<>(childFunction, from, to));
+// }
+// }
+// }
+// if (node instanceof ReferenceNode) {
+// ReferenceNode referenceNode = (ReferenceNode) node;
+// if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
+// return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
+// }
+// }
+// if (node instanceof CompositeNode) {
+// List<ExpressionNode> children = ((CompositeNode)node).children();
+// List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+// for (ExpressionNode child : children) {
+// transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
+// }
+// return ((CompositeNode)node).setChildren(transformedChildren);
+// }
+// return node;
+// }
+//
+// private static ExpressionNode reduceBatchDimensionExpression(TensorFunction function, MapEvaluationTypeContext context) {
+// TensorFunction result = function;
+// TensorType type = function.type(context);
+// if (type.dimensions().size() > 1) {
+// List<String> reduceDimensions = new ArrayList<>();
+// for (TensorType.Dimension dimension : type.dimensions()) {
+// if (dimension.size().orElse(-1L) == 1) {
+// reduceDimensions.add(dimension.name());
+// }
+// }
+// if (reduceDimensions.size() > 0) {
+// result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+// context.forgetResolvedTypes(); // We changed types
+// }
+// }
+// return new TensorFunctionNode(result);
+// }
+//
+// /**
+// * If batch dimensions have been reduced away above, bring them back here
+// * for any following computation of the tensor.
+// */
+// // TODO: determine when this is not necessary!
+// private static ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+// if (after.equals(before)) return node;
+//
+// TensorType.Builder typeBuilder = new TensorType.Builder(after.valueType());
+// for (TensorType.Dimension dimension : before.dimensions()) {
+// if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+// typeBuilder.indexed(dimension.name(), 1);
+// }
+// }
+// TensorType expandDimensionsType = typeBuilder.build();
+// if (expandDimensionsType.dimensions().size() > 0) {
+// ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1.0));
+// Generate generatedFunction = new Generate(expandDimensionsType,
+// new GeneratorLambdaFunctionNode(expandDimensionsType,
+// generatedExpression)
+// .asLongListToDoubleOperator());
+// Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
+// return new TensorFunctionNode(expand);
+// }
+// return node;
+// }
/**
* If a constant c is overridden by a function, we need to replace instances of "constant(c)" by "c" in expressions.
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
index d84d967a184..8bc9040577b 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/RankingExpressionConstantsTestCase.java
@@ -2,8 +2,11 @@
package com.yahoo.searchdefinition;
import com.yahoo.collections.Pair;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.yolean.Exceptions;
import com.yahoo.searchdefinition.derived.AttributeFields;
import com.yahoo.searchdefinition.derived.RawRankProfile;
@@ -82,6 +85,18 @@ public class RankingExpressionConstantsTestCase extends SchemaTestCase {
new AttributeFields(s)).configProperties();
assertEquals("(rankingExpression(foo).rankingScript,14.0)", rankProperties.get(0).toString());
assertEquals("(rankingExpression(firstphase).rankingScript,16.6)", rankProperties.get(2).toString());
+
+ try {
+ DerivedConfiguration config = new DerivedConfiguration(s,
+ new BaseDeployLogger(),
+ new TestProperties(),
+ rankProfileRegistry,
+ queryProfileRegistry,
+ new ImportedMlModels());
+ config.export("/Users/lesters/temp/bert/idea/");
+ } catch (Exception e) {
+ throw new IllegalArgumentException(e);
+ }
}
@Test
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
index 0cd6674751e..d5638da224c 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankProfileSearchFixture.java
@@ -41,6 +41,14 @@ class RankProfileSearchFixture {
private Search search;
private Map<String, RankProfile> compiledRankProfiles = new HashMap<>();
+ // TEMP
+ public RankProfileRegistry getRankProfileRegistry() {
+ return rankProfileRegistry;
+ }
+ public QueryProfileRegistry getQueryProfileRegistry() {
+ return queryProfileRegistry;
+ }
+
RankProfileSearchFixture(String rankProfiles) throws ParseException {
this(MockApplicationPackage.createEmpty(), new QueryProfileRegistry(), rankProfiles);
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
new file mode 100644
index 00000000000..2c0620a0c52
--- /dev/null
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithBertTestCase.java
@@ -0,0 +1,196 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchdefinition.processing;
+
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
+import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
+import ai.vespa.rankingexpression.importer.lightgbm.LightGBMImporter;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import ai.vespa.rankingexpression.importer.tensorflow.TensorFlowImporter;
+import ai.vespa.rankingexpression.importer.xgboost.XGBoostImporter;
+import com.google.common.collect.ImmutableList;
+import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
+import com.yahoo.io.IOUtils;
+import com.yahoo.path.Path;
+import com.yahoo.search.query.profile.QueryProfileRegistry;
+import com.yahoo.searchdefinition.RankProfile;
+import com.yahoo.searchdefinition.RankProfileRegistry;
+import com.yahoo.searchdefinition.Search;
+import com.yahoo.searchdefinition.SearchBuilder;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
+import com.yahoo.searchdefinition.parser.ParseException;
+import com.yahoo.searchdefinition.processing.RankingExpressionWithTensorFlowTestCase.StoringApplicationPackage;
+import com.yahoo.vespa.model.VespaModel;
+import com.yahoo.vespa.model.ml.ImportedModelTester;
+import com.yahoo.yolean.Exceptions;
+import org.junit.After;
+import org.junit.Ignore;
+import org.junit.Test;
+
+import java.io.IOException;
+import java.util.Optional;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNull;
+import static org.junit.Assert.fail;
+
+public class RankingExpressionWithBertTestCase {
+
+ private final Path applicationDir = Path.fromString("src/test/integration/bert/");
+
+ /** The model name */
+ private final static String name = "bertsquad8";
+
+ private final static String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(" + name + "_Variable), f(a,b)(a * b)), sum, d2), constant(" + name + "_Variable_1), f(a,b)(a + b))";
+
+ @After
+ public void removeGeneratedModelFiles() {
+ IOUtils.recursiveDeleteDir(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ }
+
+
+ @Ignore
+ @Test
+ public void testGlobalBertModel() throws IOException {
+ ImportedModelTester tester = new ImportedModelTester(name, applicationDir);
+ VespaModel model = tester.createVespaModel();
+// tester.assertLargeConstant(name + "_Variable_1", model, Optional.of(10L));
+// tester.assertLargeConstant(name + "_Variable", model, Optional.of(7840L));
+
+ // At this point the expression is stored - copy application to another location which do not have a models dir
+ Path storedAppDir = applicationDir.append("copy");
+ try {
+ storedAppDir.toFile().mkdirs();
+ IOUtils.copy(applicationDir.append("services.xml").toString(), storedAppDir.append("services.xml").toString());
+ IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(),
+ storedAppDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile());
+ ImportedModelTester storedTester = new ImportedModelTester(name, storedAppDir);
+ VespaModel storedModel = storedTester.createVespaModel();
+// tester.assertLargeConstant(name + "_Variable_1", storedModel, Optional.of(10L));
+// tester.assertLargeConstant(name + "_Variable", storedModel, Optional.of(7840L));
+ }
+ finally {
+ IOUtils.recursiveDeleteDir(storedAppDir.toFile());
+ }
+ }
+
+ @Ignore
+ @Test
+ public void testBertRankProfile() throws Exception {
+ StoringApplicationPackage application = new StoringApplicationPackage((applicationDir));
+
+ ImmutableList<MlModelImporter> importers = ImmutableList.of(new TensorFlowImporter(),
+ new OnnxImporter(),
+ new LightGBMImporter(),
+ new XGBoostImporter());
+
+ String rankProfiles = " rank-profile my_profile {\n" +
+ " first-phase {\n" +
+ " expression: onnx('bertsquad8.onnx', 'default', 'unstack')" +
+ " }\n" +
+ " }";
+
+ RankProfileRegistry rankProfileRegistry = new RankProfileRegistry();
+ QueryProfileRegistry queryProfileRegistry = application.getQueryProfiles();
+
+ SearchBuilder builder = new SearchBuilder(application, rankProfileRegistry, queryProfileRegistry);
+ String sdContent = "search test {\n" +
+ " document test {\n" +
+ " field unique_ids type tensor(d0[1]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field input_ids type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field input_mask type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }\n" +
+ " field segment_ids type tensor(d0[1],d1[256]) {\n" +
+ " indexing: summary | attribute\n" +
+ " }" +
+ " }\n" +
+ " rank-profile my_profile inherits default {\n" +
+ " function inline unique_ids_raw_output___9() {\n" +
+ " expression: attribute(unique_ids)\n" +
+ " }\n" +
+ " function inline input_ids() {\n" +
+ " expression: attribute(input_ids)\n" +
+ " }\n" +
+ " function inline input_mask() {\n" +
+ " expression: attribute(input_mask)\n" +
+ " }\n" +
+ " function inline segment_ids() {\n" +
+ " expression: attribute(segment_ids)\n" +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: onnx(\"bertsquad8.onnx\", \"default\", \"unstack\") \n" +
+ " }\n" +
+ " }" +
+ "}";
+ builder.importString(sdContent);
+ builder.build();
+ Search search = builder.getSearch();
+
+ RankProfile compiled = rankProfileRegistry.get(search, "my_profile")
+ .compile(queryProfileRegistry,
+ new ImportedMlModels(applicationDir.toFile(), importers));
+
+ DerivedConfiguration config = new DerivedConfiguration(search,
+ new BaseDeployLogger(),
+ new TestProperties(),
+ rankProfileRegistry,
+ queryProfileRegistry,
+ new ImportedMlModels());
+
+ config.export("/Users/lesters/temp/bert/idea/");
+
+// fixture.assertFirstPhaseExpression(vespaExpression, "my_profile");
+ System.out.println("Joda");
+ }
+
+ private RankProfileSearchFixture fixtureWith(String placeholderExpression, String firstPhaseExpression,
+ String constant, String field) {
+ return fixtureWith(placeholderExpression, firstPhaseExpression, constant, field, "Placeholder",
+ new StoringApplicationPackage(applicationDir));
+ }
+
+ private RankProfileSearchFixture uncompiledFixtureWith(String rankProfile, StoringApplicationPackage application) {
+ try {
+ return new RankProfileSearchFixture(application, application.getQueryProfiles(),
+ rankProfile, null, null);
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+ private RankProfileSearchFixture fixtureWith(String functionExpression,
+ String firstPhaseExpression,
+ String constant,
+ String field,
+ String functionName,
+ StoringApplicationPackage application) {
+ try {
+ RankProfileSearchFixture fixture = new RankProfileSearchFixture(
+ application,
+ application.getQueryProfiles(),
+ " rank-profile my_profile {\n" +
+ " function " + functionName + "() {\n" +
+ " expression: " + functionExpression +
+ " }\n" +
+ " first-phase {\n" +
+ " expression: " + firstPhaseExpression +
+ " }\n" +
+ " }",
+ constant,
+ field);
+ fixture.compileRankProfile("my_profile", applicationDir.append("models"));
+ return fixture;
+ }
+ catch (ParseException e) {
+ throw new IllegalArgumentException(e);
+ }
+ }
+
+}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index cba931e81f0..c444bf8d7dc 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -1,8 +1,11 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchdefinition.processing;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModels;
import com.yahoo.config.application.api.ApplicationFile;
import com.yahoo.config.application.api.ApplicationPackage;
+import com.yahoo.config.model.application.provider.BaseDeployLogger;
+import com.yahoo.config.model.deploy.TestProperties;
import com.yahoo.config.model.test.MockApplicationPackage;
import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.io.IOUtils;
@@ -10,6 +13,7 @@ import com.yahoo.io.reader.NamedReader;
import com.yahoo.path.Path;
import com.yahoo.search.query.profile.QueryProfileRegistry;
import com.yahoo.searchdefinition.RankingConstant;
+import com.yahoo.searchdefinition.derived.DerivedConfiguration;
import com.yahoo.searchdefinition.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
@@ -385,6 +389,15 @@ public class RankingExpressionWithTensorFlowTestCase {
finally {
IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile());
}
+
+ DerivedConfiguration config = new DerivedConfiguration(search.search(),
+ new BaseDeployLogger(),
+ new TestProperties(),
+ search.getRankProfileRegistry(),
+ search.getQueryProfileRegistry(),
+ new ImportedMlModels());
+ config.export("/Users/lesters/temp/bert/idea/");
+
}
private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
index c7f320ed3b4..87f7c1c71f8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/DimensionRenamer.java
@@ -66,7 +66,7 @@ public class DimensionRenamer {
void solve() {
log.log(Level.FINE, () -> "Rename problem:\n" + constraintsToString(constraints));
- renames = solve(100000);
+ renames = solve(100000000);
log.log(Level.FINE, () -> "Rename solution:\n" + renamesToString(renames));
}
@@ -86,7 +86,7 @@ public class DimensionRenamer {
private Map<String, Integer> solveWithOrWithoutSoftConstraints(int maxIterations) {
Map<String, Integer> solution = NamingConstraintSolver.solve(dimensions, constraints, maxIterations);
- if ( solution == null) {
+ if (solution == null) {
ListMap<Arc, Constraint> hardConstraints = new ListMap<>();
boolean anyRemoved = copyHard(constraints, hardConstraints);
if (anyRemoved) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index 14aa3ebf84e..3c8a6bde232 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -7,6 +7,7 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
@@ -74,6 +75,8 @@ public class IntermediateGraph {
renameDimensions();
}
+ static int counter = 0;
+
/**
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
@@ -93,16 +96,34 @@ public class IntermediateGraph {
}
private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
+ Set<String> operations = new HashSet<>();
+ addDimensionNameConstraints(operation, renamer, operations);
+ }
+
+ private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
+ if (operations.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, operations));
operation.addDimensionNameConstraints(renamer);
+ operations.add(operation.name());
}
}
private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
+ Set<String> operations = new HashSet<>();
+ renameDimensions(operation, renamer, operations);
+ }
+
+ private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
+ if (operations.contains(operation.name())) {
+ return;
+ }
if (operation.type().isPresent()) {
- operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.inputs().forEach(input -> renameDimensions(input, renamer, operations));
operation.renameDimensions(renamer);
+ operations.add(operation.name());
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 3774e64c886..7fad077ceb2 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -3,11 +3,14 @@ package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.configmodelview.MlModelImporter;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import ai.vespa.rankingexpression.importer.operations.Constant;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
@@ -15,9 +18,11 @@ import com.yahoo.text.ExpressionFormatter;
import com.yahoo.yolean.Exceptions;
import java.io.File;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.logging.Level;
import java.util.logging.Logger;
@@ -122,8 +127,16 @@ public abstract class ModelImporter implements MlModelImporter {
return operation.function();
}
+ private static boolean isImported(IntermediateOperation operation, ImportedModel model) {
+ return model.expressions().containsKey(operation.name()); // test for others?
+ }
+
private static void importExpressionInputs(IntermediateOperation operation, ImportedModel model) {
- operation.inputs().forEach(input -> importExpression(input, model));
+ operation.inputs().forEach(input -> {
+ if ( ! isImported(operation, model)) {
+ importExpression(input, model);
+ }
+ });
}
private static Optional<TensorFunction> importConstant(IntermediateOperation operation, ImportedModel model) {
@@ -206,18 +219,22 @@ public abstract class ModelImporter implements MlModelImporter {
private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
for (ImportedModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
- reportWarnings(graph.get(outputName), model);
+ reportWarnings(graph.get(outputName), model, new HashSet<String>());
}
}
}
- private static void reportWarnings(IntermediateOperation operation, ImportedModel model) {
+ private static void reportWarnings(IntermediateOperation operation, ImportedModel model, Set<String> reported) {
+ if (reported.contains(operation.name())) {
+ return;
+ }
for (String warning : operation.warnings()) {
// If we want to report warnings, that code goes here
}
for (IntermediateOperation input : operation.inputs()) {
- reportWarnings(input, model);
+ reportWarnings(input, model, reported);
}
+ reported.add(operation.name());
}
/**
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
index 21cc6b27dad..9a7fcc85ee1 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/NamingConstraintSolver.java
@@ -37,7 +37,8 @@ class NamingConstraintSolver {
private static ListMap<String, Integer> allPossibilities(Set<String> dimensions) {
ListMap<String, Integer> all = new ListMap<>();
for (String dimension : dimensions) {
- for (int i = 0; i < dimensions.size(); ++i)
+ // 20 (different dimension names) should be enough for most problems.
+ for (int i = 0; i < Math.min(dimensions.size(), 20); ++i)
all.put(dimension, i);
}
return all;
@@ -89,6 +90,7 @@ class NamingConstraintSolver {
workList.add(constraint);
}
}
+ if (iterations > maxIterations) return false;
}
return true;
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index ffc64c38f16..c98a5c7d4f5 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -2,7 +2,6 @@
package ai.vespa.rankingexpression.importer.onnx;
-import ai.vespa.rankingexpression.importer.operations.ExpandDims;
import ai.vespa.rankingexpression.importer.operations.Gather;
import ai.vespa.rankingexpression.importer.operations.OnnxCast;
import ai.vespa.rankingexpression.importer.operations.Gemm;
@@ -12,7 +11,10 @@ import ai.vespa.rankingexpression.importer.operations.Reduce;
import ai.vespa.rankingexpression.importer.operations.Select;
import ai.vespa.rankingexpression.importer.operations.Slice;
import ai.vespa.rankingexpression.importer.operations.Softmax;
+import ai.vespa.rankingexpression.importer.operations.Split;
import ai.vespa.rankingexpression.importer.operations.Squeeze;
+import ai.vespa.rankingexpression.importer.operations.Tile;
+import ai.vespa.rankingexpression.importer.operations.Transpose;
import ai.vespa.rankingexpression.importer.operations.Unsqueeze;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
@@ -32,6 +34,8 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import java.util.Collection;
+import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
@@ -53,19 +57,21 @@ class GraphImporter {
private static IntermediateOperation mapOperation(Onnx.NodeProto node,
List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
+ IntermediateGraph graph,
+ int outputIndex) {
String type = node.getOpType();
String modelName = graph.name();
String nodeName = getNodeName(node);
AttributeConverter attributes = AttributeConverter.convert(node);
- return mapOperation(type, inputs, modelName, nodeName, attributes);
+ return mapOperation(type, inputs, modelName, nodeName, attributes, outputIndex);
}
static IntermediateOperation mapOperation(String opType,
List<IntermediateOperation> inputs,
String modelName,
String nodeName,
- AttributeConverter attributes) {
+ AttributeConverter attributes,
+ int outputIndex) {
switch (opType.toLowerCase()) {
case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
@@ -115,17 +121,21 @@ class GraphImporter {
case "slice": return new Slice(modelName, nodeName, inputs, attributes);
case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
+ case "split": return new Split(modelName, nodeName, inputs, attributes, outputIndex);
case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
case "where": return new Select(modelName, nodeName, inputs);
case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
+ case "tile": return new Tile(modelName, nodeName, inputs);
+ case "transpose": return new Transpose(modelName, nodeName, inputs, attributes);
case "unsqueeze": return new Unsqueeze(modelName, nodeName, inputs, attributes);
}
IntermediateOperation op = new NoOp(modelName, nodeName, inputs);
op.warning("Operation '" + opType + "' is currently not implemented");
+ System.out.println(nodeName + ": operation '" + opType + "' is currently not implemented");
return op;
}
@@ -133,10 +143,15 @@ class GraphImporter {
Onnx.GraphProto onnxGraph = model.getGraph();
IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
+ System.out.println("Importing operations...");
importOperations(onnxGraph, intermediateGraph);
+ System.out.println("Verifying no warnings...");
verifyNoWarnings(intermediateGraph);
+ System.out.println("Verifying output types...");
verifyOutputTypes(onnxGraph, intermediateGraph);
+ System.out.println("Ok...");
+
return intermediateGraph;
}
@@ -150,8 +165,10 @@ class GraphImporter {
Onnx.GraphProto onnxGraph,
IntermediateGraph intermediateGraph) {
if (intermediateGraph.alreadyImported(name)) {
+// System.out.println("Trying to import '" + name + "' but is was already imported.");
return intermediateGraph.get(name);
}
+// System.out.println("Importing '" + name + "' ...");
IntermediateOperation operation;
if (isArgumentTensor(name, onnxGraph)) {
Onnx.ValueInfoProto valueInfoProto = getArgumentTensor(name, onnxGraph);
@@ -163,16 +180,21 @@ class GraphImporter {
intermediateGraph.inputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.vespaName());
+// System.out.println(" '" + name + "' imported as argument...");
+
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
OrderedTensorType defaultType = TypeConverter.typeFrom(tensorProto);
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
+// System.out.println(" '" + name + "' imported as constant...");
+
} else {
Onnx.NodeProto node = getNodeFromGraph(name, onnxGraph);
+ int outputIndex = getOutputIndex(node, name);
List<IntermediateOperation> inputs = importOperationInputs(node, onnxGraph, intermediateGraph);
- operation = mapOperation(node, inputs, intermediateGraph);
+ operation = mapOperation(node, inputs, intermediateGraph, outputIndex);
// propagate constant values if all inputs are constant
if (operation.isConstant()) {
@@ -183,8 +205,12 @@ class GraphImporter {
intermediateGraph.outputs(intermediateGraph.defaultSignature())
.put(IntermediateOperation.namePartOf(name), operation.name());
}
+
+// System.out.println(" '" + name + "' imported as normal...");
+
}
intermediateGraph.put(operation.name(), operation);
+ intermediateGraph.put(name, operation);
return operation;
}
@@ -262,7 +288,8 @@ class GraphImporter {
Onnx.ValueInfoProto onnxNode = getOutputNode(output.getKey(), onnxGraph);
OrderedTensorType type = operation.type().orElseThrow(
() -> new IllegalArgumentException("Output of '" + output.getValue() + "' has no type."));
- TypeConverter.verifyType(onnxNode.getType(), type);
+ System.out.println(onnxNode.getType() + " vs. " + type);
+ //TypeConverter.verifyType(onnxNode.getType(), type);
}
}
@@ -296,6 +323,10 @@ class GraphImporter {
return graph.getNodeList().stream().filter(node -> node.getName().equals(nodeName)).findFirst();
}
+ private static int getOutputIndex(Onnx.NodeProto node, String outputName) {
+ return node.getOutputCount() == 0 ? 0 : Math.max(node.getOutputList().indexOf(outputName), 0);
+ }
+
private static String getNodeName(Onnx.NodeProto node) {
String nodeName = node.getName();
if (nodeName.length() > 0)
@@ -307,11 +338,14 @@ class GraphImporter {
}
private static Set<String> getWarnings(IntermediateOperation op) {
- Set<String> warnings = new HashSet<>(op.warnings());
- for (IntermediateOperation input : op.inputs()) {
- warnings.addAll(getWarnings(input));
- }
- return warnings;
+ java.util.Map<String, Set<String>> warnings = new HashMap<>();
+ getWarnings(op, warnings);
+ return warnings.values().stream().flatMap(Collection::stream).collect(Collectors.toSet());
}
+ private static void getWarnings(IntermediateOperation op, java.util.Map<String, Set<String>> warnings) {
+ if (warnings.containsKey(op.name())) return;
+ op.inputs().forEach(input -> getWarnings(input, warnings));
+ warnings.put(op.name(), new HashSet<>(op.warnings()));
+ }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index 01fd7ee55bd..956d727fbad 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -54,10 +54,10 @@ public class Const extends IntermediateOperation {
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
+// @Override
+// public String vespaName() {
+// return modelName + "_" + super.vespaName();
+// }
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
index ad56eefe5f2..b12f83f274b 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java
@@ -22,10 +22,10 @@ public class Constant extends IntermediateOperation {
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + vespaName(name);
- }
+// @Override
+// public String vespaName() {
+// return modelName + "_" + vespaName(name);
+// }
@Override
protected OrderedTensorType lazyGetType() {
@@ -61,7 +61,9 @@ public class Constant extends IntermediateOperation {
public Constant withInputs(List<IntermediateOperation> inputs) {
if ( ! inputs.isEmpty())
throw new IllegalArgumentException("Constant cannot take inputs");
- return new Constant(modelName(), name(), type);
+ Constant constant = new Constant(modelName(), name(), type);
+ constant.setConstantValueFunction(constantValueFunction);
+ return constant;
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
index 5463f645355..af192fcec38 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java
@@ -12,12 +12,6 @@ public class Identity extends IntermediateOperation {
super(modelName, nodeName, inputs);
}
- /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
- @Override
- public String vespaName() {
- return modelName + "_" + super.vespaName();
- }
-
@Override
protected OrderedTensorType lazyGetType() {
if (!allInputTypesPresent(1))
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 2aa8b2a0d48..83e15a4081a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
@@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -47,6 +49,8 @@ public abstract class IntermediateOperation {
protected TensorFunction rankingExpressionFunction = null;
protected boolean exportAsRankingFunction = false;
+ private boolean hasRenamedDimensions = false;
+
private final List<String> importWarnings = new ArrayList<>();
private Value constantValue = null;
private List<IntermediateOperation> controlInputs = Collections.emptyList();
@@ -121,7 +125,10 @@ public abstract class IntermediateOperation {
}
/** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
+ public void renameDimensions(DimensionRenamer renamer) {
+ type = type.rename(renamer);
+ hasRenamedDimensions = true;
+ }
/** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
public boolean isInput() { return false; }
@@ -144,7 +151,11 @@ public abstract class IntermediateOperation {
}
/** Set the constant value function */
- public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; }
+ public void setConstantValueFunction(Function<OrderedTensorType, Value> func) {
+ this.constantValueFunction = func;
+ }
+
+ public boolean hasConstantValueFunction() { return constantValueFunction != null; }
/** Sets the external control inputs */
public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; }
@@ -153,12 +164,23 @@ public abstract class IntermediateOperation {
public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
/** Retrieve the valid Vespa name of this node */
- public String vespaName() { return vespaName(name); }
- public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; }
+ public String vespaName() {
+ if (isConstant())
+ return modelName + "_" + vespaName(name);
+ return vespaName(name);
+ }
+
+ public String vespaName(String name) {
+ return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null;
+ }
/** Retrieve the valid Vespa name of this node if it is a ranking expression function */
public String rankingExpressionFunctionName() {
- return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null;
+ String vespaName = vespaName();
+ if (vespaName == null) {
+ return null;
+ }
+ return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName;
}
/** Retrieve the list of warnings produced during its lifetime */
@@ -185,30 +207,80 @@ public abstract class IntermediateOperation {
/** Recursively evaluates this operation's constant value to avoid doing it run-time. */
public Value evaluateAsConstant(OrderedTensorType type) {
+// System.out.println("Starting constant evaluation for " + name);
if ( ! isConstant() ) {
throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant.");
}
- Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN));
- if (type != null && ! val.asTensor().type().equals(type.type()) ) {
+ if (type == null) {
+ System.out.println("Evaluating as constant for " + name + " with type null! Probably an error.");
+ }
+
+ IntermediateOperation evaluateOn = this;
+ if ( ! hasRenamedDimensions) {
+ // make a copy of the tree, perform renaming and evaluate
+ IntermediateOperation copy = copyTree(0);
+ optimizeAndRename(copy);
+ evaluateOn = copy;
+ }
+ Value val = evaluateOn.evaluateAsConstant(new MapContext(DoubleValue.NaN), 0);
+
+ if (type == null) {
+ return val;
+ }
+ Tensor tensor = val.asTensor(); //.withType(type.type());
+ if ( ! tensor.type().isRenamableTo(type.type()) ) {
throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " +
"Expected: " + type.type() + " Got: " + val.asTensor().type());
}
- return val;
+ // set constant value so we don't have to re-evaluate
+ setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type())));
+// System.out.println("Returning constant evaluation for " + name);
+ return new TensorValue(tensor.withType(type.type()));
+ }
+
+ private IntermediateOperation copyTree(int indent) {
+ String indentString = ""; for (int i = 0; i < indent; ++i) indentString += " ";
+// System.out.println(indentString + "Copying " + name);
+ List<IntermediateOperation> in = new ArrayList<>();
+ if (constantValue != null) {
+// System.out.println(indentString + name + " has a constant value");
+ IntermediateOperation constant = new Constant(modelName, name, type);
+ constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type())));
+ return constant;
+ }
+ inputs.forEach(i -> in.add(i.copyTree(indent + 1)));
+ IntermediateOperation copy = withInputs(in);
+ if (constantValueFunction != null) {
+ copy.constantValueFunction = constantValueFunction; // works?
+ }
+ return copy;
+ }
+
+ private TensorFunction optimizeAndRename(IntermediateOperation op) {
+ IntermediateGraph graph = new IntermediateGraph(modelName);
+ graph.put(name, op);
+ graph.outputs(graph.defaultSignature()).put(name, name);
+ graph.optimize();
+ return op.function().get();
}
- private Value evaluateAsConstant(Context context) {
+ private Value evaluateAsConstant(Context context, int indent) {
+ String in = ""; for (int i = 0; i < indent; ++i) in += " ";
+// System.out.println(in + "Constant evaluating for " + name);
String constantName = "constant(" + vespaName() + ")";
Value result = context.get(constantName);
if (result == DoubleValue.NaN) {
if (constantValue != null) {
+// System.out.println(in + name + " has constant value.");
result = constantValue;
} else if (inputs.size() == 0) {
+// System.out.println(in + name + " has no inputs.");
if (getConstantValue().isEmpty()) {
throw new IllegalArgumentException("Error in evaluating constant for " + name);
}
result = getConstantValue().get();
} else {
- inputs.forEach(i -> i.evaluateAsConstant(context));
+ inputs.forEach(i -> i.evaluateAsConstant(context, indent+1));
result = new TensorValue(lazyGetFunction().evaluate(context));
}
context.put(constantName, result);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index adb54474812..3211a44fa68 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -82,6 +82,13 @@ public class Join extends IntermediateOperation {
bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce);
}
+ // retain order of inputs
+ if (a == inputs.get(1)) {
+ TensorFunction temp = bReducedFunction;
+ bReducedFunction = aReducedFunction;
+ aReducedFunction = temp;
+ }
+
return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 6849e64641e..1eb21eb2a5e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -4,6 +4,9 @@ package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.text.ExpressionFormatter;
@@ -20,64 +23,126 @@ public class MatMul extends IntermediateOperation {
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(2)) return null;
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+
+ // add some more checks here
+ if (aType.type().rank() < 1 || bType.type().rank() < 1)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1");
+
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
- typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
- typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ OrderedTensorType largestRankType = aType.rank() >= bType.rank() ? aType : bType;
+ for (int i = 0; i < largestRankType.rank() - 2; ++i) {
+ typeBuilder.add(largestRankType.dimensions().get(i));
+ }
+ if (aType.rank() >= 2) {
+ typeBuilder.add(aType.dimensions().get(aType.rank() - 2));
+ }
+ if (bType.rank() >= 2) {
+ typeBuilder.add(bType.dimensions().get(bType.rank() - 1));
+ }
return typeBuilder.build();
}
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
OrderedTensorType aType = inputs.get(0).type().get();
- OrderedTensorType bType = inputs.get(1).type().get();
- if (aType.type().rank() < 2 || bType.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (aType.type().rank() != bType.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
Optional<TensorFunction> aFunction = inputs.get(0).function();
Optional<TensorFunction> bFunction = inputs.get(1).function();
- if (!aFunction.isPresent() || !bFunction.isPresent()) {
- return null;
- }
- return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+
+ // only change to this is for dimensions with size 1 - check in getType
+
+ return new com.yahoo.tensor.functions.Reduce(new Join(aFunction.get(), bFunction.get(), ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ aType.dimensions().get(aType.rank() - 1).name());
}
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
if ( ! allInputTypesPresent(2)) return;
- List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
- List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+ /*
+ * A: a1, a2, a3, a4
+ * B: b1, b2, b3, b4
+ *
+ * a4 == b3
+ * a3 < b4
+ * a3 < a4
+ * b4 < b3
+ *
+ * a1 == b1 -> men også størrelsesmessig.
+ * a2 == b2
+ * etc
+ */
+
+ OrderedTensorType typeA = inputs.get(0).type().get();
+ OrderedTensorType typeB = inputs.get(1).type().get();
+
+ String lastDimA = typeA.dimensions().get(typeA.rank()-1).name();
+ String lastDimB = typeB.dimensions().get(typeB.rank()-1).name();
+ String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name();
+ String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name();
+
+ // The last dimension of A should have the same name as the second-to-last dimension of B
+ renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this);
- assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
- assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ if (typeA.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ if (typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this);
+ }
- String aDim0 = aDimensions.get(0).name();
- String aDim1 = aDimensions.get(1).name();
- String bDim0 = bDimensions.get(0).name();
- String bDim1 = bDimensions.get(1).name();
+ // The second-to-last dimension of a should have a different name than the last dimension of b
+ if (typeA.rank() >= 2 && typeB.rank() >= 2) {
+ renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this);
+ }
- // The second dimension of a should have the same name as the first dimension of b
- renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this);
+ // a1 < a2 < a3 < a4
+ OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB;
+ for (int i = 0; i < largestRankType.rank() - 2; ++i) {
+ String iDim = largestRankType.dimensionNames().get(i);
+ for (int j = i+1; j < largestRankType.rank() - 2; ++j) {
+ String jDim = largestRankType.dimensionNames().get(j);
+ renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this);
+ }
+ }
+
+ // TODO: handle non similar sizes
+
+ // a1 == b1 etc
+ if (typeA.rank() == typeB.rank()) {
+ for (int i = 0; i < typeA.rank() - 2; ++i) {
+ renamer.addConstraint(typeA.dimensionNames().get(i), typeB.dimensionNames().get(i), DimensionRenamer.Constraint.equal(false), this);
+ }
+ }
- // The first dimension of a should have a different name than the second dimension of b
- renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this);
- // For efficiency, the dimensions to join over should be innermost - soft constraint
- renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this);
- renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this);
- }
- private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
- if (dimensions.size() >= 2) return;
- throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
- " but got just " + dimensions + " from\n" +
- ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+
+ // So, what about the other dimensions?
+// if (aDimensions.size() > 2) {
+// for (int i = 1; i < aDimensions.size(); ++i) {
+// renamer.addConstraint(aDimensions.get(0).name(), aDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this);
+// }
+// for (int i = 0; i < bDimensions.size(); ++i) {
+// renamer.addConstraint(aDimensions.get(0).name(), bDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this);
+// }
+// }
+
}
+// private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+// if (dimensions.size() >= 2) return;
+// throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+// " but got just " + dimensions + " from\n" +
+// ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
+// }
+
@Override
public MatMul withInputs(List<IntermediateOperation> inputs) {
return new MatMul(modelName(), name(), inputs);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
index e040ae62149..07ac457cca8 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java
@@ -54,7 +54,7 @@ public class Rename extends IntermediateOperation {
}
public void renameDimensions(DimensionRenamer renamer) {
- type = type.rename(renamer);
+ super.renameDimensions(renamer);
from = renamer.dimensionNameOf(from).orElse(from);
to = renamer.dimensionNameOf(to).orElse(to);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index c88fc18e6c6..f96dd420d30 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -2,8 +2,10 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
@@ -11,8 +13,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.Function;
+import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -27,6 +32,8 @@ import java.util.List;
import java.util.Optional;
import java.util.stream.Collectors;
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
public class Reshape extends IntermediateOperation {
private final AttributeMap attributeMap;
@@ -38,6 +45,10 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
if (inputs.size() == 2) {
return typeWithShapeAsInput();
} else if (inputs.size() == 1) {
@@ -126,10 +137,54 @@ public class Reshape extends IntermediateOperation {
return new Reshape(modelName(), name(), inputs, attributeMap);
}
- public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
+ public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) {
if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type())))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ IntermediateOperation input = inputs.get(0);
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ // ala (d0 * 2 + d1)
+ ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType));
+
+ long innerSize = 1;
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ innerSize *= inputType.dimensions().get(dim).size().get();
+ }
+
+ for (int dim = 0; dim < inputType.rank(); ++dim) {
+ String inputDimensionName = inputType.dimensions().get(dim).name();
+ long inputDimensionSize = inputType.dimensions().get(dim).size().get();
+ long previousInnerSize = innerSize;
+ innerSize /= inputDimensionSize;
+
+ ExpressionNode inputDimensionExpression;
+ if (inputDimensionSize == 1) {
+ inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero));
+ } else if (dim == (inputType.rank() - 1)) {
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size);
+ inputDimensionExpression = new EmbracedNode(div);
+ } else {
+ ExpressionNode size = new ConstantNode(new DoubleValue(innerSize));
+ ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize));
+ ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize);
+ ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size);
+ inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div));
+ }
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression)));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(outputType.type(), wrapScalar(sliceExpression));
+ return generate;
+
+ /*
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
// Here we create a transformation tensor that is multiplied with the from tensor to map into
@@ -168,11 +223,14 @@ public class Reshape extends IntermediateOperation {
result = new Rename(result, to, from);
}
return result;
+ */
}
+ /*
private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) {
return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent());
}
+ */
private static ExpressionNode unrollTensorExpression(OrderedTensorType type) {
if (type.rank() == 0)
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
index e5463291ef8..8dd1e3ff33d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java
@@ -182,7 +182,6 @@ public class Slice extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- // Todo: what to do?
for (int i = 0; i < type.dimensions().size(); i++) {
renamer.addDimension(type.dimensions().get(i).name());
for (int j = i + 1; j < type.dimensions().size(); j++) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
index 83086926316..e2b83246bfc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Map;
import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;
@@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
if ( ! allInputTypesPresent(1)) return null;
+
+ // input is referenced twice due to avoidance of overflow. so make this it's own function.
+ inputs.get(0).exportAsRankingFunction = true;
+
return inputs.get(0).type().get();
}
@@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation {
}
TensorFunction input = inputs.get(0).function().get();
- TensorFunction exp = new Map(input, ScalarFunctions.exp());
+ TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions);
+ TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow
+ TensorFunction exp = new Map(cap, ScalarFunctions.exp());
TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions);
TensorFunction div = new Join(exp, sum, ScalarFunctions.divide());
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
new file mode 100644
index 00000000000..02d780c52cd
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java
@@ -0,0 +1,119 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+public class Split extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+ private final int output;
+
+ private final int axis;
+ private int start;
+ private int end;
+
+ public Split(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes, int output) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ this.output = output;
+ axis = (int) attributes.get("axis").orElse(DoubleValue.zero).asDouble();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ int axisSize = inputType.dimensions().get(axis).size().get().intValue();
+ start = 0;
+ end = axisSize;
+
+ if (attributes.getList("split").isPresent()) {
+ List<Value> splitList = attributes.getList("split").get();
+ if (output > splitList.size()) {
+ throw new IllegalArgumentException("Split in " + name + ": output out of range of split list");
+ }
+ for (int i = 0; i < output; ++i) {
+ start += (int) splitList.get(i).asDouble();
+ }
+ if (output < splitList.size()) {
+ end = start + (int) splitList.get(output).asDouble();
+ }
+ } else {
+ start = axisSize / 2 * output;
+ end = start + axisSize / 2;
+ }
+
+ if (start >= axisSize || start < 0) {
+ throw new IllegalArgumentException("Split in " + name + ": split start index out of range (" + start + ")");
+ }
+ if (end > axisSize || end < 0) {
+ throw new IllegalArgumentException("Split in " + name + ": split end index out of range (" + end + ")");
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < inputType.rank(); ++i) {
+ TensorType.Dimension inputDimension = inputType.dimensions().get(i);
+ long dimSize = i == axis ? end - start : inputDimension.size().get();
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), dimSize));
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int i = 0; i < inputType.rank(); ++i) {
+ String inputDimensionName = inputType.dimensions().get(i).name();
+ ExpressionNode reference = new ReferenceNode(inputDimensionName);
+ ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0)));
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public Split withInputs(List<IntermediateOperation> inputs) {
+ return new Split(modelName(), name(), inputs, attributes, output);
+ }
+
+ @Override
+ public String operationName() { return "Split"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
new file mode 100644
index 00000000000..8d3468f3d04
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java
@@ -0,0 +1,100 @@
+// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar;
+
+/**
+ * Onnx tile operation.
+ */
+public class Tile extends IntermediateOperation {
+
+ public Tile(String modelName, String nodeName, List<IntermediateOperation> inputs) {
+ super(modelName, nodeName, inputs);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) return null;
+
+ // required as we use tensor create
+ inputs.get(0).exportAsRankingFunction = true;
+
+ IntermediateOperation repeats = inputs.get(1);
+ if (repeats.getConstantValue().isEmpty())
+ throw new IllegalArgumentException("Tile " + name + ": repeats input must be a constant.");
+
+ Tensor shape = repeats.getConstantValue().get().asTensor();
+ if (shape.type().rank() != 1)
+ throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor.");
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ if (shape.type().dimensions().get(0).size().get() != inputType.rank())
+ throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank.");
+
+ List<Integer> dimSizes = new ArrayList<>(inputType.rank());
+ shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue()));
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < dimSizes.size(); ++i) {
+ TensorType.Dimension inputDimension = inputType.dimensions().get(i);
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get() * dimSizes.get(i)));
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) return null;
+
+ IntermediateOperation input = inputs.get(0);
+ OrderedTensorType inputType = input.type().get();
+ String inputFunctionName = input.rankingExpressionFunctionName();
+
+ List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>();
+
+ for (int axis = 0; axis < inputType.rank(); ++axis) {
+ String inputDimensionName = inputType.dimensions().get(axis).name();
+ long inputDimensionSize = inputType.dimensions().get(axis).size().get();
+
+ ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize));
+ ExpressionNode reference = new ReferenceNode(inputDimensionName);
+ ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size);
+ dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod))));
+ }
+
+ TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName));
+ com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues);
+ ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices);
+
+ TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression));
+ return generate;
+ }
+
+ @Override
+ public Tile withInputs(List<IntermediateOperation> inputs) {
+ return new Tile(modelName(), name(), inputs);
+ }
+
+ @Override
+ public String operationName() { return "Tile"; }
+
+}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
new file mode 100644
index 00000000000..178759fbf2a
--- /dev/null
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java
@@ -0,0 +1,54 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.operations;
+
+import ai.vespa.rankingexpression.importer.OrderedTensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+
+import java.util.List;
+
+public class Transpose extends IntermediateOperation {
+
+ private final AttributeMap attributes;
+
+ public Transpose(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) {
+ super(modelName, nodeName, inputs);
+ this.attributes = attributes;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) return null;
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
+ for (int i = 0; i < inputType.rank(); ++i) {
+ int inputIndex = inputType.rank() - 1 - i;
+ if (attributes.getList("perm").isPresent()) {
+ inputIndex = (int) attributes.getList("perm").get().get(i).asDouble();
+ }
+ TensorType.Dimension inputDimension = inputType.dimensions().get(inputIndex);
+ typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get()));
+ }
+ OrderedTensorType result = typeBuilder.build();
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1))
+ return null;
+ return inputs.get(0).function().orElse(null);
+ }
+
+ @Override
+ public Transpose withInputs(List<IntermediateOperation> inputs) {
+ return new Transpose(modelName(), name(), inputs, attributes);
+ }
+
+ @Override
+ public String operationName() { return "Transpose"; }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java
new file mode 100644
index 00000000000..f4ed2f1b64d
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/BertImportTestCase.java
@@ -0,0 +1,281 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.onnx;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import com.yahoo.io.IOUtils;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.parser.ParseException;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunction;
+import com.yahoo.tensor.functions.Slice;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.tensorflow.op.core.Rank;
+
+import java.io.BufferedReader;
+import java.io.IOException;
+import java.sql.Ref;
+import java.util.HashSet;
+import java.util.List;
+import java.util.Map;
+import java.util.Set;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author lesters
+ */
+public class BertImportTestCase extends TestableModel {
+
+ @Test
+ public void test() throws Exception {
+ String filename = "/Users/lesters/github/onnx-models/text/machine_comprehension/bert-squad/java.txt";
+ List<String> lines = IOUtils.getLines(filename);
+ Tensor tensor = Tensor.from(lines.get(0));
+ TestableModelContext context = new TestableModelContext();
+ context.put("test", new TensorValue(tensor));
+
+ // Tensor: tensor(d1[1],d2[256],d4[12],d5[64])
+
+ String expr = "tensor(d0[256],d1[768])" +
+ "((test{" +
+ "d1:(floor(0.0)), " +
+ "d2:(floor(((768.0 * d0 + d1) % 196608) / 768.0)), " +
+ "d4:(floor(((768.0 * d0 + d1) % 768.0) / 64.0)), " +
+ "d5:(floor((768.0 * d0 + d1) % 64.0))" +
+ "}))";
+ Tensor result = new RankingExpression(expr).evaluate(context).asTensor();
+
+ assertEquals(result.sum(), -6074.247);
+ }
+
+ @Ignore
+ @Test
+ public void testBertImport() {
+ ImportedModel model = new OnnxImporter().importModel("test", "/Users/lesters/github/onnx-models/text/machine_comprehension/bert-squad/bertsquad8_modified.onnx");
+// ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/bert/bertsquad8_modified.onnx");
+// ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/bert/bertsquad10.onnx");
+// assertEquals(0, model.signature("default").skippedOutputs().size());
+// Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
+// assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.28258783057229725, -0.0685615853647904]]"), onnxResult);
+
+ String filename = "/Users/lesters/github/onnx-models/text/machine_comprehension/bert-squad/context.vespa";
+
+ // bert/encoder/layer_0/attention/self/mul_2
+ assert null != model.largeConstants().get("test_bert_encoder_layer_0_attention_self_Reshape_3__294");
+
+ TestableModelContext context;
+ if (true) {
+ // inputs
+ Tensor unique_ids_raw_output__9 = Tensor.from("tensor(d0[1]):[1]");
+ Tensor input_ids = Tensor.from("tensor(d0[1],d1[256]):[101,2073,2003,1996,5661,10549,2000,2175,1029,102,1999,2049,2220,2086,1010,1996,2047,4680,2415,3478,2000,3113,5270,1998,6599,10908,1012,1031,2260,1033,2011,2526,1010,2116,13773,3028,5661,2020,10549,1996,2172,3469,9587,9363,2638,2415,1999,2624,3799,2058,1996,2624,4560,4680,2415,2349,2000,1996,3732,1005,1055,3132,2686,1012,1037,10428,5468,2000,5446,2019,4935,3081,1037,3309,4171,3478,2000,3362,1996,3223,2048,1011,12263,3484,2000,3413,1012,1999,2238,2384,1010,2136,2624,4560,2328,1996,2148,2534,1010,1037,1002,1020,1012,6255,2454,1010,2630,1998,2317,9311,1010,5815,3770,1010,2199,2675,2519,1006,1021,1010,4278,25525,1007,1997,8327,2686,102,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]");
+ Tensor input_mask = Tensor.from("tensor(d0[1],d1[256]):[1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]");
+ Tensor segment_ids = Tensor.from("tensor(d0[1],d1[256]):[0,0,0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,1,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0]");
+
+ context = contextFrom(model);
+ context.put("unique_ids_raw_output___9", new TensorValue(unique_ids_raw_output__9));
+ context.put("input_ids", new TensorValue(input_ids));
+ context.put("input_mask", new TensorValue(input_mask));
+ context.put("segment_ids", new TensorValue(segment_ids));
+
+ context.write(filename);
+ } else {
+ context = TestableModelContext.read(filename);
+ }
+
+ // expected outputs from onnxruntime
+ Tensor unique_ids = Tensor.from("tensor(d0[1]):[1]");
+ Tensor unstack_0 = Tensor.from("tensor(d0[1],d1[256]):[-7.169589,-8.165145,-8.795558,-8.276284,-8.408593,-8.313643,-8.421538,-8.771402,-8.71111,-8.014886,-6.4415646,-7.5764513,-7.7209125,-8.689668,-8.05441,-4.357495,-4.082243,-2.4557219,-6.421309,-7.8627315,-8.612887,-8.122109,-8.072487,-8.678347,-8.467162,-8.818881,-8.143695,-6.412044,-7.765201,-8.125683,-2.5739796,-4.2929254,-7.8812947,-2.4893682,-1.5166948,-6.2354407,-4.039099,-5.6837378,-0.41342238,3.0958412,1.5454307,-0.89450985,6.0985346,-4.108738,-4.67186,-2.6797965,-0.65007347,4.2300944,-2.9132476,-4.853151,1.1584995,4.041984,-3.5257776,-2.3050616,-4.363427,-7.1510825,-8.426602,-6.6682553,-7.027374,-8.076435,-8.3017435,-6.9958987,-7.243815,-7.1347113,-7.5506253,-7.771371,-8.606251,-7.472072,-7.902196,-7.563202,-7.330995,-7.5767503,-7.8097973,-6.645113,-8.927777,-8.438513,-8.708496,-8.474434,-8.231956,-8.635139,-7.8764973,-8.80273,-9.103729,-9.07057,-8.610826,-9.084642,-8.795743,-7.2711506,-7.733648,-8.708181,-8.020964,-7.1652384,-7.9469404,-9.461184,-8.624146,-7.2252526,-6.4015207,-9.220176,-9.195709,-8.228707,-7.9646325,-8.685807,-8.980191,-8.858017,-9.290145,-8.921865,-7.656322,-8.872562,-8.898288,-8.683226,-9.219653,-8.371141,-7.130355,-8.930712,-9.05438,-8.771264,-9.621703,-8.550959,-7.327657,-9.138217,-9.377564,-9.111144,-9.653343,-8.726485,-8.215803,-9.300696,-8.044907,-8.641199,-8.641449,-8.640882,-8.641207,-8.648375,-8.645785,-8.639973,-8.650788,-8.660226,-8.6503525,-8.6601925,-8.647732,-8.652576,-8.665123,-8.653585,-8.653888,-8.661093,-8.661934,-8.6514845,-8.662573,-8.671499,-8.661195,-8.667901,-8.666959,-8.659721,-8.673244,-8.678537,-8.66441,-8.651034,-8.660175,-8.659063,-8.657169,-8.6603565,-8.6569,-8.649067,-8.651927,-8.6421995,-8.649052,-8.6478615,-8.6426935,-8.646153,-8.646865,-8.636821,-8.643324,-8.645994,-8.639597,-8.647679,-8.655649,-8.6609745,-8.654906,-8.6613455,-8.656511,-8.663024,-8.675192,-8.663131,-8.665018,-8.652522,-8.661668,-8.66894,-8.670112,-8.67217,-8.657303,-8.651893,-8.652592,-8.650168,-8.640702,-8.636455,-8.647628,-8.638621,-8.648,-8.656844,-8.649821,-8.657603,-8.648884,-8.661986,-8.663507,-8.652322,-8.662775,-8.664504,-8.662872,-8.668943,-8.6559105,-8.655738,-8.671845,-8.6666,-8.659552,-8.679308,-8.659756,-8.664594,-8.6688175,-8.666396,-8.673796,-8.65924,-8.664916,-8.6703005,-8.6611395,-8.660061,-8.660967,-8.672797,-8.66394,-8.657039,-8.671023,-8.663469,-8.659371,-8.6713705,-8.659359,-8.649764,-8.6620035,-8.656843,-8.654225,-8.661666,-8.647326,-8.652874,-8.650523,-8.644273,-8.649993,-8.65307,-8.645219,-8.6537075,-8.655814,-8.654312,-8.658724,-8.666763,-8.654713,-8.662302,-8.672376,-8.661079,-8.659652,-8.661736]");
+ Tensor unstack_1 = Tensor.from("tensor(d0[1],d1[256]):[-5.1743593,-8.167716,-8.096918,-8.610186,-8.627197,-8.518608,-8.413071,-8.04796,-8.405228,-5.775467,-8.891069,-8.499419,-8.482899,-7.4575906,-5.1060586,-9.029796,-7.9796743,-7.411322,-0.62632525,-8.209348,-8.202109,-8.436105,-8.226212,-8.245562,-7.7150273,-5.4672513,-6.134469,-8.531252,-7.390566,-6.717802,-7.9110403,-5.084878,-5.02966,-7.6901536,-7.6643076,-0.42670453,-4.2289968,-6.957412,-5.192218,-6.1616683,-6.4489427,-3.5914042,-3.7853065,-6.857571,-2.3781726,6.1620126,-3.007885,-4.688912,6.258016,-5.2202945,-6.7945094,-5.1450105,0.7468612,-4.919924,5.489712,-7.307814,-7.952,-9.152897,-6.4863043,-8.328119,-7.9448185,-8.395245,-3.6581624,-0.8252581,-8.731679,-8.624653,-7.61354,-8.755644,-8.341698,-8.758186,-5.954141,-8.560192,-8.833243,-7.6137505,-5.96118,-8.43961,-8.188338,-8.373185,-8.683964,-8.246368,-8.824446,-8.05728,-7.623751,-7.56998,-8.277908,-6.8986,-6.8709283,-9.279125,-8.84588,-7.6791453,-5.2976,-9.191589,-8.797903,-6.440836,-8.179676,-9.236156,-8.972708,-5.8724217,-6.928253,-8.685118,-8.84946,-8.293621,-7.8572874,-8.053903,-7.398021,-7.549705,-9.004784,-8.060446,-7.950672,-7.2188964,-6.497633,-8.454956,-9.045556,-7.8463507,-7.771165,-8.067679,-5.9176393,-8.09684,-8.5619955,-7.5696144,-7.100621,-7.0136676,-6.464568,-8.108538,-8.516457,-5.488856,-5.853514,-8.457255,-8.457188,-8.454844,-8.448273,-8.447864,-8.447953,-8.445874,-8.444903,-8.442595,-8.448225,-8.443758,-8.451776,-8.447646,-8.440473,-8.447313,-8.44705,-8.443977,-8.442994,-8.449743,-8.441086,-8.433916,-8.43898,-8.435363,-8.434594,-8.431777,-8.433416,-8.433545,-8.442987,-8.453411,-8.450068,-8.4503565,-8.451651,-8.450909,-8.454222,-8.456041,-8.452284,-8.449699,-8.454986,-8.455096,-8.459543,-8.458114,-8.458371,-8.4632635,-8.458183,-8.457299,-8.458008,-8.452067,-8.444335,-8.442348,-8.445211,-8.441855,-8.443939,-8.441303,-8.436119,-8.442878,-8.439337,-8.446676,-8.441184,-8.438475,-8.440033,-8.4386015,-8.447922,-8.455316,-8.452563,-8.454967,-8.459164,-8.460839,-8.453004,-8.451543,-8.446279,-8.441412,-8.448481,-8.446184,-8.448539,-8.445241,-8.444487,-8.450539,-8.446448,-8.446319,-8.447268,-8.440758,-8.448286,-8.447366,-8.437631,-8.441085,-8.444475,-8.431786,-8.441355,-8.436929,-8.432141,-8.436456,-8.435032,-8.445299,-8.442143,-8.438964,-8.445743,-8.445099,-8.444958,-8.438029,-8.439503,-8.446831,-8.43919,-8.442334,-8.446472,-8.442076,-8.449043,-8.451941,-8.449556,-8.454564,-8.455859,-8.452123,-8.461076,-8.45802,-8.456931,-8.458485,-8.45496,-8.4508295,-8.453123,-8.451649,-8.451098,-8.450148,-8.446929,-8.44253,-8.44839,-8.444667,-8.437894,-8.444409,-8.444666,-8.441956]");
+
+
+// model.functions().forEach((k, v) -> {
+// evaluateFunction(context, model, k, "");
+// });
+
+// RankingExpression e = model.expressions().get("unique_ids_graph_outputs_Identity__10");
+// evaluateFunctionDependencies(context, model, e.getRoot(), "");
+// Tensor result = e.evaluate(context).asTensor();
+// assertEquals(result, unique_ids);
+
+ RankingExpression e = model.expressions().get("bert/encoder/layer_0/output/LayerNorm/batchnorm/add_1");
+
+ evaluateFunctionDependencies(context, model, e.getRoot(), "");
+ context.write(filename);
+ Tensor result = e.evaluate(context).asTensor();
+ double sum = result.sum().asDouble();
+ System.out.println(sum);
+
+ Tensor matmul1 = model.expressions().get("bert/encoder/layer_0/attention/self/MatMul_1").evaluate(context).asTensor();
+
+ Tensor transpose = model.expressions().get("bert/encoder/layer_0/attention/self/transpose_3").evaluate(context).asTensor();
+ String cast = model.largeConstants().get("test_bert_encoder_layer_0_attention_self_Reshape_3__294");
+ Tensor reshape = model.expressions().get("bert/encoder/layer_0/attention/self/Reshape_3").evaluate(context).asTensor();
+ Tensor matmul = model.expressions().get("bert/encoder/layer_0/attention/output/dense/MatMul").evaluate(context).asTensor();
+ Tensor add = model.expressions().get("bert/encoder/layer_0/attention/output/dense/BiasAdd").evaluate(context).asTensor();
+
+ Tensor add1 = model.expressions().get("bert/encoder/layer_0/attention/output/add").evaluate(context).asTensor();
+ Tensor add2 = model.expressions().get("bert/encoder/layer_0/attention/output/LayerNorm/batchnorm/add_1").evaluate(context).asTensor();
+ Tensor add3 = model.expressions().get("bert/encoder/layer_0/output/add").evaluate(context).asTensor();
+ Tensor add4 = model.expressions().get("bert/encoder/layer_0/output/LayerNorm/batchnorm/add_1").evaluate(context).asTensor();
+
+ assertEquals(result, unique_ids);
+
+// Tensor result = model.expressions().get("unique_ids_graph_outputs_Identity__10").evaluate(context).asTensor();
+// assertEquals(result, unique_ids);
+
+// result = model.expressions().get("unstack_graph_outputs_Identity__7").evaluate(context).asTensor(); // or map from signature outputs
+// assertEquals(result, unstack_0);
+
+ // en feil her i outputs: har bare en: unstack, men vi må ha to: unstack:0 og unstack:1
+
+ }
+
+
+ private void evaluateFunction(Context context, ImportedModel model, String functionName, String in) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
+ System.out.println(in + "Looking for dependencies of function " + functionName + ": " + e.toString());
+ evaluateFunctionDependencies(context, model, e.getRoot(), in);
+ System.out.println(in + "Evaluating function " + functionName + ": " + e.toString());
+ long start = System.currentTimeMillis();
+ Tensor result = e.evaluate(context).asTensor();
+ context.put(functionName, new TensorValue(result));
+ long end = System.currentTimeMillis();
+ System.out.println(in + "[" + (end - start) + "] completed " + functionName + " (" + result.type() + "), context is: " + context.names().size() + " " + contextSize(context));
+ } else {
+ System.out.println(in + "Function " + functionName + " already evaluated...");
+ }
+ }
+
+ private long contextSize(Context context) {
+ long size = 0;
+ for (String name : context.names()) {
+ Tensor val = context.getTensor(name);
+ if (val != null) size += val.size();
+ }
+ return size;
+ }
+
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node, String in) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ ReferenceNode ref = (ReferenceNode) node;
+ if (ref.getName().equals("constant")) {
+ String constant = ref.getArguments().expressions().get(0).toString();
+ if (!context.names().contains(constant)) {
+ String value = null;
+ if (model.smallConstants().containsKey(constant)) {
+ value = model.smallConstants().get(constant);
+ }
+ if (model.largeConstants().containsKey(constant)) {
+ value = model.largeConstants().get(constant);
+ }
+ if (value != null) {
+ System.out.println(in + "Adding constant: " + name);
+ long start = System.currentTimeMillis();
+ Tensor val = Tensor.from(value);
+ context.put(name, new TensorValue(val));
+ long end = System.currentTimeMillis();
+ System.out.println(in + "Added constant: " + name + " (" + val.type() + ") in [" + (end - start) + "]");
+ }
+ }
+ }
+ if (model.functions().containsKey(name)) {
+ evaluateFunction(context, model, name, in + " ");
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ if (node instanceof TensorFunctionNode && ((TensorFunctionNode)node).function() instanceof Generate) {
+ Generate generate = (Generate) ((TensorFunctionNode)node).function();
+ TensorFunctionNode.ExpressionScalarFunction func = (TensorFunctionNode.ExpressionScalarFunction) generate.getBoundGenerator();
+ if (func != null) {
+ ExpressionNode bound = func.getExpression();
+ if (bound.toString().contains("imported_ml_")) {
+ System.out.println(in + "Found expression inside generator: " + bound.toString());
+ evaluateFunctionDependencies(context, model, bound, in);
+ }
+ }
+ }
+ else if (node instanceof TensorFunctionNode && ((TensorFunctionNode)node).function() instanceof Slice) {
+ Slice<Reference> slice = (Slice<Reference>) ((TensorFunctionNode)node).function();
+ for (Slice.DimensionValue<Reference> value : slice.getSubspaceAddress()) {
+ TensorFunctionNode.ExpressionScalarFunction func = (TensorFunctionNode.ExpressionScalarFunction) value.index().orElse(null);
+ if (func != null) {
+ ExpressionNode bound = func.getExpression();
+ if (bound.toString().contains("imported_ml_")) {
+ System.out.println(in + "Found expression inside slice: " + bound.toString());
+ evaluateFunctionDependencies(context, model, bound, in);
+ }
+ }
+ }
+ }
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateFunctionDependencies(context, model, child, in);
+ }
+ }
+ }
+
+ static TestableModelContext contextFrom(ImportedModel result) {
+ TestableModelContext context = new TestableModelContext();
+ if (result != null) {
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ }
+ return context;
+ }
+
+ private static class TestableModelContext extends MapContext implements ContextIndex {
+ @Override
+ public int size() {
+ return bindings().size();
+ }
+ @Override
+ public int getIndex(String name) {
+ throw new UnsupportedOperationException(this + " does not support index lookup by name");
+ }
+
+ public void write(String filename) {
+ try {
+ for (Map.Entry<String, Value> entry: bindings().entrySet()) {
+ String line = entry.getKey() + "\t" + entry.getValue().asTensor() + "\n";
+ IOUtils.writeFile(filename, line, true);
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ }
+
+ public static TestableModelContext read(String filename) {
+ System.out.println("Reading content from " + filename);
+ TestableModelContext context = new TestableModelContext();
+ try (BufferedReader reader = IOUtils.createReader(filename)) {
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] strings = line.trim().split("\t");
+ String name = strings[0];
+ Tensor tensor = Tensor.from(strings[1]);
+ context.put(name, new TensorValue(tensor));
+ }
+ } catch (IOException e) {
+ throw new RuntimeException(e);
+ }
+ System.out.println("Done reading context");
+ return context;
+ }
+ }
+
+}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
index 94c5577357b..0c9acc9b372 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxOperationsTestCase.java
@@ -107,6 +107,18 @@ public class OnnxOperationsTestCase {
assertEval("less", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a < b))", x, y));
assertEval("equal", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(a == b))", x, y));
assertEval("pow", x, y, evaluate("join(x, rename(y, d0, d2), f(a,b)(pow(a,b)))", x, y));
+
+ // broadcasting - opposite order
+ x = evaluate("random(d0[4]) + 1");
+ y = evaluate("random(d0[2],d1[3],d2[4]) + 1");
+ assertEval("add", x, y, evaluate("rename(x, d0, d2) + y", x, y));
+ assertEval("sub", x, y, evaluate("rename(x, d0, d2) - y", x, y));
+ assertEval("mul", x, y, evaluate("rename(x, d0, d2) * y", x, y));
+ assertEval("div", x, y, evaluate("rename(x, d0, d2) / y", x, y));
+ assertEval("greater", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a > b))", x, y));
+ assertEval("less", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a < b))", x, y));
+ assertEval("equal", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(a == b))", x, y));
+ assertEval("pow", x, y, evaluate("join(rename(x, d0, d2), y, f(a,b)(pow(a,b)))", x, y));
}
@Test
@@ -185,9 +197,55 @@ public class OnnxOperationsTestCase {
@Test
public void testMatMul1() throws ParseException {
- Tensor a = evaluate("tensor(d0[2],d1[3]):[1, 2, 3, 4, 5, 6]");
- Tensor b = evaluate("tensor(d0[3],d1[2]):[7, 8, 9, 10, 11, 12]");
- assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[58, 64, 139, 154]"));
+ Tensor a = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ Tensor b = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("91"));
+
+ a = evaluate("tensor(d0[3]):[1,2,3]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2]):[22, 28]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3]):[1,2,3]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2]):[14, 32]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[2],d1[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[3],d2[2]):[1,2,3,4,5,6]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[2],d2[2]):[22,28,49,64]"));
+
+ a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]");
+ b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]"));
+
+ a = evaluate("tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+ assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]"));
+
+
+// a = evaluate("tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]");
+// b = evaluate("tensor(d0[1],d1[4],d2[3],d3[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]"));
+
+// a = evaluate("tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,22,28,49,64,58,64,139,154]"));
+
+// a = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2],d3[2]):[22,28,49,64,58,64,139,154,22,28,49,64,58,64,139,154]"));
+
+// a = evaluate("tensor(d0[3]]):[1,2,3]");
+// assertEval("matmul", a, b, evaluate("tensor(d0[1],d1[4],d2[2]):[22,28,58,64,22,28,58,64]"));
}
@Test
@@ -217,6 +275,10 @@ public class OnnxOperationsTestCase {
y = evaluate("tensor(d0[4]):[3,2,-1,1]");
assertEval("reshape", x, y, evaluate("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]"));
+
+ x = evaluate("tensor(d0[1],d1[2],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12]");
+ y = evaluate("tensor(d0[2]):[2,6]");
+ assertEval("reshape", x, y, evaluate("tensor(d0[2],d1[6]):[1,2,3,4,5,6,7,8,9,10,11,12]"));
}
@Test
@@ -435,6 +497,50 @@ public class OnnxOperationsTestCase {
}
+ @Test
+ public void testTranspose1() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]");
+ assertEval("transpose", x, evaluate("tensor(d0[3],d1[2]):[[1,4],[2,5],[3,6]]"));
+ }
+
+ @Test
+ public void testTile6() throws ParseException {
+ Tensor x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ Tensor y = evaluate("tensor(d0[2]):[1,2]");
+ assertEval("tile", x, y, evaluate("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]"));
+
+ x = evaluate("tensor(d0[2],d1[2]):[1,2,3,4]");
+ y = evaluate("tensor(d0[2]):[3,1]");
+ assertEval("tile", x, y, evaluate("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]"));
+
+ x = evaluate("tensor(d0[1],d1[1],d2[1]):[1]");
+ y = evaluate("tensor(d0[3]):[1,6,1]");
+ assertEval("tile", x, y, evaluate("tensor(d0[1],d1[6],d2[1]):[1,1,1,1,1,1]"));
+
+ }
+
+
+ @Test
+ public void testSplit2() throws ParseException {
+ Tensor x = evaluate("tensor(d0[6]):[1,2,3,4,5,6]");
+ assertEval("split", x, evaluate("tensor(d0[3]):[1,2,3]"), 0);
+ assertEval("split", x, evaluate("tensor(d0[3]):[4,5,6]"), 1);
+ assertEval("split", x, evaluate("tensor(d0[2]):[1,2]"), createAttribute("split", new int[] {2}), 0);
+ assertEval("split", x, evaluate("tensor(d0[4]):[3,4,5,6]"), createAttribute("split", new int[] {2}), 1);
+ assertEval("split", x, evaluate("tensor(d0[3]):[3,4,5]"), createAttribute("split", new int[] {2,3}), 1);
+ assertEval("split", x, evaluate("tensor(d0[1]):[6]"), createAttribute("split", new int[] {2,3}), 2);
+
+ x = evaluate("tensor(d0[2],d1[3]):[1,2,3,4,5,6]");
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"));
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), 0);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), 1);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[1,2,3]"), createAttribute("split", new int[] {1}), 0);
+ assertEval("split", x, evaluate("tensor(d0[1],d1[3]):[4,5,6]"), createAttribute("split", new int[] {1}), 1);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[1,4]"), createAttribute("axis", 1), 0);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[2,5]"), createAttribute("axis", 1), 1);
+ assertEval("split", x, evaluate("tensor(d0[2],d1[1]):[3,6]"), createAttribute("axis", 1), 2);
+ }
+
private Tensor evaluate(String expr) throws ParseException {
return evaluate(expr, null, null, null);
}
@@ -461,41 +567,49 @@ public class OnnxOperationsTestCase {
}
private void assertEval(String opName, Tensor x, Tensor expected) {
- assertEval(opName, x, null, null, null, null, expected, null);
+ assertEval(opName, x, null, null, null, null, expected, null, 0);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, int output) {
+ assertEval(opName, x, null, null, null, null, expected, null, output);
}
private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, null, null, null, null, expected, attr);
+ assertEval(opName, x, null, null, null, null, expected, attr, 0);
+ }
+
+ private void assertEval(String opName, Tensor x, Tensor expected, AttributeConverter attr, int output) {
+ assertEval(opName, x, null, null, null, null, expected, attr, output);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, null, null, null, expected, attr);
+ assertEval(opName, x, y, null, null, null, expected, attr, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor expected) {
- assertEval(opName, x, y, null, null, null, expected, null);
+ assertEval(opName, x, y, null, null, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected) {
- assertEval(opName, x, y, z, null, null, expected, null);
+ assertEval(opName, x, y, z, null, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor expected, AttributeConverter attr) {
- assertEval(opName, x, y, z, null, null, expected, attr);
+ assertEval(opName, x, y, z, null, null, expected, attr, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor expected) {
- assertEval(opName, x, y, z, q, null, expected, null);
+ assertEval(opName, x, y, z, q, null, expected, null, 0);
}
private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected) {
- assertEval(opName, x, y, z, q, r, expected, null);
+ assertEval(opName, x, y, z, q, r, expected, null, 0);
}
- private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr) {
+ private void assertEval(String opName, Tensor x, Tensor y, Tensor z, Tensor q, Tensor r, Tensor expected, AttributeConverter attr, int output) {
Context context = new MapContext(DoubleValue.NaN);
List<IntermediateOperation> inputs = createInputs(context, x, y, z, q, r);
- IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build());
+ IntermediateOperation op = mapOperation(opName, inputs, modelName, opName, attr != null ? attr : createAttributes().build(), output);
optimizeAndRename(opName, op);
Tensor result = evaluate(op);
assertEquals(expected, result);
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
index 9631bddd93d..abecf4f5cb4 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/SimpleImportTestCase.java
@@ -35,6 +35,15 @@ public class SimpleImportTestCase {
}
@Test
+ public void testConstant() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/const.onnx");
+
+ MapContext context = new MapContext();
+ Tensor result = model.expressions().get("output").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor():0.42"));
+ }
+
+ @Test
public void testGather() {
ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/gather.onnx");
@@ -48,6 +57,19 @@ public class SimpleImportTestCase {
assertEquals(result, Tensor.from("tensor(d0[2],d1[2],d2[2]):[1, 2, 3, 4, 3, 4, 5, 6]"));
}
+ @Test
+ public void testConcat() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/simple/concat.onnx");
+
+ MapContext context = new MapContext();
+ context.put("i", new TensorValue(Tensor.from("tensor(d0[1]):[1]")));
+ context.put("j", new TensorValue(Tensor.from("tensor(d0[1]):[2]")));
+ context.put("k", new TensorValue(Tensor.from("tensor(d0[1]):[3]")));
+
+ Tensor result = model.expressions().get("y").evaluate(context).asTensor();
+ assertEquals(result, Tensor.from("tensor(d0[3]):[1, 2, 3]"));
+ }
+
private void evaluateFunction(Context context, ImportedModel model, String functionName) {
if (!context.names().contains(functionName)) {
RankingExpression e = RankingExpression.from(model.functions().get(functionName));
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java
new file mode 100644
index 00000000000..66af131aa36
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/LesterTensorflowImportTestCase.java
@@ -0,0 +1,162 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.tensorflow;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
+import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
+import com.yahoo.collections.Pair;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
+import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.system.ProcessExecuter;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.junit.Assert;
+import org.junit.Ignore;
+import org.junit.Test;
+import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
+
+import java.io.IOException;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.util.List;
+import java.util.Map;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertTrue;
+
+public class LesterTensorflowImportTestCase {
+
+ @Test
+ @Ignore
+ public void testPyTorchExport() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/pytorch/test.onnx");
+ Tensor onnxResult = evaluateVespa(model, "output", model.inputs());
+ assertEquals(Tensor.from("tensor(d0[1],d1[2]):[[0.2134095202835272, -0.08556838456161658]]"), onnxResult);
+ }
+
+ @Test
+ @Ignore
+ public void testBERT() {
+ ImportedModel model = new OnnxImporter().importModel("test", "src/test/models/onnx/bert/bertsquad10.onnx");
+ }
+
+ private Tensor evaluateVespa(ImportedModel model, String operationName, Map<String, TensorType> inputs) {
+ Context context = contextFrom(model);
+ for (Map.Entry<String, TensorType> entry : inputs.entrySet()) {
+ Tensor argument = vespaInputArgument(1, entry.getValue().dimensions().get(1).size().get().intValue());
+ context.put(entry.getKey(), new TensorValue(argument));
+ }
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ return expression.evaluate(context).asTensor();
+ }
+
+ @Test
+ @Ignore
+ public void testModelImport() {
+
+ // MÃ¥ endre til tf 2.0 i java!
+
+ String modelDir = "src/test/models/tensorflow/tf2/saved_model/";
+ // output function
+ String operationName = "out";
+
+ // Import TF
+ SavedModelBundle tensorFlowModel = SavedModelBundle.load(modelDir, "serve");
+ ImportedModel model = new TensorFlowImporter().importModel("test", modelDir, tensorFlowModel);
+ ImportedModel.Signature signature = model.signature("serving_default");
+ assertEquals("Should have no skipped outputs", 0, model.signature("serving_default").skippedOutputs().size());
+ ImportedMlFunction output = signature.outputFunction("output", operationName);
+ assertNotNull(output);
+
+ // Test TF
+ Session.Runner runner = tensorFlowModel.session().runner();
+ runner.feed("x", tensorFlowFloatInputArgument(1, 4));
+ List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
+ assertEquals(1, results.size());
+ Tensor tfResult = TensorConverter.toVespaTensor(results.get(0));
+
+ // Test Vespa
+ Context context = contextFrom(model);
+ context.put("x", new TensorValue(vespaInputArgument(1, 4)));
+ model.functions().forEach((k, v) -> evaluateFunction(context, model, k));
+ RankingExpression expression = model.expressions().get(operationName);
+ ExpressionOptimizer optimizer = new ExpressionOptimizer();
+ optimizer.optimize(expression, (ContextIndex)context);
+ Tensor vespaResult = expression.evaluate(context).asTensor();
+
+ // Equal result?
+ System.out.println(tfResult);
+ System.out.println(vespaResult);
+ assertEquals(tfResult, vespaResult);
+ }
+
+ private org.tensorflow.Tensor<?> tensorFlowFloatInputArgument(int d0Size, int d1Size) {
+ FloatBuffer fb1 = FloatBuffer.allocate(d0Size * d1Size);
+ int i = 0;
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; ++d1)
+ fb1.put(i++, (float)(d1 * 1.0 / d1Size));
+ return org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb1);
+ }
+
+ private Tensor vespaInputArgument(int d0Size, int d1Size) {
+ Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
+ for (int d0 = 0; d0 < d0Size; d0++)
+ for (int d1 = 0; d1 < d1Size; d1++)
+ b.cell(d1 * 1.0 / d1Size, d0, d1);
+ return b.build();
+ }
+
+ static Context contextFrom(ImportedModel result) {
+ TestableModelContext context = new TestableModelContext();
+ result.largeConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ result.smallConstants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(Tensor.from(tensor))));
+ return context;
+ }
+
+ private void evaluateFunction(Context context, ImportedModel model, String functionName) {
+ if (!context.names().contains(functionName)) {
+ RankingExpression e = RankingExpression.from(model.functions().get(functionName));
+ evaluateFunctionDependencies(context, model, e.getRoot());
+ context.put(functionName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateFunctionDependencies(Context context, ImportedModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.functions().containsKey(name)) {
+ evaluateFunction(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateFunctionDependencies(context, model, child);
+ }
+ }
+ }
+
+ private static class TestableModelContext extends MapContext implements ContextIndex {
+ @Override
+ public int size() {
+ return bindings().size();
+ }
+ @Override
+ public int getIndex(String name) {
+ throw new UnsupportedOperationException(this + " does not support index lookup by name");
+ }
+ }
+
+}
diff --git a/model-integration/src/test/models/onnx/simple/concat.onnx b/model-integration/src/test/models/onnx/simple/concat.onnx
new file mode 100644
index 00000000000..945bc3c9445
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/concat.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/concat.py b/model-integration/src/test/models/onnx/simple/concat.py
new file mode 100755
index 00000000000..b77cf5decc1
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/concat.py
@@ -0,0 +1,25 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+import numpy as np
+from onnx import helper, TensorProto
+
+i_type = helper.make_tensor_value_info('i', TensorProto.FLOAT, [1])
+j_type = helper.make_tensor_value_info('j', TensorProto.FLOAT, [1])
+k_type = helper.make_tensor_value_info('k', TensorProto.FLOAT, [1])
+
+output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [3])
+
+node = onnx.helper.make_node(
+ 'Concat',
+ inputs=['i', 'j', 'k'],
+ outputs=['y'],
+ axis=0,
+)
+graph_def = onnx.helper.make_graph(
+ nodes = [node],
+ name = 'concat_test',
+ inputs = [i_type, j_type, k_type],
+ outputs = [output_type]
+)
+model_def = helper.make_model(graph_def, producer_name='concat.py')
+onnx.save(model_def, 'concat.onnx')
diff --git a/model-integration/src/test/models/onnx/simple/const.onnx b/model-integration/src/test/models/onnx/simple/const.onnx
new file mode 100644
index 00000000000..c75a92ff12c
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/const.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/const.py b/model-integration/src/test/models/onnx/simple/const.py
new file mode 100755
index 00000000000..35d6dee1346
--- /dev/null
+++ b/model-integration/src/test/models/onnx/simple/const.py
@@ -0,0 +1,26 @@
+# Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+import onnx
+import numpy as np
+from onnx import helper, TensorProto
+
+output_type = helper.make_tensor_value_info('y', TensorProto.FLOAT, [])
+
+node = onnx.helper.make_node(
+ 'Constant',
+ inputs=[],
+ outputs=['y'],
+ value=onnx.helper.make_tensor(
+ name='const_tensor',
+ data_type=onnx.TensorProto.FLOAT,
+ dims=(),
+ vals=[0.42]
+ ),
+)
+graph_def = onnx.helper.make_graph(
+ nodes = [node],
+ name = 'constant_test',
+ inputs = [],
+ outputs = [output_type]
+)
+model_def = helper.make_model(graph_def, producer_name='const.py')
+onnx.save(model_def, 'const.onnx')
diff --git a/model-integration/src/test/models/onnx/simple/gather.onnx b/model-integration/src/test/models/onnx/simple/gather.onnx
index 62451ad953d..0647d86ed0f 100644
--- a/model-integration/src/test/models/onnx/simple/gather.onnx
+++ b/model-integration/src/test/models/onnx/simple/gather.onnx
Binary files differ
diff --git a/model-integration/src/test/models/onnx/simple/simple.onnx b/model-integration/src/test/models/onnx/simple/simple.onnx
index 1c746c90efa..41b458451d0 100644
--- a/model-integration/src/test/models/onnx/simple/simple.onnx
+++ b/model-integration/src/test/models/onnx/simple/simple.onnx
@@ -1,4 +1,4 @@
- simple.py:ã
+ simple.py:ã
0
query_tensor
attribute_tensormatmul"MatMul
@@ -20,4 +20,4 @@
output


-B
+B \ No newline at end of file
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index d058104bf1b..39642f5cb50 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1413,6 +1413,7 @@
"public void <init>(java.util.Collection, java.util.Map)",
"public void <init>(java.util.Map)",
"public void <init>(java.util.Map, java.util.Map)",
+ "public void <init>(java.util.Map, java.util.Map, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext)",
"public com.yahoo.searchlib.rankingexpression.ExpressionFunction getFunction(java.lang.String)",
"protected com.google.common.collect.ImmutableMap functions()",
"public java.lang.String getBinding(java.lang.String)",
@@ -1568,7 +1569,7 @@
"public void <init>(java.util.Map)",
"public void <init>(java.util.Collection, java.util.Map)",
"public void <init>(java.util.Collection, java.util.Map, java.util.Map)",
- "public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map)",
+ "public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext)",
"public void addFunctionSerialization(java.lang.String, java.lang.String)",
"public void addArgumentTypeSerialization(java.lang.String, java.lang.String, com.yahoo.tensor.TensorType)",
"public void addFunctionTypeSerialization(java.lang.String, com.yahoo.tensor.TensorType)",
@@ -1597,6 +1598,24 @@
],
"fields": []
},
+ "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionScalarFunction": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "com.yahoo.tensor.functions.ScalarFunction"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getExpression()",
+ "public java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public java.lang.String toString()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public bridge synthetic java.lang.Object apply(java.lang.Object)"
+ ],
+ "fields": []
+ },
"com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction": {
"superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
"interfaces": [],
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 674571ff73e..f2f8799b342 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -134,6 +134,8 @@ public class ExpressionFunction {
for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) {
argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString());
}
+ String symbol = toSymbol(argumentBindings);
+ System.out.println("Expanding function " + symbol);
return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
index 83aabada8f0..9d094ce06f4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java
@@ -22,6 +22,8 @@ public class FunctionReferenceContext {
/** Mapping from argument names to the expressions they resolve to */
private final Map<String, String> bindings = new HashMap<>();
+ private final FunctionReferenceContext parent;
+
/** Create a context for a single serialization task */
public FunctionReferenceContext() {
this(Collections.emptyList());
@@ -43,9 +45,14 @@ public class FunctionReferenceContext {
/** Create a context for a single serialization task */
public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) {
+ this(functions, bindings, null);
+ }
+
+ public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings, FunctionReferenceContext parent) {
this.functions = ImmutableMap.copyOf(functions);
if (bindings != null)
this.bindings.putAll(bindings);
+ this.parent = parent;
}
private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
@@ -56,16 +63,34 @@ public class FunctionReferenceContext {
}
/** Returns a function or null if it isn't defined in this context */
- public ExpressionFunction getFunction(String name) { return functions.get(name); }
+ public ExpressionFunction getFunction(String name) {
+ ExpressionFunction function = functions.get(name);
+ if (function != null) {
+ return function;
+ }
+ if (parent != null) {
+ return parent.getFunction(name);
+ }
+ return null;
+ }
protected ImmutableMap<String, ExpressionFunction> functions() { return functions; }
/** Returns the resolution of an identifier, or null if it isn't defined in this context */
- public String getBinding(String name) { return bindings.get(name); }
+ public String getBinding(String name) {
+ String binding = bindings.get(name);
+ if (binding != null) {
+ return binding;
+ }
+ if (parent != null) {
+ return parent.getBinding(name);
+ }
+ return null;
+ }
/** Returns a new context with the bindings replaced by the given bindings */
public FunctionReferenceContext withBindings(Map<String, String> bindings) {
- return new FunctionReferenceContext(this.functions, bindings);
+ return new FunctionReferenceContext(this.functions, bindings, this);
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 8fec3603f3e..a994f5247b7 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -74,20 +74,49 @@ public final class ReferenceNode extends CompositeNode {
return string.append(context.getBinding(getName()));
}
+ String name = getName();
// A reference to a function?
ExpressionFunction function = context.getFunction(getName());
if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) {
// a function reference: replace by the referenced function wrapped in rankingExpression
- if (path == null)
- path = new ArrayDeque<>();
- String myPath = getName() + getArguments().expressions();
- if (path.contains(myPath))
- throw new IllegalStateException("Cycle in ranking expression function: " + path);
- path.addLast(myPath);
- ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
- path.removeLast();
- context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
- return string.append("rankingExpression(").append(instance.getName()).append(')');
+// if (path == null)
+// path = new ArrayDeque<>();
+// String myPath = getName() + getArguments().expressions();
+// if (path.contains(myPath))
+// throw new IllegalStateException("Cycle in ranking expression function: " + path);
+// path.addLast(myPath);
+// ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path);
+// path.removeLast();
+// context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString());
+// return string.append("rankingExpression(").append(instance.getName()).append(')');
+
+// return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString());
+
+ // hack for testing:
+ // So, this worked. Meaning that when expanding we could probably cut down on the context tree?
+// String expression = function.getBody().toString();
+// context.addFunctionSerialization(RankingExpression.propertyName(getName()), expression); // <- actually set by deriveFunctionProperties - this will only overwrite
+
+ String prefix = string.toString(); // incredibly ugly hack - for testing this
+
+ // so problem here with input values
+ if (prefix.startsWith("attribute")) {
+ if (name.equals("segment_ids") || name.equals("input_mask") || name.equals("input_ids")) {
+ return string.append(getName());
+ // TODO: divine this!
+ }
+ }
+
+ // so, in one case
+// rankprofile[2].fef.property[35].name "rankingExpression(imported_ml_function_bertsquad8_input_ids).rankingScript"
+// rankprofile[2].fef.property[35].value "input_ids"
+ // vs
+// rankprofile[2].fef.property[2].name "rankingExpression(input_ids).rankingScript"
+// rankprofile[2].fef.property[2].value "attribute(input_ids)"
+ // uppermost is wrong, then we need the below
+
+ return string.append("rankingExpression(").append(getName()).append(')');
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index d7807caa2b6..c79f5556e03 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -50,7 +50,7 @@ public class SerializationContext extends FunctionReferenceContext {
*/
public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings,
Map<String, String> serializedFunctions) {
- this(toMap(functions), bindings, serializedFunctions);
+ this(toMap(functions), bindings, serializedFunctions, null);
}
private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
@@ -69,8 +69,8 @@ public class SerializationContext extends FunctionReferenceContext {
* is <b>transferred</b> to this and will be modified in it
*/
public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings,
- Map<String, String> serializedFunctions) {
- super(functions, bindings);
+ Map<String, String> serializedFunctions, FunctionReferenceContext root) {
+ super(functions, bindings, root);
this.serializedFunctions = serializedFunctions;
}
@@ -92,7 +92,7 @@ public class SerializationContext extends FunctionReferenceContext {
@Override
public SerializationContext withBindings(Map<String, String> bindings) {
- return new SerializationContext(functions(), bindings, this.serializedFunctions);
+ return new SerializationContext(functions(), bindings, this.serializedFunctions, this);
}
public Map<String, String> serializedFunctions() { return serializedFunctions; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 6e1cdf52158..1ab9702367a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -143,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode {
return new ExpressionScalarFunction(node);
}
- private static class ExpressionScalarFunction implements ScalarFunction<Reference> {
+ public static class ExpressionScalarFunction implements ScalarFunction<Reference> {
private final ExpressionNode expression;
@@ -151,6 +151,10 @@ public class TensorFunctionNode extends CompositeNode {
this.expression = expression;
}
+ public ExpressionNode getExpression() {
+ return expression;
+ }
+
@Override
public Double apply(EvaluationContext<Reference> context) {
return expression.evaluate(new ContextWrapper(context)).asDouble();
@@ -321,13 +325,45 @@ public class TensorFunctionNode extends CompositeNode {
public ToStringContext parent() { return wrappedToStringContext; }
+ private int contextNodes() {
+ int nodes = 0;
+ if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) {
+ nodes += ((ExpressionToStringContext)wrappedToStringContext).contextNodes();
+ } else if (wrappedToStringContext != null) {
+ nodes += 1;
+ }
+ if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) {
+ nodes += ((ExpressionToStringContext)wrappedSerializationContext).contextNodes();
+ } else if (wrappedSerializationContext != null) {
+ nodes += 1;
+ }
+ return nodes + 1;
+ }
+
+ private int contextDepth() {
+ int depth = 0;
+ if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) {
+ depth += ((ExpressionToStringContext)wrappedToStringContext).contextDepth();
+ }
+ if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) {
+ int d = ((ExpressionToStringContext)wrappedSerializationContext).contextDepth();
+ depth = Math.max(depth, d);
+ }
+ return depth + 1;
+ }
+
/** Returns the resolution of an identifier, or null if it isn't defined in this context */
@Override
public String getBinding(String name) {
- if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null)
- return wrappedToStringContext.getBinding(name);
- else
- return wrappedSerializationContext.getBinding(name);
+// System.out.println("getBinding for " + name + " with node count " + contextNodes() + " and max depth " + contextDepth());
+ String binding;
+ if (wrappedToStringContext != null) {
+ binding = wrappedToStringContext.getBinding(name);
+ if (binding != null) {
+ return binding;
+ }
+ }
+ return wrappedSerializationContext.getBinding(name);
}
/** Returns a new context with the bindings replaced by the given bindings */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
index a541eac2421..95652bb0e15 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.transform;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
@@ -28,8 +29,16 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext
return node;
}
+ /** Returns true if the given reference is an attribute, constant or query feature */
+ // TEMP: from config-model module
+ public static boolean isSimpleFeature(Reference reference) {
+ if ( ! reference.isSimple()) return false;
+ String name = reference.name();
+ return name.equals("attribute") || name.equals("constant") || name.equals("query");
+ }
+
private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) {
- if (!node.getArguments().isEmpty())
+ if ( ! node.getArguments().isEmpty() && ! isSimpleFeature(node.reference()))
return transformArguments(node, context);
else
return transformConstantReference(node, context);
@@ -44,7 +53,14 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext
}
private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) {
- Value value = context.constants().get(node.getName());
+ String name = node.getName();
+ if (node.reference().name().equals("constant")) {
+ ExpressionNode arg = node.getArguments().expressions().get(0);
+ if (arg instanceof ReferenceNode) {
+ name = ((ReferenceNode)arg).getName();
+ }
+ }
+ Value value = context.constants().get(name); // works if "constant(...)" is added
if (value == null || value.type().rank() > 0) {
return node; // not a number constant reference
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 807eb3aa7ce..7b246f22cf2 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -390,6 +390,126 @@ public class EvaluationTestCase {
}
@Test
+ public void testTile() {
+ EvaluationTester tester = new EvaluationTester();
+
+ tester.assertEvaluates("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]",
+ "tensor(d0[2],d1[4])(tensor0{input0:(d0 % 2), input1:(d1 % 2) } )",
+ "tensor(input0[2],input1[2]):[1, 2, 3, 4]",
+ "tensor(repeats0[2]):[1,2]");
+
+ tester.assertEvaluates("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]",
+ "tensor(d0[6],d1[2])(tensor0{input0:(d0 % 2), input1:(d1 % 2) } )",
+ "tensor(input0[2],input1[2]):[1, 2, 3, 4]",
+ "tensor(repeats0[2]):[3,1]");
+ }
+
+ @Test
+ public void testReshape() {
+ EvaluationTester tester = new EvaluationTester();
+
+ tester.assertEvaluates("tensor(d0[4]):[1,2,3,4]",
+ "tensor(d0[4])(tensor0{a0:(d0 / 2), a1:(d0 % 2)})",
+ "tensor(a0[2],a1[2]):[1,2,3,4]",
+ "tensor(d0[1]):[4]");
+
+ tester.assertEvaluates("tensor(d0[2],d1[2]):[1,2,3,4]",
+ "tensor(d0[2],d1[2])(tensor0{a0:(d0), a1:(d1)})",
+ "tensor(a0[2],a1[2]):[1,2,3,4]",
+ "tensor(d0[2]):[2,2]");
+
+ tester.assertEvaluates("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]",
+ "tensor(d0[2],d1[1],d2[2])(tensor0{a0:(d0), a1:(d2)})",
+ "tensor(a0[2],a1[2]):[1,2,3,4]",
+ "tensor(d0[3]):[2,1,2]");
+
+ tester.assertEvaluates("tensor(d0[3],d1[2]):[1,2,3,4,5,6]",
+ "tensor(d0[3],d1[2])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })",
+ "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]",
+ "tensor(d0[2]):[3,2]");
+
+ tester.assertEvaluates("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]",
+ "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })",
+ "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]",
+ "tensor(d0[4]):[3,2,-1,1]");
+
+ }
+
+ @Test
+ public void testMatmul() {
+ EvaluationTester tester = new EvaluationTester();
+
+ tester.assertEvaluates("tensor():{91}",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)",
+ "tensor(d0[6]):[1,2,3,4,5,6]",
+ "tensor(d0[6]):[1,2,3,4,5,6]");
+
+ tester.assertEvaluates("tensor(d1[2]):[22, 28]",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)",
+ "tensor(d0[3]):[1,2,3]",
+ "tensor(d0[3],d1[2]):[1,2,3,4,5,6]");
+
+ tester.assertEvaluates("tensor(d1[2]):[22, 28]",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)",
+ "tensor(d0[3],d1[2]):[1,2,3,4,5,6]",
+ "tensor(d0[3]):[1,2,3]");
+
+ tester.assertEvaluates("tensor(d0[2],d2[2]):[22,28,49,64]",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d1)",
+ "tensor(d0[2],d1[3]):[1,2,3,4,5,6]",
+ "tensor(d1[3],d2[2]):[1,2,3,4,5,6]");
+
+ tester.assertEvaluates("tensor(d0[1],d1[2],d3[2]):[22,28,49,64]",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d2)",
+ "tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]",
+ "tensor(d2[3],d3[2]):[1,2,3,4,5,6]");
+
+ tester.assertEvaluates("tensor(d0[1],d1[2],d3[2]):[22,28,49,64]",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d2)",
+ "tensor(d1[2],d2[3]):[1,2,3,4,5,6]",
+ "tensor(d0[1],d2[3],d3[2]):[1,2,3,4,5,6]");
+
+ tester.assertEvaluates("tensor(d0[1],d1[4],d2[2],d4[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]",
+ "reduce(join(tensor0{d1:0}, tensor1, f(x,y)(x*y)), sum, d3)", // notice peek
+ "tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]",
+ "tensor(d0[1],d1[4],d3[3],d4[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+
+ tester.assertEvaluates("tensor(d0[1],d1[4],d2[2],d4[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]",
+ "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d3)",
+ "tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]",
+ "tensor(d0[1],d1[4],d3[3],d4[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]");
+
+ }
+
+ @Test
+ public void testSplit() {
+ EvaluationTester tester = new EvaluationTester();
+
+ tester.assertEvaluates("tensor(d0[3]):[1,2,3]",
+ "tensor(d0[3])(tensor0{input0:(d0)} )",
+ "tensor(input0[6]):[1,2,3,4,5,6]");
+ tester.assertEvaluates("tensor(d0[3]):[4,5,6]",
+ "tensor(d0[3])(tensor0{input0:(d0+3)} )",
+ "tensor(input0[6]):[1,2,3,4,5,6]");
+ tester.assertEvaluates("tensor(d0[4]):[3,4,5,6]",
+ "tensor(d0[4])(tensor0{input0:(d0+2)} )",
+ "tensor(input0[6]):[1,2,3,4,5,6]");
+ tester.assertEvaluates("tensor(d0[2]):[3,4]",
+ "tensor(d0[2])(tensor0{input0:(d0+2)} )",
+ "tensor(input0[6]):[1,2,3,4,5,6]");
+ tester.assertEvaluates("tensor(d0[2]):[5,6]",
+ "tensor(d0[2])(tensor0{input0:(d0+4)} )",
+ "tensor(input0[6]):[1,2,3,4,5,6]");
+
+ tester.assertEvaluates("tensor(d0[1],d1[3]):[1,2,3]",
+ "tensor(d0[1],d1[3])(tensor0{input0:(d0), input1:(d1)} )",
+ "tensor(input0[2],input1[3]):[[1,2,3],[4,5,6]]");
+ tester.assertEvaluates("tensor(d0[1],d1[3]):[4,5,6]",
+ "tensor(d0[1],d1[3])(tensor0{input0:(d0+1), input1:(d1)} )",
+ "tensor(input0[2],input1[3]):[[1,2,3],[4,5,6]]");
+ }
+
+ @Test
public void testTake() {
EvaluationTester tester = new EvaluationTester();
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 04f859e2802..2e13fe0de26 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1653,6 +1653,7 @@
"public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)",
"public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)",
"public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
+ "public com.yahoo.tensor.functions.ScalarFunction getBoundGenerator()",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
@@ -2548,6 +2549,7 @@
"public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.Slice withArguments(java.util.List)",
+ "public java.util.List getSubspaceAddress()",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
"public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
index 219a3fa2278..68c4aa4c809 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java
@@ -58,7 +58,12 @@ class IndexedDoubleTensor extends IndexedTensor {
@Override
public IndexedTensor.BoundBuilder cell(double value, long ... indexes) {
- values[(int)toValueIndex(indexes, sizes())] = value;
+ int index = (int)toValueIndex(indexes, sizes());
+ if (index < 0 || index >= values.length) {
+ System.out.println("Argh");
+ }
+ values[index] = value;
+// values[(int)toValueIndex(indexes, sizes())] = value;
return this;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index fa3d70a4ddf..2a2551c9a58 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -71,6 +71,10 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
throw new IllegalArgumentException("A generated tensor can only have indexed bound dimensions");
}
+ public ScalarFunction<NAMETYPE> getBoundGenerator() {
+ return boundGenerator;
+ }
+
@Override
public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index bccd66acd31..a0a3552eb92 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -53,6 +53,10 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
return new Slice<>(arguments.get(0), subspaceAddress);
}
+ public List<DimensionValue<NAMETYPE>> getSubspaceAddress() {
+ return subspaceAddress;
+ }
+
@Override
public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }