summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java90
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java97
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java4
9 files changed, 107 insertions, 101 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java
index 0f08bf0bf21..9ef1a3f6e32 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/FieldUpdateHelper.java
@@ -56,7 +56,7 @@ public abstract class FieldUpdateHelper {
} else if (upd instanceof ArithmeticValueUpdate) {
if (((ArithmeticValueUpdate)upd).getOperator() == ArithmeticValueUpdate.Operator.DIV &&
((ArithmeticValueUpdate)upd).getOperand().doubleValue() == 0) {
- throw new IllegalArgumentException("Division by zero.");
+ throw new IllegalArgumentException("Div by zero.");
}
val.assign(upd.getValue());
return val;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 8dcd31b270e..e47f2ad53d9 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -7,10 +7,8 @@ import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Matmul;
import com.yahoo.tensor.functions.Rename;
@@ -34,6 +32,7 @@ import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.function.DoubleBinaryOperator;
+import java.util.function.DoubleUnaryOperator;
import java.util.stream.Collectors;
/**
@@ -146,11 +145,13 @@ public class TensorFlowImporter {
GraphDef graph,
String indent) {
// Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
- switch (tfNode.getOp()) {
- case "Identity" : return identity(tfNode, inputs);
- case "Add" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add());
- case "MatMul" : return matmul(importArguments(tfNode, inputs, graph, indent));
- case "Softmax" : return softmax(importArguments(tfNode, inputs, graph, indent));
+ // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
+ switch (tfNode.getOp().toLowerCase()) {
+ case "add" : case "add_n" : return join(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.add());
+ case "acos" : return map(importArguments(tfNode, inputs, graph, indent), ScalarFunctions.acos());
+ case "identity" : return identity(tfNode, inputs);
+ case "matmul" : return matmul(importArguments(tfNode, inputs, graph, indent));
+ case "softmax" : return softmax(importArguments(tfNode, inputs, graph, indent));
default : throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + "' is not supported");
}
}
@@ -162,15 +163,51 @@ public class TensorFlowImporter {
}
private TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
+ // Note that this generalizes the corresponding TF function as it does not verify that the tensor
+ // types are the same, with the assumption that this already happened on the TF side
+ // (and if not, this should do the right thing anyway)
ensureArguments(2, arguments, "join");
TypedTensorFunction a = arguments.get(0);
TypedTensorFunction b = arguments.get(0);
- // TODO: Verify with TF doc
- TensorType resultType = Join.resultType(a.type(), b.type());
+
+ TensorType resultType = Join.outputType(a.type(), b.type());
Join function = new Join(a.function(), b.function(), doubleFunction);
return new TypedTensorFunction(resultType, function);
}
-
+
+ private TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
+ ensureArguments(1, arguments, "apply");
+ TypedTensorFunction a = arguments.get(0);
+
+ TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
+ com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
+ return new TypedTensorFunction(resultType, function);
+ }
+
+ private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) {
+ // TODO: Verify with TF documentation
+ String name;
+ TensorType inputType;
+ if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants
+ if (tfNode.getInputList().size() != 1)
+ throw new IllegalArgumentException("A Variable/read node must have one input but has " +
+ tfNode.getInputList().size());
+ name = tfNode.getInput(0);
+ AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape");
+ inputType = importTensorType(shapes.getList().getShape(0));
+ }
+ else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
+ name = tfNode.getName();
+ inputType = inputs.get(name);
+ if (inputType == null)
+ throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
+ "', but there is no such input");
+ }
+ return new TypedTensorFunction(inputType, new VariableTensor(name));
+ }
+
private TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
ensureArguments(2, arguments, "matmul");
TypedTensorFunction a = arguments.get(0);
@@ -183,7 +220,7 @@ public class TensorFlowImporter {
// Let the second-to-last dimension of the second tensor be the same as the last dimension of the first
// and the last dimension of the second argument be not present in the first argument, while leaving the
// rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
-
+
// TODO: Check if transpose_a or transpose_b is set and rename differently accordingly
String beforeLastDim = "d" + (a.type().rank() - 1);
@@ -193,40 +230,17 @@ public class TensorFlowImporter {
Rename renamedB = new Rename(b.function(), ImmutableList.of(beforeLastDim, lastDim),
ImmutableList.of(lastDim, afterLastDim));
Matmul matmul = new Matmul(a.function(), renamedB, lastDim);
- return new TypedTensorFunction(Matmul.resultType(a.type(), b.type(), lastDim),
+ return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), lastDim),
new Rename(matmul, afterLastDim, lastDim));
}
private TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
ensureArguments(1, arguments, "softmax");
TypedTensorFunction a = arguments.get(0);
- String dimension = "d0"; // TODO: Verify with TF doc
+ // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
+ String dimension = "d" + (a.type().rank() - 1);
Softmax softmax = new Softmax(a.function(), dimension);
- return new TypedTensorFunction(Softmax.resultType(a.type(), dimension), softmax);
- }
-
- private TypedTensorFunction identity(NodeDef tfNode, Map<String, TensorType> inputs) {
- // TODO: Verify with TF documentation
- String name;
- TensorType inputType;
- if (tfNode.getName().endsWith("/read")) { // A node reading a variable supplied with this model TODO: We need to turn those into constants
- if (tfNode.getInputList().size() != 1)
- throw new IllegalArgumentException("A Variable/read node must have one input but has " +
- tfNode.getInputList().size());
- name = tfNode.getInput(0);
- AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
- if (shapes == null)
- throw new IllegalArgumentException("Referenced variable '" + name + " is missing a tensor output shape");
- inputType = importTensorType(shapes.getList().getShape(0));
- }
- else { // a referenced input (query or document tensor) TODO: How to map to attribute/query name
- name = tfNode.getName();
- inputType = inputs.get(name);
- if (inputType == null)
- throw new IllegalArgumentException("An identity operation node is referencing input '" + name +
- "', but there is no such input");
- }
- return new TypedTensorFunction(inputType, new VariableTensor(name));
+ return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
}
private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
index 30328c3d9fe..bfe2bb3a63b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporterTestCase.java
@@ -23,7 +23,7 @@ public class TensorFlowImporterTestCase {
"softmax(join(rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " +
"rename(matmul(x, rename(x, (d1, d2), (d2, d3)), d2), d3, d2), " +
"f(a,b)(a + b)), " +
- "d0)",
+ "d1)",
toNonPrimitiveString(expressions.get(0)));
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index c89f63c0395..9a37127e1f0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -47,7 +47,7 @@ public class Join extends PrimitiveTensorFunction {
}
/** Returns the type resulting from applying Join to the two given types */
- public static TensorType resultType(TensorType a, TensorType b) {
+ public static TensorType outputType(TensorType a, TensorType b) {
TensorType.Builder typeBuilder = new TensorType.Builder();
for (int i = 0; i < a.dimensions().size(); ++i) {
TensorType.Dimension aDim = a.dimensions().get(i);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index a9872bb42d8..d322a6ab497 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -6,6 +6,7 @@ import com.google.common.collect.ImmutableMap;
import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import java.util.Collections;
@@ -31,6 +32,8 @@ public class Map extends PrimitiveTensorFunction {
this.argument = argument;
this.mapper = mapper;
}
+
+ public static TensorType outputType(TensorType inputType) { return inputType; }
public TensorFunction argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index cbb3f159623..5e102454487 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -22,8 +22,8 @@ public class Matmul extends CompositeTensorFunction {
this.dimension = dimension;
}
- public static TensorType resultType(TensorType a, TensorType b, String dimension) {
- return Reduce.resultType(Join.resultType(a, b), ImmutableList.of(dimension));
+ public static TensorType outputType(TensorType a, TensorType b, String dimension) {
+ return Reduce.outputType(Join.outputType(a, b), ImmutableList.of(dimension));
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index aa28a26deb2..a51df12e522 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,9 +61,9 @@ public class Reduce extends PrimitiveTensorFunction {
this.dimensions = ImmutableList.copyOf(dimensions);
}
- public static TensorType resultType(TensorType type, List<String> reduceDimensions) {
+ public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
TensorType.Builder b = new TensorType.Builder();
- for (TensorType.Dimension dimension : type.dimensions()) {
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
b.dimension(dimension);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index 99f79cb735a..fb5029fbfd6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -21,101 +21,87 @@ import java.util.stream.Collectors;
@Beta
public class ScalarFunctions {
- public static DoubleBinaryOperator add() { return new Addition(); }
- public static DoubleBinaryOperator multiply() { return new Multiplication(); }
- public static DoubleBinaryOperator divide() { return new Division(); }
+ public static DoubleBinaryOperator add() { return new Add(); }
+ public static DoubleBinaryOperator divide() { return new Divide(); }
public static DoubleBinaryOperator equal() { return new Equal(); }
- public static DoubleUnaryOperator square() { return new Square(); }
+ public static DoubleBinaryOperator multiply() { return new Multiply(); }
+
+ public static DoubleUnaryOperator acos() { return new Acos(); }
+ public static DoubleUnaryOperator exp() { return new Exp(); }
public static DoubleUnaryOperator sqrt() { return new Sqrt(); }
- public static DoubleUnaryOperator exp() { return new Exponent(); }
+ public static DoubleUnaryOperator square() { return new Square(); }
+
public static Function<List<Integer>, Double> random() { return new Random(); }
public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
- public static class Addition implements DoubleBinaryOperator {
+ // Binary operators -----------------------------------------------------------------------------
+ public static class Add implements DoubleBinaryOperator {
@Override
public double applyAsDouble(double left, double right) { return left + right; }
-
@Override
public String toString() { return "f(a,b)(a + b)"; }
-
}
- public static class Multiplication implements DoubleBinaryOperator {
-
+ public static class Equal implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left * right; }
-
+ public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
@Override
- public String toString() { return "f(a,b)(a * b)"; }
-
+ public String toString() { return "f(a,b)(a==b)"; }
}
- public static class Division implements DoubleBinaryOperator {
-
+ public static class Exp implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left / right; }
-
+ public double applyAsDouble(double operand) { return Math.exp(operand); }
@Override
- public String toString() { return "f(a,b)(a / b)"; }
+ public String toString() { return "f(a)(exp(a))"; }
}
- public static class Equal implements DoubleBinaryOperator {
-
+ public static class Multiply implements DoubleBinaryOperator {
@Override
- public double applyAsDouble(double left, double right) { return left == right ? 1 : 0; }
+ public double applyAsDouble(double left, double right) { return left * right; }
+ @Override
+ public String toString() { return "f(a,b)(a * b)"; }
+ }
+ public static class Divide implements DoubleBinaryOperator {
@Override
- public String toString() { return "f(a,b)(a==b)"; }
+ public double applyAsDouble(double left, double right) { return left / right; }
+ @Override
+ public String toString() { return "f(a,b)(a / b)"; }
}
- public static class Square implements DoubleUnaryOperator {
+ // Unary operators ------------------------------------------------------------------------------
+ public static class Acos implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return operand * operand; }
-
+ public double applyAsDouble(double operand) { return Math.acos(operand); }
@Override
- public String toString() { return "f(a)(a * a)"; }
-
+ public String toString() { return "f(a)(acos(a))"; }
}
public static class Sqrt implements DoubleUnaryOperator {
-
@Override
public double applyAsDouble(double operand) { return Math.sqrt(operand); }
-
@Override
public String toString() { return "f(a)(sqrt(a))"; }
-
}
- public static class Exponent implements DoubleUnaryOperator {
+ public static class Square implements DoubleUnaryOperator {
@Override
- public double applyAsDouble(double operand) { return Math.exp(operand); }
+ public double applyAsDouble(double operand) { return operand * operand; }
@Override
- public String toString() { return "f(a)(exp(a))"; }
+ public String toString() { return "f(a)(a * a)"; }
}
- public static class Random implements Function<List<Integer>, Double> {
-
- @Override
- public Double apply(List<Integer> values) {
- return ThreadLocalRandom.current().nextDouble();
- }
-
- @Override
- public String toString() { return "random"; }
+ // Variable-length operators -----------------------------------------------------------------------------
- }
-
- public static class EqualElements implements Function<List<Integer>, Double> {
-
- private final ImmutableList<String> argumentNames;
-
+ public static class EqualElements implements Function<List<Integer>, Double> {
+ private final ImmutableList<String> argumentNames;
private EqualElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -128,7 +114,6 @@ public class ScalarFunctions {
return 0.0;
return 1.0;
}
-
@Override
public String toString() {
if (argumentNames.size() == 0) return "1";
@@ -143,13 +128,19 @@ public class ScalarFunctions {
}
return b.toString();
}
+ }
+ public static class Random implements Function<List<Integer>, Double> {
+ @Override
+ public Double apply(List<Integer> values) {
+ return ThreadLocalRandom.current().nextDouble();
+ }
+ @Override
+ public String toString() { return "random"; }
}
public static class SumElements implements Function<List<Integer>, Double> {
-
private final ImmutableList<String> argumentNames;
-
private SumElements(List<String> argumentNames) {
this.argumentNames = ImmutableList.copyOf(argumentNames);
}
@@ -161,12 +152,10 @@ public class ScalarFunctions {
sum += value;
return (double)sum;
}
-
@Override
public String toString() {
return argumentNames.stream().collect(Collectors.joining("+"));
}
-
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index 45f78389c16..c856b548180 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -22,8 +22,8 @@ public class Softmax extends CompositeTensorFunction {
this.dimension = dimension;
}
- public static TensorType resultType(TensorType type, String dimension) {
- return Reduce.resultType(type, ImmutableList.of(dimension));
+ public static TensorType outputType(TensorType inputType, String dimension) {
+ return Reduce.outputType(inputType, ImmutableList.of(dimension));
}
@Override