aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-01-26 13:45:36 +0100
committerLester Solbakken <lesters@oath.com>2018-01-26 14:14:29 +0100
commit089a765734b1791995510e97a0852bd7a89b3c0b (patch)
tree46e70b9e7855dc8cd1ecee119c868a555dcd89af /searchlib/src/main/java
parent780264290b9e15f0594991b5dba8f1dc2021f92d (diff)
Refactor tensorflow import and add dropout test
Diffstat (limited to 'searchlib/src/main/java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java132
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java750
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java211
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java9
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java9
6 files changed, 775 insertions, 342 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java
new file mode 100644
index 00000000000..5f0c016881a
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java
@@ -0,0 +1,132 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorProto;
+import org.tensorflow.framework.TensorShapeProto;
+
+/**
+ * @author lesters
+ */
+public class AttrValueConverter {
+
+ public static Tensor toVespaTensor(NodeDef tfNode, String attr) {
+ if (!tfNode.getAttrMap().containsKey(attr)) {
+ throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr);
+ }
+ AttrValue attrValue = tfNode.getAttrMap().get(attr);
+ switch (attrValue.getValueCase()) {
+ case TENSOR:
+ return buildFromTensor(attrValue);
+ case B:
+ return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0);
+ case F:
+ return buildFromSingleValue(attrValue.getF());
+ case I:
+ return buildFromSingleValue(attrValue.getI());
+ }
+
+ throw new IllegalArgumentException(tfNode.getName() +
+ ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'");
+ }
+
+ private static Tensor buildFromSingleValue(double value) {
+ TensorType type = new TensorType.Builder().build();
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ builder.cellByDirectIndex(0, value);
+ return builder.build();
+ }
+
+ private static Tensor buildFromTensor(AttrValue attrValue) {
+ TensorProto tensorProto = attrValue.getTensor();
+ TensorType type = toVespaTensorType(tensorProto.getTensorShape());
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ Values values = valuesOf(tensorProto);
+ for (int i = 0; i < values.size(); ++i) {
+ builder.cellByDirectIndex(i, values.get(i));
+ }
+ Tensor tensor = builder.build();
+ return tensor;
+ }
+
+ private static Values valuesOf(TensorProto tensorProto) {
+ switch (tensorProto.getDtype()) {
+ case DT_BOOL:
+ return new BoolValues(tensorProto);
+ case DT_HALF:
+ return new HalfValues(tensorProto);
+ case DT_INT16:
+ case DT_INT32:
+ return new IntValues(tensorProto);
+ case DT_INT64:
+ return new Int64Values(tensorProto);
+ case DT_FLOAT:
+ return new FloatValues(tensorProto);
+ case DT_DOUBLE:
+ return new DoubleValues(tensorProto);
+ }
+
+ throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
+ }
+
+ public static TensorType toVespaTensorType(TensorShapeProto shapeProto) {
+ TensorType.Builder b = new TensorType.Builder();
+ for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) {
+ int dimensionSize = (int)dimension.getSize();
+ if (dimensionSize >= 0)
+ b.indexed("d" + b.rank(), dimensionSize);
+ else
+ b.indexed("d" + b.rank()); // unbound size
+ }
+ return b.build();
+ }
+
+ private static abstract class Values {
+ protected final TensorProto tensorProto;
+ protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; }
+ abstract double get(int i);
+ abstract int size();
+ }
+
+ private static class BoolValues extends Values {
+ BoolValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
+ @Override int size() { return tensorProto.getBoolValCount(); }
+ }
+
+ private static class HalfValues extends Values {
+ HalfValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getHalfVal(i); }
+ @Override int size() { return tensorProto.getHalfValCount(); }
+ }
+
+ private static class IntValues extends Values {
+ IntValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getIntVal(i); }
+ @Override int size() { return tensorProto.getIntValCount(); }
+ }
+
+ private static class Int64Values extends Values {
+ Int64Values(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getInt64Val(i); }
+ @Override int size() { return tensorProto.getInt64ValCount(); }
+ }
+
+ private static class FloatValues extends Values {
+ FloatValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getFloatVal(i); }
+ @Override int size() { return tensorProto.getFloatValCount(); }
+ }
+
+ private static class DoubleValues extends Values {
+ DoubleValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getDoubleVal(i); }
+ @Override int size() { return tensorProto.getDoubleValCount(); }
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
index 85452d16a77..816ef38e128 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
@@ -24,157 +24,318 @@ import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
+import java.util.function.Function;
import java.util.stream.Collectors;
-import java.util.stream.StreamSupport;
+import java.util.stream.Stream;
/**
* Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
*
* @author bratseth
+ * @author lesters
*/
class OperationMapper {
+ // A note on conversion from implicitly numbered to explicitly named dimensions:
+ //
+ // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
+ // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
+ // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
+ // around dimension renaming operations which mirrors those built into the TF operation definitions.
+ //
+ // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
+ // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
+ // and the result is then renamed again (if necessary) to recover this convention across a full nested
+ // computation.
+ //
+ // This requires us to track tensor types throughout the conversion.
+
+
+ // Supported TensorFlow operations
+ enum Operation {
+
+ // TODO: move the implementations to specific files as we support more operations
+
+ /*
+ * array ops
+ */
+ CONST (OperationMapper::constant),
+ EXPANDDIMS (OperationMapper::expandDims),
+ IDENTITY (OperationMapper::identity),
+ PLACEHOLDER (OperationMapper::placeholder),
+ PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault),
+ RESHAPE (OperationMapper::reshape),
+ SQUEEZE (OperationMapper::squeeze),
+
+ /*
+ * control flow
+ */
+ MERGE (OperationMapper::merge),
+ SWITCH (OperationMapper::switchOp),
+
+ /*
+ * math ops
+ */
+ ADD (OperationMapper::add),
+ ADD_N (OperationMapper::add),
+ ACOS (OperationMapper::acos),
+ DIV (OperationMapper::div),
+ REALDIV (OperationMapper::div),
+ FLOOR (OperationMapper::floor),
+ MATMUL (OperationMapper::matmul),
+ MAXIMUM (OperationMapper::maximum),
+ MEAN (OperationMapper::mean),
+ REDUCEMEAN (OperationMapper::mean),
+ MUL (OperationMapper::mul),
+ MULTIPLY (OperationMapper::mul),
+ RSQRT (OperationMapper::rsqrt),
+ SELECT (OperationMapper::select),
+ WHERE3 (OperationMapper::select),
+ SIGMOID (OperationMapper::sigmoid),
+ SQUAREDDIFFERENCE (OperationMapper::squaredDifference),
+ SUB (OperationMapper::sub),
+ SUBTRACT (OperationMapper::sub),
+
+ /*
+ * nn ops
+ */
+ BIASADD (OperationMapper::add),
+ ELU (OperationMapper::elu),
+ RELU (OperationMapper::relu),
+ SELU (OperationMapper::selu),
+ SOFTMAX (OperationMapper::softMax),
+
+ /*
+ * state ops
+ */
+ VARIABLE (OperationMapper::variable),
+ VARIABLEV2 (OperationMapper::variable),
+
+ /*
+ * evaluation no-ops
+ */
+ STOPGRADIENT (OperationMapper::identity),
+ NOOP (OperationMapper::noOp);
+
+
+ private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func;
+
+ Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) {
+ this.func = func;
+ }
+
+ Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) {
+ return func.apply(params);
+ }
+
+ }
+
+ static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) {
+ Optional<Operation> operation = Stream.of(Operation.values())
+ .filter(op -> op.name().equalsIgnoreCase(params.node().getOp()))
+ .findFirst();
+ if (operation.isPresent()) {
+ return operation.get().map(params);
+ }
+ params.signature().importWarning("TensorFlow operation '" + params.node().getOp() +
+ "' in node '" + params.node().getName() + "' is not supported.");
+ return Optional.empty();
+ }
+
+
/*
- A note on conversion from implicitly numbered to explicitly named dimensions:
- Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
- 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
- comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
- around dimension renaming operations which mirrors those built into the TF operation definitions.
-
- To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
- dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
- and the result is then renamed again (if necessary) to recover this convention across a full nested
- computation.
-
- This requires us to track tensor types throughout the conversion.
+ * Operations
*/
- private TensorConverter tensorConverter = new TensorConverter();
+ private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
+ Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
+ return createConstant(params, value);
+ }
- TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) {
- ensureArguments(2, arguments, "join");
- TypedTensorFunction a = arguments.get(0);
- TypedTensorFunction b = arguments.get(1);
+ private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 2)) {
+ return Optional.empty();
+ }
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
- if (a.type().rank() == 0 && b.type().rank() > 0) {
- return new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction));
+ Tensor axis = getConstantTensor(params, params.node().getInput(1));
+ if (axis.type().rank() != 0) {
+ throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar");
}
- if (b.type().rank() == 0 && a.type().rank() > 0) {
- return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction));
+
+ TensorFunction inputFunction = inputs.get(0).get().function();
+ TensorType inputType = inputs.get(0).get().type();
+
+ int dimensionToInsert = (int)axis.asDouble();
+ if (dimensionToInsert < 0) {
+ dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- if (a.type().rank() == b.type().rank()) {
- return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction));
+
+ TensorType.Builder outputTypeBuilder = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (int i = 0; i < inputType.dimensions().size() + 1; ++i) {
+ String name = String.format("temp_%d", i);
+ Long size;
+ if (i == dimensionToInsert) {
+ size = 1L;
+ } else {
+ size = dimensionSize(inputType.dimensions().get(dimensionIndex));
+ dimensionIndex++;
+ }
+ outputTypeBuilder.indexed(name, size);
}
- // Well now we have entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // I'm not able to extract from that any unambiguous specification of which dimensions
- // should be "stretched" when the tensor do not have the same number of dimensions.
- // From trying this with TensorFlow it appears that the second tensor is matched to the
- // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
- // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+ return reshape(inputFunction, inputType, outputTypeBuilder.build());
+ }
- if (a.type().rank() > b.type().rank()) {
- TensorFunction renameFunction = renameForBroadcast(a, b);
- return new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction));
+ private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 1)) {
+ return Optional.empty();
}
- TensorFunction renameFunction = renameForBroadcast(b, a);
- return new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction));
+ return params.inputs().get(0);
}
- private TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) {
- List<String> renameFrom = new ArrayList<>();
- List<String> renameTo = new ArrayList<>();
- int sizeDifference = a.type().rank() - b.type().rank();
- for (int i = 0; i < b.type().rank(); i++) {
- renameFrom.add(b.type().dimensions().get(i).name());
- renameTo.add("d" + (sizeDifference + i));
+ private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) {
+ String name = params.node().getName();
+ TensorType type = params.result().arguments().get(name);
+ if (type == null) {
+ throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
+ "', but there is no such placeholder");
}
- return new Rename(b.function(), renameFrom, renameTo);
+ // Included literally in the expression and so must be produced by a separate macro in the rank profile
+ TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name));
+ return Optional.of(output);
+ }
+
+ private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
+ String name = params.node().getInput(0);
+ Tensor defaultValue = getConstantTensor(params, name);
+ params.result().constant(name, defaultValue);
+ params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
+ // The default value will be provided by the macro. Users can override macro to change value.
+ TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
+ return Optional.of(output);
}
- TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) {
- ensureArguments(1, arguments, "apply");
- TypedTensorFunction a = arguments.get(0);
+ private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 2)) {
+ return Optional.empty();
+ }
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ Tensor shape = getConstantTensor(params, params.node().getInput(1));
- TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
- com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
- return new TypedTensorFunction(resultType, function);
+ TensorFunction inputFunction = inputs.get(0).get().function();
+ TensorType inputType = inputs.get(0).get().type();
+
+ TensorType.Builder outputTypeBuilder = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int size = cell.getValue().intValue();
+ if (size < 0) {
+ size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue();
+ }
+ outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size);
+ dimensionIndex++;
+ }
+ return reshape(inputFunction, inputType, outputTypeBuilder.build());
}
- TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) {
- String name = tfNode.getName();
- TensorType type = result.arguments().get(name);
- if (type == null)
- throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
- "', but there is no such placeholder");
- // Included literally in the expression and so must be produced by a separate macro in the rank profile
- return new TypedTensorFunction(type, new VariableTensor(name));
+ private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 1)) {
+ return Optional.empty();
+ }
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+
+ TensorFunction inputFunction = inputs.get(0).get().function();
+ TensorType inputType = inputs.get(0).get().type();
+ List<String> squeezeDimensions;
+
+ AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims");
+ if (squeezeDimsAttr == null) {
+ squeezeDimensions = inputType.dimensions().stream().
+ filter(dim -> dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ } else {
+ squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
+ map(i -> i < 0 ? inputType.dimensions().size() - i : i).
+ map(i -> inputType.dimensions().get(i.intValue())).
+ filter(dim -> dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ }
+
+ if (squeezeDimensions.isEmpty()) {
+ return inputs.get(0);
+ }
+
+ TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
+ TensorType outputType = Reduce.outputType(inputType, squeezeDimensions);
+ TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
+ return Optional.of(output);
}
- TypedTensorFunction placeholderWithDefault(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
- String name = tfNode.getInput(0);
- Tensor defaultValue = getConstantTensor(model, name);
- result.constant(name, defaultValue);
- result.macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
- // The default value will be provided by the macro. Users can override macro to change value.
- return new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
+ private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) {
+ return params.inputs().stream()
+ .filter(Optional::isPresent)
+ .findFirst()
+ .orElse(Optional.empty());
}
- TypedTensorFunction constant(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
- String name = tfNode.getName();
- if (tfNode.getInputList().size() != 0) {
- throw new IllegalArgumentException("A constant node must have zero inputs but '" + name + "' has " +
- tfNode.getInputList().size());
+ private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 2)) {
+ return Optional.empty();
+ }
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ Tensor predicate = getConstantTensor(params, params.node().getInput(1));
+ if (predicate.type().rank() != 0) {
+ throw new IllegalArgumentException("'switch': predicate must be a scalar");
+ }
+ double pred = predicate.asDouble();
+ int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0;
+ if (output < 0 || output > 1) {
+ throw new IllegalArgumentException("'switch': predicate is not boolean");
+ }
+ if (pred == output) {
+ return inputs.get(0);
}
- return importConstantTensor(tfNode, model, result, name);
+ return Optional.empty();
}
- TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) {
- if ( ! tfNode.getName().endsWith("/read"))
- throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " +
- "nodes are only supported when reading variables");
- if (tfNode.getInputList().size() != 1)
- throw new IllegalArgumentException("A Variable/read node must have one input but '" +
- tfNode.getName() + "' has " + tfNode.getInputList().size());
+ private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) {
+ return join(params, ScalarFunctions.add());
+ }
- String name = tfNode.getInput(0);
- return importConstantTensor(tfNode, model, result, name);
+ private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.acos());
}
- private TypedTensorFunction importConstantTensor(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, String name) {
- AttrValue shapes = tfNode.getAttrMap().get("_output_shapes");
- if (shapes == null)
- throw new IllegalArgumentException("'" + name + "' is missing a tensor shape");
- Tensor constant = getConstantTensor(model, name);
- result.constant(name, constant);
- return new TypedTensorFunction(constant.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(\"" + name + "\")")));
+ private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) {
+ return join(params, ScalarFunctions.divide());
}
- private Tensor getConstantTensor(SavedModelBundle model, String name) {
- Session.Runner fetched = model.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " +
- importedTensors.size());
- return tensorConverter.toVespaTensor(importedTensors.get(0));
+ private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.floor());
}
- TypedTensorFunction matmul(List<TypedTensorFunction> arguments) {
- ensureArguments(2, arguments, "matmul");
- TypedTensorFunction a = arguments.get(0);
- TypedTensorFunction b = arguments.get(1);
+ private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) {
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ if (!checkInputs(params, 2)) {
+ return Optional.empty();
+ }
+
+ TypedTensorFunction a = inputs.get(0).get();
+ TypedTensorFunction b = inputs.get(1).get();
if (a.type().rank() < 2 || b.type().rank() < 2)
throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
if (a.type().rank() != b.type().rank())
@@ -190,17 +351,24 @@ class OperationMapper {
Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
ImmutableList.of("d1", afterLastDim));
Matmul matmul = new Matmul(a.function(), renamedB, "d1");
- return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
+ TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
new Rename(matmul, afterLastDim, "d1"));
+ return Optional.of(output);
}
- TypedTensorFunction mean(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) {
- ensureArguments(2, arguments, "mean");
- Tensor reductionIndices = getConstantTensor(model, tfNode.getInput(1));
+ private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) {
+ return join(params, ScalarFunctions.max());
+ }
- TensorFunction inputFunction = arguments.get(0).function();
- TensorType inputType = arguments.get(0).type();
+ private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 2)) {
+ return Optional.empty();
+ }
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ TensorFunction inputFunction = inputs.get(0).get().function();
+ TensorType inputType = inputs.get(0).get().type();
+ Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1));
List<String> reduceDimensions = new ArrayList<>();
for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -214,122 +382,195 @@ class OperationMapper {
TensorType outputType = Reduce.outputType(inputType, reduceDimensions);
TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
- if (shouldKeepDimensions(tfNode)) {
+ if (shouldKeepDimensions(params)) {
return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions));
}
-
TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return output;
+ return Optional.of(output);
}
- private boolean shouldKeepDimensions(NodeDef tfNode) {
- AttrValue keepDimsAttr = tfNode.getAttrMap().get("keep_dims");
- return keepDimsAttr != null && keepDimsAttr.getB();
+ private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) {
+ return join(params, ScalarFunctions.multiply());
}
- private TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) {
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension: inputType.dimensions()) {
- String name = dimension.name();
- Long size = dimensionSize(dimension);
- if (reduceDimensions.contains(name)) {
- size = 1L;
- }
- builder.indexed(name, size);
- }
- return builder.build();
+ private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.rsqrt());
}
- private TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) {
- for (int i = 0; i < type.dimensions().size(); ++i) {
- String correct = String.format("d%d", i);
- String current = type.dimensions().get(i).name();
- if (!current.equals(correct)) {
- return fixNamingConvention(type, function);
- }
+ private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 3)) {
+ return Optional.empty();
}
- return new TypedTensorFunction(type, function);
- }
+ Tensor condition = getConstantTensor(params, params.node().getInput(0));
- private TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) {
- TensorType.Builder correctType = new TensorType.Builder();
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (int i = 0; i < type.dimensions().size(); ++i) {
- String correct = String.format("d%d", i);
- String current = type.dimensions().get(i).name();
- if (!current.equals(correct)) {
- from.add(current);
- to.add(correct);
- }
- correctType.indexed(correct, dimensionSize(type.dimensions().get(i)));
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ TypedTensorFunction x = inputs.get(1).get();
+ TypedTensorFunction y = inputs.get(2).get();
+ if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) {
+ throw new IllegalArgumentException("'Select': input tensors must have the same shape");
}
- if (from.size() > 0) {
- function = new Rename(function, from, to);
- type = correctType.build();
+
+ if (condition.type().rank() == 0) {
+ return Optional.of((int)condition.asDouble() == 0 ? y : x);
}
- return new TypedTensorFunction(type, function);
+ if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
+ return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x);
+ }
+
+ // The task is to select cells from 'x' or 'y' based on 'condition'.
+ // If 'condition' is 0 (false), select from 'y', if 1 (true) select
+ // from 'x'. We do this by individually joining 'x' and 'y' with
+ // 'condition', and then joining the resulting two tensors.
+
+ Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0));
+ if (!conditionFunction.isPresent()) {
+ return Optional.empty();
+ }
+ TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply());
+ TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() {
+ @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
+ @Override public String toString() { return "f(a,b)(a * (1-b))"; }
+ });
+ TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add());
+ TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction);
+ return Optional.of(output);
}
- TypedTensorFunction noOp(List<TypedTensorFunction> arguments) {
- ensureArguments(1, arguments, "noOp");
- return arguments.get(0);
+ private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.sigmoid());
}
- TypedTensorFunction expandDims(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) {
- ensureArguments(2, arguments, "expandDims");
- Tensor axis = getConstantTensor(model, tfNode.getInput(1));
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar");
+ private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) {
+ return join(params, ScalarFunctions.squareddifference());
+ }
+
+ private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) {
+ return join(params, ScalarFunctions.subtract());
+ }
+
+ private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.elu());
+ }
+
+ private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.relu());
+ }
+
+ private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) {
+ return map(params, ScalarFunctions.selu());
+ }
+
+ private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) {
+ if (!checkInputs(params, 1)) {
+ return Optional.empty();
}
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ TypedTensorFunction a = inputs.get(0).get();
+ // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
+ String dimension = "d" + (a.type().rank() - 1);
+ Softmax softmax = new Softmax(a.function(), dimension);
+ TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
+ return Optional.of(output);
+ }
- TensorFunction inputFunction = arguments.get(0).function();
- TensorType inputType = arguments.get(0).type();
+ private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) {
+ return importConstantTensor(params, params.node().getName());
+ }
- int dimensionToInsert = (int)axis.asDouble();
- if (dimensionToInsert < 0) {
- dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
+ private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) {
+ return Optional.empty();
+ }
+
+ /*
+ * Utility
+ */
+
+ private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) {
+ if (!checkInputs(params, 2)) {
+ return Optional.empty();
}
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TensorType.Builder outputTypeBuilder = new TensorType.Builder();
- int dimensionIndex = 0;
- for (int i = 0; i < inputType.dimensions().size() + 1; ++i) {
- String name = String.format("temp_%d", i);
- Long size;
- if (i == dimensionToInsert) {
- size = 1L;
- } else {
- size = dimensionSize(inputType.dimensions().get(dimensionIndex));
- dimensionIndex++;
- }
- outputTypeBuilder.indexed(name, size);
+ TypedTensorFunction a = inputs.get(0).get();
+ TypedTensorFunction b = inputs.get(1).get();
+
+ if (a.type().rank() == 0 && b.type().rank() > 0) {
+ return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction)));
+ }
+ if (b.type().rank() == 0 && a.type().rank() > 0) {
+ return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)));
+ }
+ if (a.type().rank() == b.type().rank()) {
+ return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)));
}
- return reshape(inputFunction, inputType, outputTypeBuilder.build());
+ // Well now we have entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+
+ if (a.type().rank() > b.type().rank()) {
+ TensorFunction renameFunction = renameForBroadcast(a, b);
+ return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction)));
+ }
+ TensorFunction renameFunction = renameForBroadcast(b, a);
+ return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction)));
+ }
+
+ private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) {
+ List<String> renameFrom = new ArrayList<>();
+ List<String> renameTo = new ArrayList<>();
+ int sizeDifference = a.type().rank() - b.type().rank();
+ for (int i = 0; i < b.type().rank(); i++) {
+ renameFrom.add(b.type().dimensions().get(i).name());
+ renameTo.add("d" + (sizeDifference + i));
+ }
+ return new Rename(b.function(), renameFrom, renameTo);
}
- TypedTensorFunction reshape(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) {
- ensureArguments(2, arguments, "reshape");
- Tensor shape = getConstantTensor(model, tfNode.getInput(1));
+ private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) {
+ if (!checkInputs(params, 1)) {
+ return Optional.empty();
+ }
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ TypedTensorFunction a = inputs.get(0).get();
+ TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
+ com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
+ return Optional.of(new TypedTensorFunction(resultType, function));
+ }
- TensorFunction inputFunction = arguments.get(0).function();
- TensorType inputType = arguments.get(0).type();
+ private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
+ params.result().constant(params.node().getName(), constant);
+ TypedTensorFunction output = new TypedTensorFunction(constant.type(),
+ new TensorFunctionNode.TensorFunctionExpressionNode(
+ new ReferenceNode("constant(\"" + params.node().getName() + "\")")));
+ return Optional.of(output);
+ }
- TensorType.Builder outputTypeBuilder = new TensorType.Builder();
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
- if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue();
- }
- outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size);
- dimensionIndex++;
+ private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
+ if (params.result().constants().containsKey(name)) {
+ return params.result().constants().get(name);
}
- return reshape(inputFunction, inputType, outputTypeBuilder.build());
+ Session.Runner fetched = params.model().session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " +
+ importedTensors.size());
+ return TensorConverter.toVespaTensor(importedTensors.get(0));
+ }
+
+ private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) {
+ AttrValue shapes = params.node().getAttrMap().get("_output_shapes");
+ if (shapes == null)
+ throw new IllegalArgumentException("'" + name + "' is missing a tensor shape");
+ Tensor constant = getConstantTensor(params, name);
+ return createConstant(params, constant);
}
- private TypedTensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
+ private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
if (!tensorSize(inputType).equals(tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}
@@ -353,10 +594,10 @@ class OperationMapper {
Reduce.Aggregator.sum,
inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return output;
+ return Optional.of(output);
}
- private ExpressionNode unrollTensorExpression(TensorType type) {
+ private static ExpressionNode unrollTensorExpression(TensorType type) {
if (type.rank() == 0) {
return new ConstantNode(DoubleValue.zero);
}
@@ -378,80 +619,56 @@ class OperationMapper {
return new ArithmeticNode(children, operators);
}
- TypedTensorFunction select(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, List<TypedTensorFunction> arguments) {
- ensureArguments(3, arguments, "select");
- Tensor condition = getConstantTensor(model, tfNode.getInput(0));
-
- TypedTensorFunction x = arguments.get(1);
- TypedTensorFunction y = arguments.get(2);
- if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) {
- throw new IllegalArgumentException("'Select': input tensors must have the same shape");
- }
+ private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) {
+ AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims");
+ return keepDimsAttr != null && keepDimsAttr.getB();
+ }
- if (condition.type().rank() == 0) {
- return (int)condition.asDouble() == 0 ? y : x;
- }
- if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
- return condition.cellIterator().next().getValue().intValue() == 0 ? y : x;
+ private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) {
+ TensorType.Builder builder = new TensorType.Builder();
+ for (TensorType.Dimension dimension: inputType.dimensions()) {
+ String name = dimension.name();
+ Long size = dimensionSize(dimension);
+ if (reduceDimensions.contains(name)) {
+ size = 1L;
+ }
+ builder.indexed(name, size);
}
-
- // The task is to select cells from 'x' or 'y' based on 'condition'.
- // If 'condition' is 0 (false), select from 'y', if 1 (true) select
- // from 'x'. We do this by individually joining 'x' and 'y' with
- // 'condition', and then joining the resulting two tensors.
-
- TypedTensorFunction conditionFunction = importConstantTensor(tfNode, model, result, tfNode.getInput(0));
- TensorFunction xCond = new Join(x.function(), conditionFunction.function(), ScalarFunctions.multiply());
- TensorFunction yCond = new Join(y.function(), conditionFunction.function(), new DoubleBinaryOperator() {
- @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
- @Override public String toString() { return "f(a,b)(a * (1-b))"; }
- });
- TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add());
- return new TypedTensorFunction(x.type(), outputFunction);
+ return builder.build();
}
- TypedTensorFunction softmax(List<TypedTensorFunction> arguments) {
- ensureArguments(1, arguments, "softmax");
- TypedTensorFunction a = arguments.get(0);
- // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
- String dimension = "d" + (a.type().rank() - 1);
- Softmax softmax = new Softmax(a.function(), dimension);
- return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
+ private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) {
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ String correct = String.format("d%d", i);
+ String current = type.dimensions().get(i).name();
+ if (!current.equals(correct)) {
+ return fixNamingConvention(type, function);
+ }
+ }
+ return new TypedTensorFunction(type, function);
}
- TypedTensorFunction squeeze(NodeDef tfNode, List<TypedTensorFunction> arguments) {
- ensureArguments(1, arguments, "squeeze");
-
- TensorFunction inputFunction = arguments.get(0).function();
- TensorType inputType = arguments.get(0).type();
- List<String> squeezeDimensions;
-
- AttrValue squeezeDimsAttr = tfNode.getAttrMap().get("squeeze_dims");
- if (squeezeDimsAttr == null) {
- squeezeDimensions = inputType.dimensions().stream().
- filter(dim -> dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- } else {
- squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
- map(i -> i < 0 ? inputType.dimensions().size() - i : i).
- map(i -> inputType.dimensions().get(i.intValue())).
- filter(dim -> dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
+ private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) {
+ TensorType.Builder correctType = new TensorType.Builder();
+ List<String> from = new ArrayList<>();
+ List<String> to = new ArrayList<>();
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ String correct = String.format("d%d", i);
+ String current = type.dimensions().get(i).name();
+ if (!current.equals(correct)) {
+ from.add(current);
+ to.add(correct);
+ }
+ correctType.indexed(correct, dimensionSize(type.dimensions().get(i)));
}
-
- if (squeezeDimensions.isEmpty()) {
- return arguments.get(0);
+ if (from.size() > 0) {
+ function = new Rename(function, from, to);
+ type = correctType.build();
}
-
- TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
- TensorType outputType = Reduce.outputType(inputType, squeezeDimensions);
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return output;
+ return new TypedTensorFunction(type, function);
}
- private Long tensorSize(TensorType type) {
+ private static Long tensorSize(TensorType type) {
Long size = 1L;
for (TensorType.Dimension dimension : type.dimensions()) {
size *= dimensionSize(dimension);
@@ -459,14 +676,21 @@ class OperationMapper {
return size;
}
- private Long dimensionSize(TensorType.Dimension dim) {
+ private static Long dimensionSize(TensorType.Dimension dim) {
return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
}
- private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) {
- if ( arguments.size() != count)
- throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName +
- ", but got " + arguments.size());
+ private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) {
+ List<Optional<TypedTensorFunction>> inputs = params.inputs();
+ if (!inputs.stream().allMatch(Optional::isPresent)) {
+ return false;
+ }
+ if (inputs.size() != expected) {
+ params.signature().importWarning("Expected " + expected +
+ " arguments to " + params.node().getOp() + ", but got " + inputs.size());
+ return false;
+ }
+ return true;
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
index 8edb9b9b7a1..b88ffce275a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
@@ -19,7 +19,7 @@ import java.nio.LongBuffer;
*/
public class TensorConverter {
- public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
TensorType type = toVespaTensorType(tfTensor.shape());
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
@@ -28,7 +28,7 @@ public class TensorConverter {
return builder.build();
}
- private TensorType toVespaTensorType(long[] shape) {
+ private static TensorType toVespaTensorType(long[] shape) {
TensorType.Builder b = new TensorType.Builder();
int dimensionIndex = 0;
for (long dimensionSize : shape) {
@@ -38,7 +38,7 @@ public class TensorConverter {
return b.build();
}
- private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
switch (tfTensor.dataType()) {
case DOUBLE: return new DoubleValues(tfTensor);
case FLOAT: return new FloatValues(tfTensor);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 4780c39d21d..3a6b3f23a1d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -4,7 +4,6 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.framework.GraphDef;
@@ -12,12 +11,14 @@ import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
-import org.tensorflow.framework.TensorShapeProto;
import java.io.File;
import java.io.IOException;
+import java.util.ArrayList;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.stream.Collectors;
/**
@@ -27,8 +28,6 @@ import java.util.stream.Collectors;
*/
public class TensorFlowImporter {
- private final OperationMapper operationMapper = new OperationMapper();
-
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a .pbtxt or .pb file.
@@ -68,9 +67,21 @@ public class TensorFlowImporter {
for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
String outputName = output.getKey();
try {
- NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef());
- importNode(node, graph.getGraphDef(), model, result);
- signature.output(outputName, nameOf(output.getValue().getName()));
+ NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef());
+ Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, "");
+
+ // Commonly, there are multiple paths through a TensorFlow graph, for instance for
+ // training and testing/evaluation. Examples are dropout and batch norm. For Vespa
+ // we are not concerned with training paths, so we can ignore non-supported operations
+ // as long as they are on a path that will not be evaluated run time. Operations
+ // that fail import will not have a value present in the optionals. However, the
+ // final output node must have value present. It is an error if it does not.
+
+ Optional<TypedTensorFunction> outputFunction = importNode(params);
+ if (!outputFunction.isPresent()) {
+ throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n")));
+ }
+ signature.output(outputName, namePartOf(output.getValue().getName()));
}
catch (IllegalArgumentException e) {
signature.skippedOutput(outputName, Exceptions.toMessageString(e));
@@ -82,93 +93,59 @@ public class TensorFlowImporter {
private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) {
inputInfoMap.forEach((key, value) -> {
- String argumentName = nameOf(value.getName());
- TensorType argumentType = importTensorType(value.getTensorShape());
+ String argumentName = namePartOf(value.getName());
+ TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape());
// Arguments are (Placeholder) nodes, so not local to the signature:
signature.owner().argument(argumentName, argumentType);
signature.input(key, argumentName);
});
}
- private TensorType importTensorType(TensorShapeProto tensorShape) {
- TensorType.Builder b = new TensorType.Builder();
- for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) {
- int dimensionSize = (int)dimension.getSize();
- if (dimensionSize >= 0)
- b.indexed("d" + b.rank(), dimensionSize);
- else
- b.indexed("d" + b.rank()); // unbound size
+ /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
+ private Optional<TypedTensorFunction> importNode(Parameters params) {
+ String nodeName = params.node().getName();
+ if (params.imported().containsKey(nodeName)) {
+ return Optional.of(params.imported().get(nodeName));
}
- return b.build();
- }
- /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
- TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result);
+ Optional<TypedTensorFunction> function = OperationMapper.map(params);
+ if (!function.isPresent()) {
+ return Optional.empty();
+ }
+ if (!controlDependenciesArePresent(params)) {
+ return Optional.empty();
+ }
+ params.imported().put(nodeName, function.get());
+
try {
// We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
// will be used. We parse the TensorFunction here to convert it to a RankingExpression tree
- result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), function.function().toString()));
+ params.result().expression(nodeName,
+ new RankingExpression(params.node().getName(), function.get().function().toString()));
return function;
}
catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function.function() +
+ throw new RuntimeException("Tensorflow function " + function.get().function() +
" cannot be parsed as a ranking expression", e);
}
}
+ private boolean controlDependenciesArePresent(Parameters params) {
+ return params.node().getInputList().stream()
+ .filter(TensorFlowImporter::isControlDependency)
+ .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName))))
+ .allMatch(Optional::isPresent);
+ }
-
- private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) {
- // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops
- // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/
- switch (tfNode.getOp().toLowerCase()) {
- // array ops
- case "const" : return operationMapper.constant(tfNode, model, result);
- case "expanddims" : return operationMapper.expandDims(tfNode, model, importArguments(tfNode, graph, model, result));
- case "identity" : return operationMapper.identity(tfNode, model, result);
- case "placeholder" : return operationMapper.placeholder(tfNode, result);
- case "placeholderwithdefault" : return operationMapper.placeholderWithDefault(tfNode, model, result);
- case "reshape" : return operationMapper.reshape(tfNode, model, importArguments(tfNode, graph, model, result));
- case "squeeze" : return operationMapper.squeeze(tfNode, importArguments(tfNode, graph, model, result));
-
- // math ops
- case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
- case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos());
- case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result));
- case "maximum" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.max());
- case "mean" : case "reducemean": return operationMapper.mean(tfNode, model, importArguments(tfNode, graph, model, result));
- case "multiply": case "mul" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.multiply());
- case "rsqrt": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.rsqrt());
- case "where3": case "select" : return operationMapper.select(tfNode, model, result, importArguments(tfNode, graph, model, result));
- case "sigmoid": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.sigmoid());
- case "squareddifference" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.squareddifference());
- case "subtract" : case "sub" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.subtract());
-
- // nn ops
- case "biasadd" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add());
- case "elu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.elu());
- case "relu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.relu());
- case "selu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.selu());
- case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result));
-
- // evaluation no-ops
- case "stopgradient" :
- case "noop":
- return operationMapper.noOp(importArguments(tfNode, graph, model, result));
-
- // not supported
- default :
- throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() +
- "' is not supported (" + tfNode.getName() + ")");
- }
+ private static boolean isControlDependency(String nodeName) {
+ return nodeName.startsWith("^");
}
- private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model,
- TensorFlowModel result) {
- return tfNode.getInputList().stream()
- .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result))
- .collect(Collectors.toList());
+ private List<Optional<TypedTensorFunction>> importArguments(Parameters params) {
+ return params.node().getInputList().stream()
+ .filter(nodeName -> !isControlDependency(nodeName))
+ .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName))))
+ .collect(Collectors.toList());
}
private NodeDef getNode(String name, GraphDef graph) {
@@ -182,8 +159,94 @@ public class TensorFlowImporter {
* A method signature input and output has the form name:index.
* This returns the name part without the index.
*/
- private String nameOf(String name) {
+ private static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
return name.split(":")[0];
}
+ /**
+ * This return the index part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ private static String indexPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? "" : name.substring(i + 1);
+ }
+
+
+ private Parameters createParameters(GraphDef graph,
+ SavedModelBundle model,
+ TensorFlowModel result,
+ TensorFlowModel.Signature signature,
+ NodeDef node,
+ String port) {
+ return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port);
+ }
+
+ /** Parameter object to hold important data while importing */
+ static final class Parameters {
+ private final TensorFlowImporter owner;
+ private final GraphDef graph;
+ private final SavedModelBundle model;
+ private final TensorFlowModel result;
+ private final TensorFlowModel.Signature signature;
+ private final Map<String, TypedTensorFunction> imported;
+ private final NodeDef node;
+ private final String port;
+
+ private Parameters(TensorFlowImporter owner,
+ GraphDef graph,
+ SavedModelBundle model,
+ TensorFlowModel result,
+ TensorFlowModel.Signature signature,
+ Map<String, TypedTensorFunction> imported,
+ NodeDef node,
+ String port) {
+ this.owner = owner;
+ this.graph = graph;
+ this.model = model;
+ this.result = result;
+ this.signature = signature;
+ this.imported = imported;
+ this.node = node;
+ this.port = port;
+ }
+
+ GraphDef graph() {
+ return this.graph;
+ }
+
+ SavedModelBundle model() {
+ return this.model;
+ }
+
+ TensorFlowModel result() {
+ return this.result;
+ }
+
+ TensorFlowModel.Signature signature() {
+ return this.signature;
+ }
+
+ Map<String, TypedTensorFunction> imported() {
+ return this.imported;
+ }
+
+ NodeDef node() {
+ return node;
+ }
+
+ String port() {
+ return port;
+ }
+
+ Parameters copy(NodeDef node, String port) {
+ return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port);
+ }
+
+ List<Optional<TypedTensorFunction>> inputs() {
+ return owner.importArguments(this);
+ }
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index 1a6c93384ea..60aaf8ddce1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -5,8 +5,10 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
+import java.util.List;
import java.util.Map;
/**
@@ -67,6 +69,7 @@ public class TensorFlowModel {
private final Map<String, String> inputs = new HashMap<>();
private final Map<String, String> outputs = new HashMap<>();
private final Map<String, String> skippedOutputs = new HashMap<>();
+ private final List<String> importWarnings = new ArrayList<>();
Signature(String name) {
this.name = name;
@@ -75,6 +78,7 @@ public class TensorFlowModel {
void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
void output(String name, String expressionName) { outputs.put(name, expressionName); }
void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ void importWarning(String warning) { importWarnings.add(warning); }
public String name() { return name; }
@@ -99,6 +103,11 @@ public class TensorFlowModel {
*/
public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); }
+ /**
+ * Returns an immutable list of possibly non-fatal warnings encountered during import.
+ */
+ public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
+
/** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */
public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
index 962f9dda0a6..600225bfe76 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
@@ -19,7 +19,12 @@ final class TypedTensorFunction {
this.function = function;
}
- public TensorType type() { return type; }
- public TensorFunction function() { return function; }
+ public TensorType type() {
+ return type;
+ }
+
+ public TensorFunction function() {
+ return function;
+ }
}