diff options
9 files changed, 107 insertions, 101 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java index 0f08bf0bf21..9ef1a3f6e32 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java @@ -56,7 +56,7 @@ public abstract class FieldUpdateHelper { } else if (upd instanceof ArithmeticValueUpdate) { if (((ArithmeticValueUpdate)upd).getOperator() == ArithmeticValueUpdate.Operator.DIV && ((ArithmeticValueUpdate)upd).getOperand().doubleValue() == 0) { - throw new IllegalArgumentException("Division by zero."); + throw new IllegalArgumentException("Div by zero."); } val.assign(upd.getValue()); return val; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 8dcd31b270e..e47f2ad53d9 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -7,10 +7,8 @@ import com.yahoo.io.IOUtils; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Matmul; import com.yahoo.tensor.functions.Rename; @@ -34,6 +32,7 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.function.DoubleBinaryOperator; +import java.util.function.DoubleUnaryOperator; import java.util.stream.Collectors; /** @@ -146,11 +145,13 @@ public class TensorFlowImporter { GraphDef graph, String indent) { // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops - switch (tfNode.getOp()) { - case "Identity" : return identity(tfNode, inputs); - case "Add" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add()); - case "MatMul" : return matmul(importArguments(tfNode, inputs, graph, indent)); - case "Softmax" : return softmax(importArguments(tfNode, inputs, graph, indent)); + // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ + switch (tfNode.getOp().toLowerCase()) { + case "add" : case "add_n" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add()); + case "acos" : return map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos()); + case "identity" : return identity(tfNode, inputs); + case "matmul" : return matmul(importArguments(tfNode, inputs, graph, indent)); + case "softmax" : return softmax(importArguments(tfNode, inputs, graph, indent)); default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported"); } } @@ -162,15 +163,51 @@ public class TensorFlowImporter { } private TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { + // Note that this generalizes the corresponding TF function as it does not verify that the tensor + // types are the same, with the assumption that this already happened on the TF side + // (and if not, this should do the right thing anyway) ensureArguments(2, arguments, "join"); TypedTensorFunction a = arguments.get(0); TypedTensorFunction b = arguments.get(0); - // TODO: Verify with TF doc - TensorType resultType = Join.resultType(a.type(), b.type()); + + TensorType resultType = Join.outputType(a.type(), b.type()); Join function = new Join(a.function(), b.function(), doubleFunction); return new TypedTensorFunction(resultType, function); } - + + private TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { + ensureArguments(1, arguments, "apply"); + TypedTensorFunction a = arguments.get(0); + + TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); + com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); + return new TypedTensorFunction(resultType, function); + } + + private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) { + // TODO: Verify with TF documentation + String name; + TensorType inputType; + if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants + if (tfNode.getInputList().size() != 1) + throw new IllegalArgumentException("A Variable/read node must have one input but has " + + tfNode.getInputList().size()); + name = tfNode.getInput(0); + AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); + inputType = importTensorType(shapes.getList().getShape(0)); + } + else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name + name = tfNode.getName(); + inputType = inputs.get(name); + if (inputType == null) + throw new IllegalArgumentException("An identity operation node is referencing input '" + name + + "', but there is no such input"); + } + return new TypedTensorFunction(inputType, new VariableTensor(name)); + } + private TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { ensureArguments(2, arguments, "matmul"); TypedTensorFunction a = arguments.get(0); @@ -183,7 +220,7 @@ public class TensorFlowImporter { // Let the second-to-last dimension of the second tensor be the same as the last dimension of the first // and the last dimension of the second argument be not present in the first argument, while leaving the // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - + // TODO: Check if transpose_a or transpose_b is set and rename differently accordingly String beforeLastDim = "d" + (a.type().rank() - 1); @@ -193,40 +230,17 @@ public class TensorFlowImporter { Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim), ImmutableList.of(lastDim, afterLastDim)); Matmul matmul = new Matmul(a.function(), renamedB, lastDim); - return new TypedTensorFunction(Matmul.resultType(a.type(), b.type(), lastDim), + return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim), new Rename(matmul, afterLastDim, lastDim)); } private TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { ensureArguments(1, arguments, "softmax"); TypedTensorFunction a = arguments.get(0); - String dimension = "d0"; // TODO: Verify with TF doc + // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 + String dimension = "d" + (a.type().rank() - 1); Softmax softmax = new Softmax(a.function(), dimension); - return new TypedTensorFunction(Softmax.resultType(a.type(), dimension), softmax); - } - - private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) { - // TODO: Verify with TF documentation - String name; - TensorType inputType; - if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but has " + - tfNode.getInputList().size()); - name = tfNode.getInput(0); - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape"); - inputType = importTensorType(shapes.getList().getShape(0)); - } - else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name - name = tfNode.getName(); - inputType = inputs.get(name); - if (inputType == null) - throw new IllegalArgumentException("An identity operation node is referencing input '" + name + - "', but there is no such input"); - } - return new TypedTensorFunction(inputType, new VariableTensor(name)); + return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); } private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java index 30328c3d9fe..bfe2bb3a63b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java @@ -23,7 +23,7 @@ public class TensorFlowImporterTestCase { "softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + "rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " + "f(a,b)(a + b)), " + - "d0)", + "d1)", toNonPrimitiveString(expressions.get(0))); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index c89f63c0395..9a37127e1f0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -47,7 +47,7 @@ public class Join extends PrimitiveTensorFunction { } /** Returns the type resulting from applying Join to the two given types */ - public static TensorType resultType(TensorType a, TensorType b) { + public static TensorType outputType(TensorType a, TensorType b) { TensorType.Builder typeBuilder = new TensorType.Builder(); for (int i = 0; i < a.dimensions().size(); ++i) { TensorType.Dimension aDim = a.dimensions().get(i); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a9872bb42d8..d322a6ab497 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableMap; import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; @@ -31,6 +32,8 @@ public class Map extends PrimitiveTensorFunction { this.argument = argument; this.mapper = mapper; } + + public static TensorType outputType(TensorType inputType) { return inputType; } public TensorFunction argument() { return argument; } public DoubleUnaryOperator mapper() { return mapper; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index cbb3f159623..5e102454487 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -22,8 +22,8 @@ public class Matmul extends CompositeTensorFunction { this.dimension = dimension; } - public static TensorType resultType(TensorType a, TensorType b, String dimension) { - return Reduce.resultType(Join.resultType(a, b), ImmutableList.of(dimension)); + public static TensorType outputType(TensorType a, TensorType b, String dimension) { + return Reduce.outputType(Join.outputType(a, b), ImmutableList.of(dimension)); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index aa28a26deb2..a51df12e522 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,9 +61,9 @@ public class Reduce extends PrimitiveTensorFunction { this.dimensions = ImmutableList.copyOf(dimensions); } - public static TensorType resultType(TensorType type, List<String> reduceDimensions) { + public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { TensorType.Builder b = new TensorType.Builder(); - for (TensorType.Dimension dimension : type.dimensions()) { + for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) b.dimension(dimension); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index 99f79cb735a..fb5029fbfd6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -21,101 +21,87 @@ import java.util.stream.Collectors; @Beta public class ScalarFunctions { - public static DoubleBinaryOperator add() { return new Addition(); } - public static DoubleBinaryOperator multiply() { return new Multiplication(); } - public static DoubleBinaryOperator divide() { return new Division(); } + public static DoubleBinaryOperator add() { return new Add(); } + public static DoubleBinaryOperator divide() { return new Divide(); } public static DoubleBinaryOperator equal() { return new Equal(); } - public static DoubleUnaryOperator square() { return new Square(); } + public static DoubleBinaryOperator multiply() { return new Multiply(); } + + public static DoubleUnaryOperator acos() { return new Acos(); } + public static DoubleUnaryOperator exp() { return new Exp(); } public static DoubleUnaryOperator sqrt() { return new Sqrt(); } - public static DoubleUnaryOperator exp() { return new Exponent(); } + public static DoubleUnaryOperator square() { return new Square(); } + public static Function<List<Integer>, Double> random() { return new Random(); } public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } - public static class Addition implements DoubleBinaryOperator { + // Binary operators ----------------------------------------------------------------------------- + public static class Add implements DoubleBinaryOperator { @Override public double applyAsDouble(double left, double right) { return left + right; } - @Override public String toString() { return "f(a,b)(a + b)"; } - } - public static class Multiplication implements DoubleBinaryOperator { - + public static class Equal implements DoubleBinaryOperator { @Override - public double applyAsDouble(double left, double right) { return left * right; } - + public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } @Override - public String toString() { return "f(a,b)(a * b)"; } - + public String toString() { return "f(a,b)(a==b)"; } } - public static class Division implements DoubleBinaryOperator { - + public static class Exp implements DoubleUnaryOperator { @Override - public double applyAsDouble(double left, double right) { return left / right; } - + public double applyAsDouble(double operand) { return Math.exp(operand); } @Override - public String toString() { return "f(a,b)(a / b)"; } + public String toString() { return "f(a)(exp(a))"; } } - public static class Equal implements DoubleBinaryOperator { - + public static class Multiply implements DoubleBinaryOperator { @Override - public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; } + public double applyAsDouble(double left, double right) { return left * right; } + @Override + public String toString() { return "f(a,b)(a * b)"; } + } + public static class Divide implements DoubleBinaryOperator { @Override - public String toString() { return "f(a,b)(a==b)"; } + public double applyAsDouble(double left, double right) { return left / right; } + @Override + public String toString() { return "f(a,b)(a / b)"; } } - public static class Square implements DoubleUnaryOperator { + // Unary operators ------------------------------------------------------------------------------ + public static class Acos implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return operand * operand; } - + public double applyAsDouble(double operand) { return Math.acos(operand); } @Override - public String toString() { return "f(a)(a * a)"; } - + public String toString() { return "f(a)(acos(a))"; } } public static class Sqrt implements DoubleUnaryOperator { - @Override public double applyAsDouble(double operand) { return Math.sqrt(operand); } - @Override public String toString() { return "f(a)(sqrt(a))"; } - } - public static class Exponent implements DoubleUnaryOperator { + public static class Square implements DoubleUnaryOperator { @Override - public double applyAsDouble(double operand) { return Math.exp(operand); } + public double applyAsDouble(double operand) { return operand * operand; } @Override - public String toString() { return "f(a)(exp(a))"; } + public String toString() { return "f(a)(a * a)"; } } - public static class Random implements Function<List<Integer>, Double> { - - @Override - public Double apply(List<Integer> values) { - return ThreadLocalRandom.current().nextDouble(); - } - - @Override - public String toString() { return "random"; } + // Variable-length operators ----------------------------------------------------------------------------- - } - - public static class EqualElements implements Function<List<Integer>, Double> { - - private final ImmutableList<String> argumentNames; - + public static class EqualElements implements Function<List<Integer>, Double> { + private final ImmutableList<String> argumentNames; private EqualElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @@ -128,7 +114,6 @@ public class ScalarFunctions { return 0.0; return 1.0; } - @Override public String toString() { if (argumentNames.size() == 0) return "1"; @@ -143,13 +128,19 @@ public class ScalarFunctions { } return b.toString(); } + } + public static class Random implements Function<List<Integer>, Double> { + @Override + public Double apply(List<Integer> values) { + return ThreadLocalRandom.current().nextDouble(); + } + @Override + public String toString() { return "random"; } } public static class SumElements implements Function<List<Integer>, Double> { - private final ImmutableList<String> argumentNames; - private SumElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @@ -161,12 +152,10 @@ public class ScalarFunctions { sum += value; return (double)sum; } - @Override public String toString() { return argumentNames.stream().collect(Collectors.joining("+")); } - } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index 45f78389c16..c856b548180 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -22,8 +22,8 @@ public class Softmax extends CompositeTensorFunction { this.dimension = dimension; } - public static TensorType resultType(TensorType type, String dimension) { - return Reduce.resultType(type, ImmutableList.of(dimension)); + public static TensorType outputType(TensorType inputType, String dimension) { + return Reduce.outputType(inputType, ImmutableList.of(dimension)); } @Override |