diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations')
13 files changed, 539 insertions, 62 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 01fd7ee55bd..956d727fbad 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -54,10 +54,10 @@ public class Const extends IntermediateOperation { } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } +// @Override +// public String vespaName() { +// return modelName + "_" + super.vespaName(); +// } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java index ad56eefe5f2..b12f83f274b 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Constant.java @@ -22,10 +22,10 @@ public class Constant extends IntermediateOperation { } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + vespaName(name); - } +// @Override +// public String vespaName() { +// return modelName + "_" + vespaName(name); +// } @Override protected OrderedTensorType lazyGetType() { @@ -61,7 +61,9 @@ public class Constant extends IntermediateOperation { public Constant withInputs(List<IntermediateOperation> inputs) { if ( ! inputs.isEmpty()) throw new IllegalArgumentException("Constant cannot take inputs"); - return new Constant(modelName(), name(), type); + Constant constant = new Constant(modelName(), name(), type); + constant.setConstantValueFunction(constantValueFunction); + return constant; } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java index 5463f645355..af192fcec38 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Identity.java @@ -12,12 +12,6 @@ public class Identity extends IntermediateOperation { super(modelName, nodeName, inputs); } - /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ - @Override - public String vespaName() { - return modelName + "_" + super.vespaName(); - } - @Override protected OrderedTensorType lazyGetType() { if (!allInputTypesPresent(1)) diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 2aa8b2a0d48..83e15a4081a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; @@ -13,6 +14,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -47,6 +49,8 @@ public abstract class IntermediateOperation { protected TensorFunction rankingExpressionFunction = null; protected boolean exportAsRankingFunction = false; + private boolean hasRenamedDimensions = false; + private final List<String> importWarnings = new ArrayList<>(); private Value constantValue = null; private List<IntermediateOperation> controlInputs = Collections.emptyList(); @@ -121,7 +125,10 @@ public abstract class IntermediateOperation { } /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } + public void renameDimensions(DimensionRenamer renamer) { + type = type.rename(renamer); + hasRenamedDimensions = true; + } /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ public boolean isInput() { return false; } @@ -144,7 +151,11 @@ public abstract class IntermediateOperation { } /** Set the constant value function */ - public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { this.constantValueFunction = func; } + public void setConstantValueFunction(Function<OrderedTensorType, Value> func) { + this.constantValueFunction = func; + } + + public boolean hasConstantValueFunction() { return constantValueFunction != null; } /** Sets the external control inputs */ public void setControlInputs(List<IntermediateOperation> inputs) { this.controlInputs = inputs; } @@ -153,12 +164,23 @@ public abstract class IntermediateOperation { public List<IntermediateOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } /** Retrieve the valid Vespa name of this node */ - public String vespaName() { return vespaName(name); } - public String vespaName(String name) { return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; } + public String vespaName() { + if (isConstant()) + return modelName + "_" + vespaName(name); + return vespaName(name); + } + + public String vespaName(String name) { + return name != null ? namePartOf(name).replace('/', '_').replace('.', '_') : null; + } /** Retrieve the valid Vespa name of this node if it is a ranking expression function */ public String rankingExpressionFunctionName() { - return vespaName() != null ? FUNCTION_PREFIX + modelName + "_" + vespaName() : null; + String vespaName = vespaName(); + if (vespaName == null) { + return null; + } + return isConstant() ? "constant(" + vespaName + ")" : FUNCTION_PREFIX + modelName + "_" + vespaName; } /** Retrieve the list of warnings produced during its lifetime */ @@ -185,30 +207,80 @@ public abstract class IntermediateOperation { /** Recursively evaluates this operation's constant value to avoid doing it run-time. */ public Value evaluateAsConstant(OrderedTensorType type) { +// System.out.println("Starting constant evaluation for " + name); if ( ! isConstant() ) { throw new IllegalArgumentException("Attempted to evaluate non-constant operation as a constant."); } - Value val = evaluateAsConstant(new MapContext(DoubleValue.NaN)); - if (type != null && ! val.asTensor().type().equals(type.type()) ) { + if (type == null) { + System.out.println("Evaluating as constant for " + name + " with type null! Probably an error."); + } + + IntermediateOperation evaluateOn = this; + if ( ! hasRenamedDimensions) { + // make a copy of the tree, perform renaming and evaluate + IntermediateOperation copy = copyTree(0); + optimizeAndRename(copy); + evaluateOn = copy; + } + Value val = evaluateOn.evaluateAsConstant(new MapContext(DoubleValue.NaN), 0); + + if (type == null) { + return val; + } + Tensor tensor = val.asTensor(); //.withType(type.type()); + if ( ! tensor.type().isRenamableTo(type.type()) ) { throw new IllegalArgumentException("Constant evaluation in " + name + " resulted in wrong type. " + "Expected: " + type.type() + " Got: " + val.asTensor().type()); } - return val; + // set constant value so we don't have to re-evaluate + setConstantValueFunction(t -> new TensorValue(tensor.withType(t.type()))); +// System.out.println("Returning constant evaluation for " + name); + return new TensorValue(tensor.withType(type.type())); + } + + private IntermediateOperation copyTree(int indent) { + String indentString = ""; for (int i = 0; i < indent; ++i) indentString += " "; +// System.out.println(indentString + "Copying " + name); + List<IntermediateOperation> in = new ArrayList<>(); + if (constantValue != null) { +// System.out.println(indentString + name + " has a constant value"); + IntermediateOperation constant = new Constant(modelName, name, type); + constant.setConstantValueFunction(t -> new TensorValue(constantValue.asTensor().withType(t.type()))); + return constant; + } + inputs.forEach(i -> in.add(i.copyTree(indent + 1))); + IntermediateOperation copy = withInputs(in); + if (constantValueFunction != null) { + copy.constantValueFunction = constantValueFunction; // works? + } + return copy; + } + + private TensorFunction optimizeAndRename(IntermediateOperation op) { + IntermediateGraph graph = new IntermediateGraph(modelName); + graph.put(name, op); + graph.outputs(graph.defaultSignature()).put(name, name); + graph.optimize(); + return op.function().get(); } - private Value evaluateAsConstant(Context context) { + private Value evaluateAsConstant(Context context, int indent) { + String in = ""; for (int i = 0; i < indent; ++i) in += " "; +// System.out.println(in + "Constant evaluating for " + name); String constantName = "constant(" + vespaName() + ")"; Value result = context.get(constantName); if (result == DoubleValue.NaN) { if (constantValue != null) { +// System.out.println(in + name + " has constant value."); result = constantValue; } else if (inputs.size() == 0) { +// System.out.println(in + name + " has no inputs."); if (getConstantValue().isEmpty()) { throw new IllegalArgumentException("Error in evaluating constant for " + name); } result = getConstantValue().get(); } else { - inputs.forEach(i -> i.evaluateAsConstant(context)); + inputs.forEach(i -> i.evaluateAsConstant(context, indent+1)); result = new TensorValue(lazyGetFunction().evaluate(context)); } context.put(constantName, result); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index adb54474812..3211a44fa68 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -82,6 +82,13 @@ public class Join extends IntermediateOperation { bReducedFunction = new Reduce(b.function().get(), Reduce.Aggregator.sum, bDimensionsToReduce); } + // retain order of inputs + if (a == inputs.get(1)) { + TensorFunction temp = bReducedFunction; + bReducedFunction = aReducedFunction; + aReducedFunction = temp; + } + return new com.yahoo.tensor.functions.Join(aReducedFunction, bReducedFunction, operator); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 6849e64641e..1eb21eb2a5e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -4,6 +4,9 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.text.ExpressionFormatter; @@ -20,64 +23,126 @@ public class MatMul extends IntermediateOperation { protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + + // add some more checks here + if (aType.type().rank() < 1 || bType.type().rank() < 1) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 1"); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); - typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); - typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + OrderedTensorType largestRankType = aType.rank() >= bType.rank() ? aType : bType; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + typeBuilder.add(largestRankType.dimensions().get(i)); + } + if (aType.rank() >= 2) { + typeBuilder.add(aType.dimensions().get(aType.rank() - 2)); + } + if (bType.rank() >= 2) { + typeBuilder.add(bType.dimensions().get(bType.rank() - 1)); + } return typeBuilder.build(); } @Override protected TensorFunction lazyGetFunction() { if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; OrderedTensorType aType = inputs.get(0).type().get(); - OrderedTensorType bType = inputs.get(1).type().get(); - if (aType.type().rank() < 2 || bType.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (aType.type().rank() != bType.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - Optional<TensorFunction> aFunction = inputs.get(0).function(); Optional<TensorFunction> bFunction = inputs.get(1).function(); - if (!aFunction.isPresent() || !bFunction.isPresent()) { - return null; - } - return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + + // only change to this is for dimensions with size 1 - check in getType + + return new com.yahoo.tensor.functions.Reduce(new Join(aFunction.get(), bFunction.get(), ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + aType.dimensions().get(aType.rank() - 1).name()); } @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { if ( ! allInputTypesPresent(2)) return; - List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); - List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + /* + * A: a1, a2, a3, a4 + * B: b1, b2, b3, b4 + * + * a4 == b3 + * a3 < b4 + * a3 < a4 + * b4 < b3 + * + * a1 == b1 -> men også størrelsesmessig. + * a2 == b2 + * etc + */ + + OrderedTensorType typeA = inputs.get(0).type().get(); + OrderedTensorType typeB = inputs.get(1).type().get(); + + String lastDimA = typeA.dimensions().get(typeA.rank()-1).name(); + String lastDimB = typeB.dimensions().get(typeB.rank()-1).name(); + String secondLastDimA = typeA.dimensions().get(Math.max(0,typeA.rank()-2)).name(); + String secondLastDimB = typeB.dimensions().get(Math.max(0,typeB.rank()-2)).name(); + + // The last dimension of A should have the same name as the second-to-last dimension of B + renamer.addConstraint(lastDimA, secondLastDimB, DimensionRenamer.Constraint.equal(false), this); - assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); - assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + // For efficiency, the dimensions to join over should be innermost - soft constraint + if (typeA.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimA, DimensionRenamer.Constraint.lessThan(true), this); + } + if (typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimB, lastDimB, DimensionRenamer.Constraint.greaterThan(true), this); + } - String aDim0 = aDimensions.get(0).name(); - String aDim1 = aDimensions.get(1).name(); - String bDim0 = bDimensions.get(0).name(); - String bDim1 = bDimensions.get(1).name(); + // The second-to-last dimension of a should have a different name than the last dimension of b + if (typeA.rank() >= 2 && typeB.rank() >= 2) { + renamer.addConstraint(secondLastDimA, lastDimB, DimensionRenamer.Constraint.lessThan(false), this); + } - // The second dimension of a should have the same name as the first dimension of b - renamer.addConstraint(aDim1, bDim0, DimensionRenamer.Constraint.equal(false), this); + // a1 < a2 < a3 < a4 + OrderedTensorType largestRankType = typeA.rank() >= typeB.rank() ? typeA : typeB; + for (int i = 0; i < largestRankType.rank() - 2; ++i) { + String iDim = largestRankType.dimensionNames().get(i); + for (int j = i+1; j < largestRankType.rank() - 2; ++j) { + String jDim = largestRankType.dimensionNames().get(j); + renamer.addConstraint(iDim, jDim, DimensionRenamer.Constraint.lessThan(true), this); + } + } + + // TODO: handle non similar sizes + + // a1 == b1 etc + if (typeA.rank() == typeB.rank()) { + for (int i = 0; i < typeA.rank() - 2; ++i) { + renamer.addConstraint(typeA.dimensionNames().get(i), typeB.dimensionNames().get(i), DimensionRenamer.Constraint.equal(false), this); + } + } - // The first dimension of a should have a different name than the second dimension of b - renamer.addConstraint(aDim0, bDim1, DimensionRenamer.Constraint.lessThan(false), this); - // For efficiency, the dimensions to join over should be innermost - soft constraint - renamer.addConstraint(aDim0, aDim1, DimensionRenamer.Constraint.lessThan(true), this); - renamer.addConstraint(bDim0, bDim1, DimensionRenamer.Constraint.greaterThan(true), this); - } - private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { - if (dimensions.size() >= 2) return; - throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + - " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); + + // So, what about the other dimensions? +// if (aDimensions.size() > 2) { +// for (int i = 1; i < aDimensions.size(); ++i) { +// renamer.addConstraint(aDimensions.get(0).name(), aDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this); +// } +// for (int i = 0; i < bDimensions.size(); ++i) { +// renamer.addConstraint(aDimensions.get(0).name(), bDimensions.get(i).name(), DimensionRenamer.Constraint.notEqual(false), this); +// } +// } + } +// private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { +// if (dimensions.size() >= 2) return; +// throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + +// " but got just " + dimensions + " from\n" + +// ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); +// } + @Override public MatMul withInputs(List<IntermediateOperation> inputs) { return new MatMul(modelName(), name(), inputs); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java index e040ae62149..07ac457cca8 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Rename.java @@ -54,7 +54,7 @@ public class Rename extends IntermediateOperation { } public void renameDimensions(DimensionRenamer renamer) { - type = type.rename(renamer); + super.renameDimensions(renamer); from = renamer.dimensionNameOf(from).orElse(from); to = renamer.dimensionNameOf(to).orElse(to); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index c88fc18e6c6..f96dd420d30 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -2,8 +2,10 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; @@ -11,8 +13,11 @@ import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.Function; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -27,6 +32,8 @@ import java.util.List; import java.util.Optional; import java.util.stream.Collectors; +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + public class Reshape extends IntermediateOperation { private final AttributeMap attributeMap; @@ -38,6 +45,10 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + if (inputs.size() == 2) { return typeWithShapeAsInput(); } else if (inputs.size() == 1) { @@ -126,10 +137,54 @@ public class Reshape extends IntermediateOperation { return new Reshape(modelName(), name(), inputs, attributeMap); } - public static TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { + public TensorFunction reshape(TensorFunction inputFunction, OrderedTensorType inputType, OrderedTensorType outputType) { if ( ! OrderedTensorType.tensorSize(inputType.type()).equals(OrderedTensorType.tensorSize(outputType.type()))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + IntermediateOperation input = inputs.get(0); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + // ala (d0 * 2 + d1) + ExpressionNode unrolled = new EmbracedNode(unrollTensorExpression(outputType)); + + long innerSize = 1; + for (int dim = 0; dim < inputType.rank(); ++dim) { + innerSize *= inputType.dimensions().get(dim).size().get(); + } + + for (int dim = 0; dim < inputType.rank(); ++dim) { + String inputDimensionName = inputType.dimensions().get(dim).name(); + long inputDimensionSize = inputType.dimensions().get(dim).size().get(); + long previousInnerSize = innerSize; + innerSize /= inputDimensionSize; + + ExpressionNode inputDimensionExpression; + if (inputDimensionSize == 1) { + inputDimensionExpression = new EmbracedNode(new ConstantNode(DoubleValue.zero)); + } else if (dim == (inputType.rank() - 1)) { + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode div = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, size); + inputDimensionExpression = new EmbracedNode(div); + } else { + ExpressionNode size = new ConstantNode(new DoubleValue(innerSize)); + ExpressionNode previousSize = new ConstantNode(new DoubleValue(previousInnerSize)); + ExpressionNode mod = new ArithmeticNode(unrolled, ArithmeticOperator.MODULO, previousSize); + ExpressionNode div = new ArithmeticNode(new EmbracedNode(mod), ArithmeticOperator.DIVIDE, size); + inputDimensionExpression = new EmbracedNode(new FunctionNode(Function.floor, div)); + } + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(inputDimensionExpression))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(outputType.type(), wrapScalar(sliceExpression)); + return generate; + + /* // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. // Here we create a transformation tensor that is multiplied with the from tensor to map into @@ -168,11 +223,14 @@ public class Reshape extends IntermediateOperation { result = new Rename(result, to, from); } return result; + */ } + /* private static boolean dimensionNamesOverlap(OrderedTensorType a, OrderedTensorType b) { return a.dimensionNames().stream().anyMatch(d -> b.type().indexOfDimension(d).isPresent()); } + */ private static ExpressionNode unrollTensorExpression(OrderedTensorType type) { if (type.rank() == 0) diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java index e5463291ef8..8dd1e3ff33d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Slice.java @@ -182,7 +182,6 @@ public class Slice extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - // Todo: what to do? for (int i = 0; i < type.dimensions().size(); i++) { renamer.addDimension(type.dimensions().get(i).name()); for (int j = i + 1; j < type.dimensions().size(); j++) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java index 83086926316..e2b83246bfc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Softmax.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Map; import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.TensorFunction; @@ -28,6 +29,10 @@ public class Softmax extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { if ( ! allInputTypesPresent(1)) return null; + + // input is referenced twice due to avoidance of overflow. so make this it's own function. + inputs.get(0).exportAsRankingFunction = true; + return inputs.get(0).type().get(); } @@ -50,7 +55,9 @@ public class Softmax extends IntermediateOperation { } TensorFunction input = inputs.get(0).function().get(); - TensorFunction exp = new Map(input, ScalarFunctions.exp()); + TensorFunction max = new Reduce(input, Reduce.Aggregator.max, reduceDimensions); + TensorFunction cap = new Join(input, max, ScalarFunctions.subtract()); // to avoid overflow + TensorFunction exp = new Map(cap, ScalarFunctions.exp()); TensorFunction sum = new Reduce(exp, Reduce.Aggregator.sum, reduceDimensions); TensorFunction div = new Join(exp, sum, ScalarFunctions.divide()); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java new file mode 100644 index 00000000000..02d780c52cd --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Split.java @@ -0,0 +1,119 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +public class Split extends IntermediateOperation { + + private final AttributeMap attributes; + private final int output; + + private final int axis; + private int start; + private int end; + + public Split(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes, int output) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + this.output = output; + axis = (int) attributes.get("axis").orElse(DoubleValue.zero).asDouble(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + OrderedTensorType inputType = inputs.get(0).type().get(); + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + int axisSize = inputType.dimensions().get(axis).size().get().intValue(); + start = 0; + end = axisSize; + + if (attributes.getList("split").isPresent()) { + List<Value> splitList = attributes.getList("split").get(); + if (output > splitList.size()) { + throw new IllegalArgumentException("Split in " + name + ": output out of range of split list"); + } + for (int i = 0; i < output; ++i) { + start += (int) splitList.get(i).asDouble(); + } + if (output < splitList.size()) { + end = start + (int) splitList.get(output).asDouble(); + } + } else { + start = axisSize / 2 * output; + end = start + axisSize / 2; + } + + if (start >= axisSize || start < 0) { + throw new IllegalArgumentException("Split in " + name + ": split start index out of range (" + start + ")"); + } + if (end > axisSize || end < 0) { + throw new IllegalArgumentException("Split in " + name + ": split end index out of range (" + end + ")"); + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < inputType.rank(); ++i) { + TensorType.Dimension inputDimension = inputType.dimensions().get(i); + long dimSize = i == axis ? end - start : inputDimension.size().get(); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), dimSize)); + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + for (int i = 0; i < inputType.rank(); ++i) { + String inputDimensionName = inputType.dimensions().get(i).name(); + ExpressionNode reference = new ReferenceNode(inputDimensionName); + ExpressionNode offset = new ArithmeticNode(reference, ArithmeticOperator.PLUS, new ConstantNode(new DoubleValue(i == axis ? start : 0))); + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(offset)))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public Split withInputs(List<IntermediateOperation> inputs) { + return new Split(modelName(), name(), inputs, attributes, output); + } + + @Override + public String operationName() { return "Split"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java new file mode 100644 index 00000000000..8d3468f3d04 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Tile.java @@ -0,0 +1,100 @@ +// Copyright 2020 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.EmbracedNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +import static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode.wrapScalar; + +/** + * Onnx tile operation. + */ +public class Tile extends IntermediateOperation { + + public Tile(String modelName, String nodeName, List<IntermediateOperation> inputs) { + super(modelName, nodeName, inputs); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) return null; + + // required as we use tensor create + inputs.get(0).exportAsRankingFunction = true; + + IntermediateOperation repeats = inputs.get(1); + if (repeats.getConstantValue().isEmpty()) + throw new IllegalArgumentException("Tile " + name + ": repeats input must be a constant."); + + Tensor shape = repeats.getConstantValue().get().asTensor(); + if (shape.type().rank() != 1) + throw new IllegalArgumentException("Tile " + name + ": repeats must be a 1-d tensor."); + + OrderedTensorType inputType = inputs.get(0).type().get(); + if (shape.type().dimensions().get(0).size().get() != inputType.rank()) + throw new IllegalArgumentException("Tile " + name + ": repeats must be the same size as input rank."); + + List<Integer> dimSizes = new ArrayList<>(inputType.rank()); + shape.valueIterator().forEachRemaining(v -> dimSizes.add(v.intValue())); + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < dimSizes.size(); ++i) { + TensorType.Dimension inputDimension = inputType.dimensions().get(i); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get() * dimSizes.get(i))); + } + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) return null; + + IntermediateOperation input = inputs.get(0); + OrderedTensorType inputType = input.type().get(); + String inputFunctionName = input.rankingExpressionFunctionName(); + + List<com.yahoo.tensor.functions.Slice.DimensionValue<Reference>> dimensionValues = new ArrayList<>(); + + for (int axis = 0; axis < inputType.rank(); ++axis) { + String inputDimensionName = inputType.dimensions().get(axis).name(); + long inputDimensionSize = inputType.dimensions().get(axis).size().get(); + + ExpressionNode size = new ConstantNode(new DoubleValue(inputDimensionSize)); + ExpressionNode reference = new ReferenceNode(inputDimensionName); + ExpressionNode mod = new ArithmeticNode(reference, ArithmeticOperator.MODULO, size); + dimensionValues.add(new com.yahoo.tensor.functions.Slice.DimensionValue<>(Optional.of(inputDimensionName), wrapScalar(new EmbracedNode(mod)))); + } + + TensorFunction<Reference> inputIndices = new TensorFunctionNode.ExpressionTensorFunction(new ReferenceNode(inputFunctionName)); + com.yahoo.tensor.functions.Slice<Reference> sliceIndices = new com.yahoo.tensor.functions.Slice<>(inputIndices, dimensionValues); + ExpressionNode sliceExpression = new TensorFunctionNode(sliceIndices); + + TensorFunction generate = Generate.bound(type.type(), wrapScalar(sliceExpression)); + return generate; + } + + @Override + public Tile withInputs(List<IntermediateOperation> inputs) { + return new Tile(modelName(), name(), inputs); + } + + @Override + public String operationName() { return "Tile"; } + +} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java new file mode 100644 index 00000000000..178759fbf2a --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Transpose.java @@ -0,0 +1,54 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.operations; + +import ai.vespa.rankingexpression.importer.OrderedTensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; + +import java.util.List; + +public class Transpose extends IntermediateOperation { + + private final AttributeMap attributes; + + public Transpose(String modelName, String nodeName, List<IntermediateOperation> inputs, AttributeMap attributes) { + super(modelName, nodeName, inputs); + this.attributes = attributes; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) return null; + + OrderedTensorType inputType = inputs.get(0).type().get(); + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); + for (int i = 0; i < inputType.rank(); ++i) { + int inputIndex = inputType.rank() - 1 - i; + if (attributes.getList("perm").isPresent()) { + inputIndex = (int) attributes.getList("perm").get().get(i).asDouble(); + } + TensorType.Dimension inputDimension = inputType.dimensions().get(inputIndex); + typeBuilder.add(TensorType.Dimension.indexed(inputDimension.name(), inputDimension.size().get())); + } + OrderedTensorType result = typeBuilder.build(); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) + return null; + return inputs.get(0).function().orElse(null); + } + + @Override + public Transpose withInputs(List<IntermediateOperation> inputs) { + return new Transpose(modelName(), name(), inputs, attributes); + } + + @Override + public String operationName() { return "Transpose"; } + +} |