diff options
author | Lester Solbakken <lesters@oath.com> | 2018-01-26 13:45:36 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-01-26 14:14:29 +0100 |
commit | 089a765734b1791995510e97a0852bd7a89b3c0b (patch) | |
tree | 46e70b9e7855dc8cd1ecee119c868a555dcd89af /searchlib/src/main/java | |
parent | 780264290b9e15f0594991b5dba8f1dc2021f92d (diff) |
Refactor tensorflow import and add dropout test
Diffstat (limited to 'searchlib/src/main/java')
6 files changed, 775 insertions, 342 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java new file mode 100644 index 00000000000..5f0c016881a --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java @@ -0,0 +1,132 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorProto; +import org.tensorflow.framework.TensorShapeProto; + +/** + * @author lesters + */ +public class AttrValueConverter { + + public static Tensor toVespaTensor(NodeDef tfNode, String attr) { + if (!tfNode.getAttrMap().containsKey(attr)) { + throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr); + } + AttrValue attrValue = tfNode.getAttrMap().get(attr); + switch (attrValue.getValueCase()) { + case TENSOR: + return buildFromTensor(attrValue); + case B: + return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0); + case F: + return buildFromSingleValue(attrValue.getF()); + case I: + return buildFromSingleValue(attrValue.getI()); + } + + throw new IllegalArgumentException(tfNode.getName() + + ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'"); + } + + private static Tensor buildFromSingleValue(double value) { + TensorType type = new TensorType.Builder().build(); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + builder.cellByDirectIndex(0, value); + return builder.build(); + } + + private static Tensor buildFromTensor(AttrValue attrValue) { + TensorProto tensorProto = attrValue.getTensor(); + TensorType type = toVespaTensorType(tensorProto.getTensorShape()); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + Values values = valuesOf(tensorProto); + for (int i = 0; i < values.size(); ++i) { + builder.cellByDirectIndex(i, values.get(i)); + } + Tensor tensor = builder.build(); + return tensor; + } + + private static Values valuesOf(TensorProto tensorProto) { + switch (tensorProto.getDtype()) { + case DT_BOOL: + return new BoolValues(tensorProto); + case DT_HALF: + return new HalfValues(tensorProto); + case DT_INT16: + case DT_INT32: + return new IntValues(tensorProto); + case DT_INT64: + return new Int64Values(tensorProto); + case DT_FLOAT: + return new FloatValues(tensorProto); + case DT_DOUBLE: + return new DoubleValues(tensorProto); + } + + throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + } + + public static TensorType toVespaTensorType(TensorShapeProto shapeProto) { + TensorType.Builder b = new TensorType.Builder(); + for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) { + int dimensionSize = (int)dimension.getSize(); + if (dimensionSize >= 0) + b.indexed("d" + b.rank(), dimensionSize); + else + b.indexed("d" + b.rank()); // unbound size + } + return b.build(); + } + + private static abstract class Values { + protected final TensorProto tensorProto; + protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; } + abstract double get(int i); + abstract int size(); + } + + private static class BoolValues extends Values { + BoolValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } + @Override int size() { return tensorProto.getBoolValCount(); } + } + + private static class HalfValues extends Values { + HalfValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getHalfVal(i); } + @Override int size() { return tensorProto.getHalfValCount(); } + } + + private static class IntValues extends Values { + IntValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getIntVal(i); } + @Override int size() { return tensorProto.getIntValCount(); } + } + + private static class Int64Values extends Values { + Int64Values(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getInt64Val(i); } + @Override int size() { return tensorProto.getInt64ValCount(); } + } + + private static class FloatValues extends Values { + FloatValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getFloatVal(i); } + @Override int size() { return tensorProto.getFloatValCount(); } + } + + private static class DoubleValues extends Values { + DoubleValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getDoubleVal(i); } + @Override int size() { return tensorProto.getDoubleValCount(); } + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java index 85452d16a77..816ef38e128 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java @@ -24,157 +24,318 @@ import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.tensor.functions.Softmax; import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.SavedModelBundle; import org.tensorflow.Session; import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; import java.util.ArrayList; import java.util.Iterator; import java.util.List; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; +import java.util.function.Function; import java.util.stream.Collectors; -import java.util.stream.StreamSupport; +import java.util.stream.Stream; /** * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. * * @author bratseth + * @author lesters */ class OperationMapper { + // A note on conversion from implicitly numbered to explicitly named dimensions: + // + // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being + // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation + // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation + // around dimension renaming operations which mirrors those built into the TF operation definitions. + // + // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' + // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation + // and the result is then renamed again (if necessary) to recover this convention across a full nested + // computation. + // + // This requires us to track tensor types throughout the conversion. + + + // Supported TensorFlow operations + enum Operation { + + // TODO: move the implementations to specific files as we support more operations + + /* + * array ops + */ + CONST (OperationMapper::constant), + EXPANDDIMS (OperationMapper::expandDims), + IDENTITY (OperationMapper::identity), + PLACEHOLDER (OperationMapper::placeholder), + PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault), + RESHAPE (OperationMapper::reshape), + SQUEEZE (OperationMapper::squeeze), + + /* + * control flow + */ + MERGE (OperationMapper::merge), + SWITCH (OperationMapper::switchOp), + + /* + * math ops + */ + ADD (OperationMapper::add), + ADD_N (OperationMapper::add), + ACOS (OperationMapper::acos), + DIV (OperationMapper::div), + REALDIV (OperationMapper::div), + FLOOR (OperationMapper::floor), + MATMUL (OperationMapper::matmul), + MAXIMUM (OperationMapper::maximum), + MEAN (OperationMapper::mean), + REDUCEMEAN (OperationMapper::mean), + MUL (OperationMapper::mul), + MULTIPLY (OperationMapper::mul), + RSQRT (OperationMapper::rsqrt), + SELECT (OperationMapper::select), + WHERE3 (OperationMapper::select), + SIGMOID (OperationMapper::sigmoid), + SQUAREDDIFFERENCE (OperationMapper::squaredDifference), + SUB (OperationMapper::sub), + SUBTRACT (OperationMapper::sub), + + /* + * nn ops + */ + BIASADD (OperationMapper::add), + ELU (OperationMapper::elu), + RELU (OperationMapper::relu), + SELU (OperationMapper::selu), + SOFTMAX (OperationMapper::softMax), + + /* + * state ops + */ + VARIABLE (OperationMapper::variable), + VARIABLEV2 (OperationMapper::variable), + + /* + * evaluation no-ops + */ + STOPGRADIENT (OperationMapper::identity), + NOOP (OperationMapper::noOp); + + + private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func; + + Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) { + this.func = func; + } + + Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { + return func.apply(params); + } + + } + + static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { + Optional<Operation> operation = Stream.of(Operation.values()) + .filter(op -> op.name().equalsIgnoreCase(params.node().getOp())) + .findFirst(); + if (operation.isPresent()) { + return operation.get().map(params); + } + params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + + "' in node '" + params.node().getName() + "' is not supported."); + return Optional.empty(); + } + + /* - A note on conversion from implicitly numbered to explicitly named dimensions: - Vespa tensor dimensions are explicitly named and thus have an explicit notion of being - 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation - comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation - around dimension renaming operations which mirrors those built into the TF operation definitions. - - To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' - dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation - and the result is then renamed again (if necessary) to recover this convention across a full nested - computation. - - This requires us to track tensor types throughout the conversion. + * Operations */ - private TensorConverter tensorConverter = new TensorConverter(); + private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { + Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); + return createConstant(params, value); + } - TypedTensorFunction join(List<TypedTensorFunction> arguments, DoubleBinaryOperator doubleFunction) { - ensureArguments(2, arguments, "join"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(1); + private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); - if (a.type().rank() == 0 && b.type().rank() > 0) { - return new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction)); + Tensor axis = getConstantTensor(params, params.node().getInput(1)); + if (axis.type().rank() != 0) { + throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); } - if (b.type().rank() == 0 && a.type().rank() > 0) { - return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)); + + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + + int dimensionToInsert = (int)axis.asDouble(); + if (dimensionToInsert < 0) { + dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - if (a.type().rank() == b.type().rank()) { - return new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)); + + TensorType.Builder outputTypeBuilder = new TensorType.Builder(); + int dimensionIndex = 0; + for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { + String name = String.format("temp_%d", i); + Long size; + if (i == dimensionToInsert) { + size = 1L; + } else { + size = dimensionSize(inputType.dimensions().get(dimensionIndex)); + dimensionIndex++; + } + outputTypeBuilder.indexed(name, size); } - // Well now we have entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // I'm not able to extract from that any unambiguous specification of which dimensions - // should be "stretched" when the tensor do not have the same number of dimensions. - // From trying this with TensorFlow it appears that the second tensor is matched to the - // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. - // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + return reshape(inputFunction, inputType, outputTypeBuilder.build()); + } - if (a.type().rank() > b.type().rank()) { - TensorFunction renameFunction = renameForBroadcast(a, b); - return new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction)); + private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 1)) { + return Optional.empty(); } - TensorFunction renameFunction = renameForBroadcast(b, a); - return new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction)); + return params.inputs().get(0); } - private TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) { - List<String> renameFrom = new ArrayList<>(); - List<String> renameTo = new ArrayList<>(); - int sizeDifference = a.type().rank() - b.type().rank(); - for (int i = 0; i < b.type().rank(); i++) { - renameFrom.add(b.type().dimensions().get(i).name()); - renameTo.add("d" + (sizeDifference + i)); + private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) { + String name = params.node().getName(); + TensorType type = params.result().arguments().get(name); + if (type == null) { + throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + + "', but there is no such placeholder"); } - return new Rename(b.function(), renameFrom, renameTo); + // Included literally in the expression and so must be produced by a separate macro in the rank profile + TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(name)); + return Optional.of(output); + } + + private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) { + String name = params.node().getInput(0); + Tensor defaultValue = getConstantTensor(params, name); + params.result().constant(name, defaultValue); + params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); + // The default value will be provided by the macro. Users can override macro to change value. + TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); + return Optional.of(output); } - TypedTensorFunction map(List<TypedTensorFunction> arguments, DoubleUnaryOperator doubleFunction) { - ensureArguments(1, arguments, "apply"); - TypedTensorFunction a = arguments.get(0); + private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + Tensor shape = getConstantTensor(params, params.node().getInput(1)); - TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); - com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); - return new TypedTensorFunction(resultType, function); + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + + TensorType.Builder outputTypeBuilder = new TensorType.Builder(); + int dimensionIndex = 0; + for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int size = cell.getValue().intValue(); + if (size < 0) { + size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); + } + outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); + dimensionIndex++; + } + return reshape(inputFunction, inputType, outputTypeBuilder.build()); } - TypedTensorFunction placeholder(NodeDef tfNode, TensorFlowModel result) { - String name = tfNode.getName(); - TensorType type = result.arguments().get(name); - if (type == null) - throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); - // Included literally in the expression and so must be produced by a separate macro in the rank profile - return new TypedTensorFunction(type, new VariableTensor(name)); + private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 1)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + List<String> squeezeDimensions; + + AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims"); + if (squeezeDimsAttr == null) { + squeezeDimensions = inputType.dimensions().stream(). + filter(dim -> dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } else { + squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). + map(i -> i < 0 ? inputType.dimensions().size() - i : i). + map(i -> inputType.dimensions().get(i.intValue())). + filter(dim -> dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } + + if (squeezeDimensions.isEmpty()) { + return inputs.get(0); + } + + TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + TensorType outputType = Reduce.outputType(inputType, squeezeDimensions); + TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); + return Optional.of(output); } - TypedTensorFunction placeholderWithDefault(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { - String name = tfNode.getInput(0); - Tensor defaultValue = getConstantTensor(model, name); - result.constant(name, defaultValue); - result.macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); - // The default value will be provided by the macro. Users can override macro to change value. - return new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); + private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) { + return params.inputs().stream() + .filter(Optional::isPresent) + .findFirst() + .orElse(Optional.empty()); } - TypedTensorFunction constant(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { - String name = tfNode.getName(); - if (tfNode.getInputList().size() != 0) { - throw new IllegalArgumentException("A constant node must have zero inputs but '" + name + "' has " + - tfNode.getInputList().size()); + private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + Tensor predicate = getConstantTensor(params, params.node().getInput(1)); + if (predicate.type().rank() != 0) { + throw new IllegalArgumentException("'switch': predicate must be a scalar"); + } + double pred = predicate.asDouble(); + int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0; + if (output < 0 || output > 1) { + throw new IllegalArgumentException("'switch': predicate is not boolean"); + } + if (pred == output) { + return inputs.get(0); } - return importConstantTensor(tfNode, model, result, name); + return Optional.empty(); } - TypedTensorFunction identity(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result) { - if ( ! tfNode.getName().endsWith("/read")) - throw new IllegalArgumentException("Encountered identity node " + tfNode.getName() + ", but identify " + - "nodes are only supported when reading variables"); - if (tfNode.getInputList().size() != 1) - throw new IllegalArgumentException("A Variable/read node must have one input but '" + - tfNode.getName() + "' has " + tfNode.getInputList().size()); + private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.add()); + } - String name = tfNode.getInput(0); - return importConstantTensor(tfNode, model, result, name); + private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.acos()); } - private TypedTensorFunction importConstantTensor(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, String name) { - AttrValue shapes = tfNode.getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("'" + name + "' is missing a tensor shape"); - Tensor constant = getConstantTensor(model, name); - result.constant(name, constant); - return new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode(new ReferenceNode("constant(\"" + name + "\")"))); + private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.divide()); } - private Tensor getConstantTensor(SavedModelBundle model, String name) { - Session.Runner fetched = model.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + - importedTensors.size()); - return tensorConverter.toVespaTensor(importedTensors.get(0)); + private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.floor()); } - TypedTensorFunction matmul(List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "matmul"); - TypedTensorFunction a = arguments.get(0); - TypedTensorFunction b = arguments.get(1); + private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) { + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + + TypedTensorFunction a = inputs.get(0).get(); + TypedTensorFunction b = inputs.get(1).get(); if (a.type().rank() < 2 || b.type().rank() < 2) throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); if (a.type().rank() != b.type().rank()) @@ -190,17 +351,24 @@ class OperationMapper { Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), ImmutableList.of("d1", afterLastDim)); Matmul matmul = new Matmul(a.function(), renamedB, "d1"); - return new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), + TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), new Rename(matmul, afterLastDim, "d1")); + return Optional.of(output); } - TypedTensorFunction mean(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "mean"); - Tensor reductionIndices = getConstantTensor(model, tfNode.getInput(1)); + private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.max()); + } - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); + private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 2)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TensorFunction inputFunction = inputs.get(0).get().function(); + TensorType inputType = inputs.get(0).get().type(); + Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1)); List<String> reduceDimensions = new ArrayList<>(); for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -214,122 +382,195 @@ class OperationMapper { TensorType outputType = Reduce.outputType(inputType, reduceDimensions); TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); - if (shouldKeepDimensions(tfNode)) { + if (shouldKeepDimensions(params)) { return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions)); } - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return output; + return Optional.of(output); } - private boolean shouldKeepDimensions(NodeDef tfNode) { - AttrValue keepDimsAttr = tfNode.getAttrMap().get("keep_dims"); - return keepDimsAttr != null && keepDimsAttr.getB(); + private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.multiply()); } - private TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension: inputType.dimensions()) { - String name = dimension.name(); - Long size = dimensionSize(dimension); - if (reduceDimensions.contains(name)) { - size = 1L; - } - builder.indexed(name, size); - } - return builder.build(); + private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.rsqrt()); } - private TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) { - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - return fixNamingConvention(type, function); - } + private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 3)) { + return Optional.empty(); } - return new TypedTensorFunction(type, function); - } + Tensor condition = getConstantTensor(params, params.node().getInput(0)); - private TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) { - TensorType.Builder correctType = new TensorType.Builder(); - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - from.add(current); - to.add(correct); - } - correctType.indexed(correct, dimensionSize(type.dimensions().get(i))); + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TypedTensorFunction x = inputs.get(1).get(); + TypedTensorFunction y = inputs.get(2).get(); + if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) { + throw new IllegalArgumentException("'Select': input tensors must have the same shape"); } - if (from.size() > 0) { - function = new Rename(function, from, to); - type = correctType.build(); + + if (condition.type().rank() == 0) { + return Optional.of((int)condition.asDouble() == 0 ? y : x); } - return new TypedTensorFunction(type, function); + if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { + return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x); + } + + // The task is to select cells from 'x' or 'y' based on 'condition'. + // If 'condition' is 0 (false), select from 'y', if 1 (true) select + // from 'x'. We do this by individually joining 'x' and 'y' with + // 'condition', and then joining the resulting two tensors. + + Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0)); + if (!conditionFunction.isPresent()) { + return Optional.empty(); + } + TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply()); + TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() { + @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } + @Override public String toString() { return "f(a,b)(a * (1-b))"; } + }); + TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); + TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction); + return Optional.of(output); } - TypedTensorFunction noOp(List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "noOp"); - return arguments.get(0); + private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.sigmoid()); } - TypedTensorFunction expandDims(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "expandDims"); - Tensor axis = getConstantTensor(model, tfNode.getInput(1)); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); + private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.squareddifference()); + } + + private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) { + return join(params, ScalarFunctions.subtract()); + } + + private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.elu()); + } + + private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.relu()); + } + + private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) { + return map(params, ScalarFunctions.selu()); + } + + private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) { + if (!checkInputs(params, 1)) { + return Optional.empty(); } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TypedTensorFunction a = inputs.get(0).get(); + // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 + String dimension = "d" + (a.type().rank() - 1); + Softmax softmax = new Softmax(a.function(), dimension); + TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); + return Optional.of(output); + } - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); + private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) { + return importConstantTensor(params, params.node().getName()); + } - int dimensionToInsert = (int)axis.asDouble(); - if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) { + return Optional.empty(); + } + + /* + * Utility + */ + + private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) { + if (!checkInputs(params, 2)) { + return Optional.empty(); } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { - String name = String.format("temp_%d", i); - Long size; - if (i == dimensionToInsert) { - size = 1L; - } else { - size = dimensionSize(inputType.dimensions().get(dimensionIndex)); - dimensionIndex++; - } - outputTypeBuilder.indexed(name, size); + TypedTensorFunction a = inputs.get(0).get(); + TypedTensorFunction b = inputs.get(1).get(); + + if (a.type().rank() == 0 && b.type().rank() > 0) { + return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction))); + } + if (b.type().rank() == 0 && a.type().rank() > 0) { + return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); + } + if (a.type().rank() == b.type().rank()) { + return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); + // Well now we have entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + + if (a.type().rank() > b.type().rank()) { + TensorFunction renameFunction = renameForBroadcast(a, b); + return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction))); + } + TensorFunction renameFunction = renameForBroadcast(b, a); + return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction))); + } + + private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) { + List<String> renameFrom = new ArrayList<>(); + List<String> renameTo = new ArrayList<>(); + int sizeDifference = a.type().rank() - b.type().rank(); + for (int i = 0; i < b.type().rank(); i++) { + renameFrom.add(b.type().dimensions().get(i).name()); + renameTo.add("d" + (sizeDifference + i)); + } + return new Rename(b.function(), renameFrom, renameTo); } - TypedTensorFunction reshape(NodeDef tfNode, SavedModelBundle model, List<TypedTensorFunction> arguments) { - ensureArguments(2, arguments, "reshape"); - Tensor shape = getConstantTensor(model, tfNode.getInput(1)); + private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) { + if (!checkInputs(params, 1)) { + return Optional.empty(); + } + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + TypedTensorFunction a = inputs.get(0).get(); + TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); + com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); + return Optional.of(new TypedTensorFunction(resultType, function)); + } - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); + private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) { + params.result().constant(params.node().getName(), constant); + TypedTensorFunction output = new TypedTensorFunction(constant.type(), + new TensorFunctionNode.TensorFunctionExpressionNode( + new ReferenceNode("constant(\"" + params.node().getName() + "\")"))); + return Optional.of(output); + } - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); - if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); - } - outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); - dimensionIndex++; + private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { + if (params.result().constants().containsKey(name)) { + return params.result().constants().get(name); } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); + Session.Runner fetched = params.model().session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + + importedTensors.size()); + return TensorConverter.toVespaTensor(importedTensors.get(0)); + } + + private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) { + AttrValue shapes = params.node().getAttrMap().get("_output_shapes"); + if (shapes == null) + throw new IllegalArgumentException("'" + name + "' is missing a tensor shape"); + Tensor constant = getConstantTensor(params, name); + return createConstant(params, constant); } - private TypedTensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { + private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { if (!tensorSize(inputType).equals(tensorSize(outputType))) { throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); } @@ -353,10 +594,10 @@ class OperationMapper { Reduce.Aggregator.sum, inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return output; + return Optional.of(output); } - private ExpressionNode unrollTensorExpression(TensorType type) { + private static ExpressionNode unrollTensorExpression(TensorType type) { if (type.rank() == 0) { return new ConstantNode(DoubleValue.zero); } @@ -378,80 +619,56 @@ class OperationMapper { return new ArithmeticNode(children, operators); } - TypedTensorFunction select(NodeDef tfNode, SavedModelBundle model, TensorFlowModel result, List<TypedTensorFunction> arguments) { - ensureArguments(3, arguments, "select"); - Tensor condition = getConstantTensor(model, tfNode.getInput(0)); - - TypedTensorFunction x = arguments.get(1); - TypedTensorFunction y = arguments.get(2); - if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) { - throw new IllegalArgumentException("'Select': input tensors must have the same shape"); - } + private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) { + AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims"); + return keepDimsAttr != null && keepDimsAttr.getB(); + } - if (condition.type().rank() == 0) { - return (int)condition.asDouble() == 0 ? y : x; - } - if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { - return condition.cellIterator().next().getValue().intValue() == 0 ? y : x; + private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension: inputType.dimensions()) { + String name = dimension.name(); + Long size = dimensionSize(dimension); + if (reduceDimensions.contains(name)) { + size = 1L; + } + builder.indexed(name, size); } - - // The task is to select cells from 'x' or 'y' based on 'condition'. - // If 'condition' is 0 (false), select from 'y', if 1 (true) select - // from 'x'. We do this by individually joining 'x' and 'y' with - // 'condition', and then joining the resulting two tensors. - - TypedTensorFunction conditionFunction = importConstantTensor(tfNode, model, result, tfNode.getInput(0)); - TensorFunction xCond = new Join(x.function(), conditionFunction.function(), ScalarFunctions.multiply()); - TensorFunction yCond = new Join(y.function(), conditionFunction.function(), new DoubleBinaryOperator() { - @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } - @Override public String toString() { return "f(a,b)(a * (1-b))"; } - }); - TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); - return new TypedTensorFunction(x.type(), outputFunction); + return builder.build(); } - TypedTensorFunction softmax(List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "softmax"); - TypedTensorFunction a = arguments.get(0); - // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 - String dimension = "d" + (a.type().rank() - 1); - Softmax softmax = new Softmax(a.function(), dimension); - return new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); + private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) { + for (int i = 0; i < type.dimensions().size(); ++i) { + String correct = String.format("d%d", i); + String current = type.dimensions().get(i).name(); + if (!current.equals(correct)) { + return fixNamingConvention(type, function); + } + } + return new TypedTensorFunction(type, function); } - TypedTensorFunction squeeze(NodeDef tfNode, List<TypedTensorFunction> arguments) { - ensureArguments(1, arguments, "squeeze"); - - TensorFunction inputFunction = arguments.get(0).function(); - TensorType inputType = arguments.get(0).type(); - List<String> squeezeDimensions; - - AttrValue squeezeDimsAttr = tfNode.getAttrMap().get("squeeze_dims"); - if (squeezeDimsAttr == null) { - squeezeDimensions = inputType.dimensions().stream(). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } else { - squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). - map(i -> i < 0 ? inputType.dimensions().size() - i : i). - map(i -> inputType.dimensions().get(i.intValue())). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); + private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) { + TensorType.Builder correctType = new TensorType.Builder(); + List<String> from = new ArrayList<>(); + List<String> to = new ArrayList<>(); + for (int i = 0; i < type.dimensions().size(); ++i) { + String correct = String.format("d%d", i); + String current = type.dimensions().get(i).name(); + if (!current.equals(correct)) { + from.add(current); + to.add(correct); + } + correctType.indexed(correct, dimensionSize(type.dimensions().get(i))); } - - if (squeezeDimensions.isEmpty()) { - return arguments.get(0); + if (from.size() > 0) { + function = new Rename(function, from, to); + type = correctType.build(); } - - TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); - TensorType outputType = Reduce.outputType(inputType, squeezeDimensions); - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return output; + return new TypedTensorFunction(type, function); } - private Long tensorSize(TensorType type) { + private static Long tensorSize(TensorType type) { Long size = 1L; for (TensorType.Dimension dimension : type.dimensions()) { size *= dimensionSize(dimension); @@ -459,14 +676,21 @@ class OperationMapper { return size; } - private Long dimensionSize(TensorType.Dimension dim) { + private static Long dimensionSize(TensorType.Dimension dim) { return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); } - private void ensureArguments(int count, List<TypedTensorFunction> arguments, String operationName) { - if ( arguments.size() != count) - throw new IllegalArgumentException("Expected " + count + " arguments to " + operationName + - ", but got " + arguments.size()); + private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) { + List<Optional<TypedTensorFunction>> inputs = params.inputs(); + if (!inputs.stream().allMatch(Optional::isPresent)) { + return false; + } + if (inputs.size() != expected) { + params.signature().importWarning("Expected " + expected + + " arguments to " + params.node().getOp() + ", but got " + inputs.size()); + return false; + } + return true; } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java index 8edb9b9b7a1..b88ffce275a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -19,7 +19,7 @@ import java.nio.LongBuffer; */ public class TensorConverter { - public Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { TensorType type = toVespaTensorType(tfTensor.shape()); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); @@ -28,7 +28,7 @@ public class TensorConverter { return builder.build(); } - private TensorType toVespaTensorType(long[] shape) { + private static TensorType toVespaTensorType(long[] shape) { TensorType.Builder b = new TensorType.Builder(); int dimensionIndex = 0; for (long dimensionSize : shape) { @@ -38,7 +38,7 @@ public class TensorConverter { return b.build(); } - private Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { + private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { switch (tfTensor.dataType()) { case DOUBLE: return new DoubleValues(tfTensor); case FLOAT: return new FloatValues(tfTensor); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 4780c39d21d..3a6b3f23a1d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -4,7 +4,6 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.ScalarFunctions; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; import org.tensorflow.framework.GraphDef; @@ -12,12 +11,14 @@ import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.SignatureDef; import org.tensorflow.framework.TensorInfo; -import org.tensorflow.framework.TensorShapeProto; import java.io.File; import java.io.IOException; +import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.stream.Collectors; /** @@ -27,8 +28,6 @@ import java.util.stream.Collectors; */ public class TensorFlowImporter { - private final OperationMapper operationMapper = new OperationMapper(); - /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. @@ -68,9 +67,21 @@ public class TensorFlowImporter { for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { String outputName = output.getKey(); try { - NodeDef node = getNode(nameOf(output.getValue().getName()), graph.getGraphDef()); - importNode(node, graph.getGraphDef(), model, result); - signature.output(outputName, nameOf(output.getValue().getName())); + NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef()); + Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, ""); + + // Commonly, there are multiple paths through a TensorFlow graph, for instance for + // training and testing/evaluation. Examples are dropout and batch norm. For Vespa + // we are not concerned with training paths, so we can ignore non-supported operations + // as long as they are on a path that will not be evaluated run time. Operations + // that fail import will not have a value present in the optionals. However, the + // final output node must have value present. It is an error if it does not. + + Optional<TypedTensorFunction> outputFunction = importNode(params); + if (!outputFunction.isPresent()) { + throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n"))); + } + signature.output(outputName, namePartOf(output.getValue().getName())); } catch (IllegalArgumentException e) { signature.skippedOutput(outputName, Exceptions.toMessageString(e)); @@ -82,93 +93,59 @@ public class TensorFlowImporter { private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) { inputInfoMap.forEach((key, value) -> { - String argumentName = nameOf(value.getName()); - TensorType argumentType = importTensorType(value.getTensorShape()); + String argumentName = namePartOf(value.getName()); + TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape()); // Arguments are (Placeholder) nodes, so not local to the signature: signature.owner().argument(argumentName, argumentType); signature.input(key, argumentName); }); } - private TensorType importTensorType(TensorShapeProto tensorShape) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorShapeProto.Dim dimension : tensorShape.getDimList()) { - int dimensionSize = (int)dimension.getSize(); - if (dimensionSize >= 0) - b.indexed("d" + b.rank(), dimensionSize); - else - b.indexed("d" + b.rank()); // unbound size + /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ + private Optional<TypedTensorFunction> importNode(Parameters params) { + String nodeName = params.node().getName(); + if (params.imported().containsKey(nodeName)) { + return Optional.of(params.imported().get(nodeName)); } - return b.build(); - } - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private TypedTensorFunction importNode(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { - TypedTensorFunction function = tensorFunctionOf(tfNode, graph, model, result); + Optional<TypedTensorFunction> function = OperationMapper.map(params); + if (!function.isPresent()) { + return Optional.empty(); + } + if (!controlDependenciesArePresent(params)) { + return Optional.empty(); + } + params.imported().put(nodeName, function.get()); + try { // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree - result.expression(tfNode.getName(), new RankingExpression(tfNode.getName(), function.function().toString())); + params.result().expression(nodeName, + new RankingExpression(params.node().getName(), function.get().function().toString())); return function; } catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function.function() + + throw new RuntimeException("Tensorflow function " + function.get().function() + " cannot be parsed as a ranking expression", e); } } + private boolean controlDependenciesArePresent(Parameters params) { + return params.node().getInputList().stream() + .filter(TensorFlowImporter::isControlDependency) + .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) + .allMatch(Optional::isPresent); + } - - private TypedTensorFunction tensorFunctionOf(NodeDef tfNode, GraphDef graph, SavedModelBundle model, TensorFlowModel result) { - // Import arguments lazily below, as some nodes have arguments unused arguments leading to unsupported ops - // TODO: Implement mapping of more functions from https://www.tensorflow.org/api_docs/python/ - switch (tfNode.getOp().toLowerCase()) { - // array ops - case "const" : return operationMapper.constant(tfNode, model, result); - case "expanddims" : return operationMapper.expandDims(tfNode, model, importArguments(tfNode, graph, model, result)); - case "identity" : return operationMapper.identity(tfNode, model, result); - case "placeholder" : return operationMapper.placeholder(tfNode, result); - case "placeholderwithdefault" : return operationMapper.placeholderWithDefault(tfNode, model, result); - case "reshape" : return operationMapper.reshape(tfNode, model, importArguments(tfNode, graph, model, result)); - case "squeeze" : return operationMapper.squeeze(tfNode, importArguments(tfNode, graph, model, result)); - - // math ops - case "add" : case "add_n" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "acos" : return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.acos()); - case "matmul" : return operationMapper.matmul(importArguments(tfNode, graph, model, result)); - case "maximum" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.max()); - case "mean" : case "reducemean": return operationMapper.mean(tfNode, model, importArguments(tfNode, graph, model, result)); - case "multiply": case "mul" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.multiply()); - case "rsqrt": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.rsqrt()); - case "where3": case "select" : return operationMapper.select(tfNode, model, result, importArguments(tfNode, graph, model, result)); - case "sigmoid": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.sigmoid()); - case "squareddifference" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.squareddifference()); - case "subtract" : case "sub" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.subtract()); - - // nn ops - case "biasadd" : return operationMapper.join(importArguments(tfNode, graph, model, result), ScalarFunctions.add()); - case "elu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.elu()); - case "relu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.relu()); - case "selu": return operationMapper.map(importArguments(tfNode, graph, model, result), ScalarFunctions.selu()); - case "softmax" : return operationMapper.softmax(importArguments(tfNode, graph, model, result)); - - // evaluation no-ops - case "stopgradient" : - case "noop": - return operationMapper.noOp(importArguments(tfNode, graph, model, result)); - - // not supported - default : - throw new IllegalArgumentException("Conversion of TensorFlow operation '" + tfNode.getOp() + - "' is not supported (" + tfNode.getName() + ")"); - } + private static boolean isControlDependency(String nodeName) { + return nodeName.startsWith("^"); } - private List<TypedTensorFunction> importArguments(NodeDef tfNode, GraphDef graph, SavedModelBundle model, - TensorFlowModel result) { - return tfNode.getInputList().stream() - .map(argNode -> importNode(getNode(nameOf(argNode), graph), graph, model, result)) - .collect(Collectors.toList()); + private List<Optional<TypedTensorFunction>> importArguments(Parameters params) { + return params.node().getInputList().stream() + .filter(nodeName -> !isControlDependency(nodeName)) + .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) + .collect(Collectors.toList()); } private NodeDef getNode(String name, GraphDef graph) { @@ -182,8 +159,94 @@ public class TensorFlowImporter { * A method signature input and output has the form name:index. * This returns the name part without the index. */ - private String nameOf(String name) { + private static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; return name.split(":")[0]; } + /** + * This return the index part. Indexes are used for nodes with + * multiple outputs. + */ + private static String indexPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? "" : name.substring(i + 1); + } + + + private Parameters createParameters(GraphDef graph, + SavedModelBundle model, + TensorFlowModel result, + TensorFlowModel.Signature signature, + NodeDef node, + String port) { + return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port); + } + + /** Parameter object to hold important data while importing */ + static final class Parameters { + private final TensorFlowImporter owner; + private final GraphDef graph; + private final SavedModelBundle model; + private final TensorFlowModel result; + private final TensorFlowModel.Signature signature; + private final Map<String, TypedTensorFunction> imported; + private final NodeDef node; + private final String port; + + private Parameters(TensorFlowImporter owner, + GraphDef graph, + SavedModelBundle model, + TensorFlowModel result, + TensorFlowModel.Signature signature, + Map<String, TypedTensorFunction> imported, + NodeDef node, + String port) { + this.owner = owner; + this.graph = graph; + this.model = model; + this.result = result; + this.signature = signature; + this.imported = imported; + this.node = node; + this.port = port; + } + + GraphDef graph() { + return this.graph; + } + + SavedModelBundle model() { + return this.model; + } + + TensorFlowModel result() { + return this.result; + } + + TensorFlowModel.Signature signature() { + return this.signature; + } + + Map<String, TypedTensorFunction> imported() { + return this.imported; + } + + NodeDef node() { + return node; + } + + String port() { + return port; + } + + Parameters copy(NodeDef node, String port) { + return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port); + } + + List<Optional<TypedTensorFunction>> inputs() { + return owner.importArguments(this); + } + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java index 1a6c93384ea..60aaf8ddce1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java @@ -5,8 +5,10 @@ import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.List; import java.util.Map; /** @@ -67,6 +69,7 @@ public class TensorFlowModel { private final Map<String, String> inputs = new HashMap<>(); private final Map<String, String> outputs = new HashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); + private final List<String> importWarnings = new ArrayList<>(); Signature(String name) { this.name = name; @@ -75,6 +78,7 @@ public class TensorFlowModel { void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } void output(String name, String expressionName) { outputs.put(name, expressionName); } void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); } + void importWarning(String warning) { importWarnings.add(warning); } public String name() { return name; } @@ -99,6 +103,11 @@ public class TensorFlowModel { */ public Map<String, String> skippedOutputs() { return Collections.unmodifiableMap(skippedOutputs); } + /** + * Returns an immutable list of possibly non-fatal warnings encountered during import. + */ + public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } + /** Returns owner().expressions().get(outputs.get(outputName)), e.g the expression this output references */ public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java index 962f9dda0a6..600225bfe76 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java @@ -19,7 +19,12 @@ final class TypedTensorFunction { this.function = function; } - public TensorType type() { return type; } - public TensorFunction function() { return function; } + public TensorType type() { + return type; + } + + public TensorFunction function() { + return function; + } } |