diff options
author | Lester Solbakken <lesters@oath.com> | 2018-02-05 16:04:42 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-02-22 12:54:34 +0100 |
commit | b1f46fcd0495dbce905fb8b7318781f4cf5965a7 (patch) | |
tree | d0a0506fe66e5af4af2a927101a0eb9ed9420d38 /searchlib | |
parent | e307df56eaaf5b0ebca5aefb7f7e0c5c3a970bdb (diff) |
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'searchlib')
31 files changed, 2393 insertions, 1187 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java deleted file mode 100644 index 5f0c016881a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java +++ /dev/null @@ -1,132 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.TensorProto; -import org.tensorflow.framework.TensorShapeProto; - -/** - * @author lesters - */ -public class AttrValueConverter { - - public static Tensor toVespaTensor(NodeDef tfNode, String attr) { - if (!tfNode.getAttrMap().containsKey(attr)) { - throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr); - } - AttrValue attrValue = tfNode.getAttrMap().get(attr); - switch (attrValue.getValueCase()) { - case TENSOR: - return buildFromTensor(attrValue); - case B: - return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0); - case F: - return buildFromSingleValue(attrValue.getF()); - case I: - return buildFromSingleValue(attrValue.getI()); - } - - throw new IllegalArgumentException(tfNode.getName() + - ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'"); - } - - private static Tensor buildFromSingleValue(double value) { - TensorType type = new TensorType.Builder().build(); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - builder.cellByDirectIndex(0, value); - return builder.build(); - } - - private static Tensor buildFromTensor(AttrValue attrValue) { - TensorProto tensorProto = attrValue.getTensor(); - TensorType type = toVespaTensorType(tensorProto.getTensorShape()); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - Values values = valuesOf(tensorProto); - for (int i = 0; i < values.size(); ++i) { - builder.cellByDirectIndex(i, values.get(i)); - } - Tensor tensor = builder.build(); - return tensor; - } - - private static Values valuesOf(TensorProto tensorProto) { - switch (tensorProto.getDtype()) { - case DT_BOOL: - return new BoolValues(tensorProto); - case DT_HALF: - return new HalfValues(tensorProto); - case DT_INT16: - case DT_INT32: - return new IntValues(tensorProto); - case DT_INT64: - return new Int64Values(tensorProto); - case DT_FLOAT: - return new FloatValues(tensorProto); - case DT_DOUBLE: - return new DoubleValues(tensorProto); - } - - throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); - } - - public static TensorType toVespaTensorType(TensorShapeProto shapeProto) { - TensorType.Builder b = new TensorType.Builder(); - for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) { - int dimensionSize = (int)dimension.getSize(); - if (dimensionSize >= 0) - b.indexed("d" + b.rank(), dimensionSize); - else - b.indexed("d" + b.rank()); // unbound size - } - return b.build(); - } - - private static abstract class Values { - protected final TensorProto tensorProto; - protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; } - abstract double get(int i); - abstract int size(); - } - - private static class BoolValues extends Values { - BoolValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } - @Override int size() { return tensorProto.getBoolValCount(); } - } - - private static class HalfValues extends Values { - HalfValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getHalfVal(i); } - @Override int size() { return tensorProto.getHalfValCount(); } - } - - private static class IntValues extends Values { - IntValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getIntVal(i); } - @Override int size() { return tensorProto.getIntValCount(); } - } - - private static class Int64Values extends Values { - Int64Values(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getInt64Val(i); } - @Override int size() { return tensorProto.getInt64ValCount(); } - } - - private static class FloatValues extends Values { - FloatValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getFloatVal(i); } - @Override int size() { return tensorProto.getFloatValCount(); } - } - - private static class DoubleValues extends Values { - DoubleValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getDoubleVal(i); } - @Override int size() { return tensorProto.getDoubleValCount(); } - } - - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java deleted file mode 100644 index ef82045e771..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java +++ /dev/null @@ -1,715 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.google.common.collect.ImmutableList; -import com.yahoo.searchlib.rankingexpression.RankingExpression; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; -import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; -import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; -import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.ScalarFunctions; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.TensorFunction; -import org.tensorflow.Session; -import org.tensorflow.framework.AttrValue; - -import java.util.ArrayList; -import java.util.Iterator; -import java.util.List; -import java.util.Optional; -import java.util.function.DoubleBinaryOperator; -import java.util.function.DoubleUnaryOperator; -import java.util.function.Function; -import java.util.stream.Collectors; -import java.util.stream.Stream; - -/** - * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions. - * - * @author bratseth - * @author lesters - */ -class OperationMapper { - - // A note on conversion from implicitly numbered to explicitly named dimensions: - // - // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being - // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation - // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation - // around dimension renaming operations which mirrors those built into the TF operation definitions. - // - // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost' - // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation - // and the result is then renamed again (if necessary) to recover this convention across a full nested - // computation. - // - // This requires us to track tensor types throughout the conversion. - - - // Supported TensorFlow operations - enum Operation { - - // TODO: move the implementations to specific files as we support more operations - - /* - * array ops - */ - CONST (OperationMapper::constant), - EXPANDDIMS (OperationMapper::expandDims), - IDENTITY (OperationMapper::identity), - PLACEHOLDER (OperationMapper::placeholder), - PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault), - RESHAPE (OperationMapper::reshape), - SQUEEZE (OperationMapper::squeeze), - - /* - * control flow - */ - MERGE (OperationMapper::merge), - SWITCH (OperationMapper::switchOp), - - /* - * math ops - */ - ADD (OperationMapper::add), - ADD_N (OperationMapper::add), - ACOS (OperationMapper::acos), - DIV (OperationMapper::div), - REALDIV (OperationMapper::div), - FLOOR (OperationMapper::floor), - MATMUL (OperationMapper::matmul), - MAXIMUM (OperationMapper::maximum), - MEAN (OperationMapper::mean), - REDUCEMEAN (OperationMapper::mean), - MUL (OperationMapper::mul), - MULTIPLY (OperationMapper::mul), - RSQRT (OperationMapper::rsqrt), - SELECT (OperationMapper::select), - WHERE3 (OperationMapper::select), - SIGMOID (OperationMapper::sigmoid), - SQUAREDDIFFERENCE (OperationMapper::squaredDifference), - SUB (OperationMapper::sub), - SUBTRACT (OperationMapper::sub), - - /* - * nn ops - */ - BIASADD (OperationMapper::add), - ELU (OperationMapper::elu), - RELU (OperationMapper::relu), - SELU (OperationMapper::selu), - SOFTMAX (OperationMapper::softMax), - - /* - * state ops - */ - VARIABLE (OperationMapper::variable), - VARIABLEV2 (OperationMapper::variable), - - /* - * evaluation no-ops - */ - STOPGRADIENT (OperationMapper::identity), - NOOP (OperationMapper::noOp); - - - private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func; - - Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) { - this.func = func; - } - - Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { - return func.apply(params); - } - - } - - static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) { - Optional<Operation> operation = Stream.of(Operation.values()) - .filter(op -> op.name().equalsIgnoreCase(params.node().getOp())) - .findFirst(); - if (operation.isPresent()) { - return operation.get().map(params); - } - params.signature().importWarning("TensorFlow operation '" + params.node().getOp() + - "' in node '" + params.node().getName() + "' is not supported."); - return Optional.empty(); - } - - - // Operations --------------------------------- - - private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) { - Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value"); - if (value.type().rank() == 0) { - TypedTensorFunction output = new TypedTensorFunction(value.type(), - new TensorFunctionNode.TensorFunctionExpressionNode( - new ConstantNode(new DoubleValue(value.asDouble())))); - return Optional.of(output); - } - return createConstant(params, value); - } - - private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - - Tensor axis = getConstantTensor(params, params.node().getInput(1)); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar"); - } - - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - - int dimensionToInsert = (int)axis.asDouble(); - if (dimensionToInsert < 0) { - dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; - } - - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (int i = 0; i < inputType.dimensions().size() + 1; ++i) { - String name = String.format("temp_%d", i); - Long size; - if (i == dimensionToInsert) { - size = 1L; - } else { - size = dimensionSize(inputType.dimensions().get(dimensionIndex)); - dimensionIndex++; - } - outputTypeBuilder.indexed(name, size); - } - - return reshape(inputFunction, inputType, outputTypeBuilder.build()); - } - - private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - return params.inputs().get(0); - } - - private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) { - String name = params.node().getName(); - String vespaName = toVespaName(params.node().getName()); - TensorType type = params.result().arguments().get(name); - if (type == null) { - throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name + - "', but there is no such placeholder"); - } - params.result().requiredMacro(vespaName, type); - // Included literally in the expression and so must be produced by a separate macro in the rank profile - TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(vespaName, type)); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) { - String name = toVespaName(params.node().getInput(0)); - Tensor defaultValue = getConstantTensor(params, params.node().getInput(0)); - params.result().largeConstant(name, defaultValue); - params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")"))); - // The default value will be provided by the macro. Users can override macro to change value. - TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name)); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) { - if ( ! checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - Tensor shape = getConstantTensor(params, params.node().getInput(1)); - - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - - TensorType.Builder outputTypeBuilder = new TensorType.Builder(); - int dimensionIndex = 0; - for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int size = cell.getValue().intValue(); - if (size < 0) { - size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue(); - } - outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size); - dimensionIndex++; - } - return reshape(inputFunction, inputType, outputTypeBuilder.build()); - } - - private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - List<String> squeezeDimensions; - - AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims"); - if (squeezeDimsAttr == null) { - squeezeDimensions = inputType.dimensions().stream(). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } else { - squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). - map(i -> i < 0 ? inputType.dimensions().size() - i : i). - map(i -> inputType.dimensions().get(i.intValue())). - filter(dim -> dimensionSize(dim) == 1). - map(TensorType.Dimension::name). - collect(Collectors.toList()); - } - - if (squeezeDimensions.isEmpty()) { - return inputs.get(0); - } - - TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); - TensorType outputType = Reduce.outputType(inputType, squeezeDimensions); - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) { - return params.inputs().stream() - .filter(Optional::isPresent) - .findFirst() - .orElse(Optional.empty()); - } - - private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - Tensor predicate = getConstantTensor(params, params.node().getInput(1)); - if (predicate.type().rank() != 0) { - throw new IllegalArgumentException("'switch': predicate must be a scalar"); - } - double pred = predicate.asDouble(); - int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0; - if (output < 0 || output > 1) { - throw new IllegalArgumentException("'switch': predicate is not boolean"); - } - if (pred == output) { - return inputs.get(0); - } - return Optional.empty(); - } - - private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.add()); - } - - private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.acos()); - } - - private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.divide()); - } - - private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.floor()); - } - - private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) { - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - - TypedTensorFunction a = inputs.get(0).get(); - TypedTensorFunction b = inputs.get(1).get(); - if (a.type().rank() < 2 || b.type().rank() < 2) - throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); - if (a.type().rank() != b.type().rank()) - throw new IllegalArgumentException("Tensors in matmul must have the same rank"); - - String afterLastDim = "d" + (a.type().rank() + 1); - // Let the first dimension of the second tensor be the same as the second dimension of the first - // and the second dimension of the second argument be not present in the first argument, while leaving the - // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication. - - // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly - - Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"), - ImmutableList.of("d1", afterLastDim)); - Matmul matmul = new Matmul(a.function(), renamedB, "d1"); - TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"), - new Rename(matmul, afterLastDim, "d1")); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.max()); - } - - private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TensorFunction inputFunction = inputs.get(0).get().function(); - TensorType inputType = inputs.get(0).get().type(); - - Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1)); - List<String> reduceDimensions = new ArrayList<>(); - for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) { - Tensor.Cell cell = cellIterator.next(); - int dimensionIndex = cell.getValue().intValue(); - if (dimensionIndex < 0) { - dimensionIndex = inputType.dimensions().size() - dimensionIndex; - } - reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); - } - - TensorType outputType = Reduce.outputType(inputType, reduceDimensions); - TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); - - if (shouldKeepDimensions(params)) { - return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions)); - } - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.multiply()); - } - - private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.rsqrt()); - } - - private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 3)) { - return Optional.empty(); - } - Tensor condition = getConstantTensor(params, params.node().getInput(0)); - - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TypedTensorFunction x = inputs.get(1).get(); - TypedTensorFunction y = inputs.get(2).get(); - if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) { - throw new IllegalArgumentException("'Select': input tensors must have the same shape"); - } - - if (condition.type().rank() == 0) { - return Optional.of((int)condition.asDouble() == 0 ? y : x); - } - if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { - return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x); - } - - // The task is to select cells from 'x' or 'y' based on 'condition'. - // If 'condition' is 0 (false), select from 'y', if 1 (true) select - // from 'x'. We do this by individually joining 'x' and 'y' with - // 'condition', and then joining the resulting two tensors. - - Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0)); - if (!conditionFunction.isPresent()) { - return Optional.empty(); - } - TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply()); - TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() { - @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } - @Override public String toString() { return "f(a,b)(a * (1-b))"; } - }); - TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add()); - TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.sigmoid()); - } - - private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.squareddifference()); - } - - private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) { - return join(params, ScalarFunctions.subtract()); - } - - private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.elu()); - } - - private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.relu()); - } - - private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) { - return map(params, ScalarFunctions.selu()); - } - - private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TypedTensorFunction a = inputs.get(0).get(); - // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1 - String dimension = "d" + (a.type().rank() - 1); - Softmax softmax = new Softmax(a.function(), dimension); - TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax); - return Optional.of(output); - } - - private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) { - return importConstantTensor(params, params.node().getName()); - } - - private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) { - return Optional.empty(); - } - - /* - * Utility - */ - - private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) { - if (!checkInputs(params, 2)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - - TypedTensorFunction a = inputs.get(0).get(); - TypedTensorFunction b = inputs.get(1).get(); - - if (a.type().rank() == 0 && b.type().rank() > 0) { - return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction))); - } - if (b.type().rank() == 0 && a.type().rank() > 0) { - return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); - } - if (a.type().rank() == b.type().rank()) { - return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction))); - } - - // Well now we have entered the wonderful world of "broadcasting" - // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html - // I'm not able to extract from that any unambiguous specification of which dimensions - // should be "stretched" when the tensor do not have the same number of dimensions. - // From trying this with TensorFlow it appears that the second tensor is matched to the - // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. - // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). - - if (a.type().rank() > b.type().rank()) { - TensorFunction renameFunction = renameForBroadcast(a, b); - return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction))); - } - TensorFunction renameFunction = renameForBroadcast(b, a); - return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction))); - } - - private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) { - List<String> renameFrom = new ArrayList<>(); - List<String> renameTo = new ArrayList<>(); - int sizeDifference = a.type().rank() - b.type().rank(); - for (int i = 0; i < b.type().rank(); i++) { - renameFrom.add(b.type().dimensions().get(i).name()); - renameTo.add("d" + (sizeDifference + i)); - } - return new Rename(b.function(), renameFrom, renameTo); - } - - private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) { - if (!checkInputs(params, 1)) { - return Optional.empty(); - } - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - TypedTensorFunction a = inputs.get(0).get(); - TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type()); - com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction); - return Optional.of(new TypedTensorFunction(resultType, function)); - } - - private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) { - String name = toVespaName(params.node().getName()); - if (constant.type().rank() == 0 || constant.size() <= 1) { - params.result().smallConstant(name, constant); - } else { - params.result().largeConstant(name, constant); - } - TypedTensorFunction output = new TypedTensorFunction(constant.type(), - new TensorFunctionNode.TensorFunctionExpressionNode( - new ReferenceNode("constant(\"" + name + "\")"))); - return Optional.of(output); - } - - private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) { - String vespaName = toVespaName(name); - if (params.result().smallConstants().containsKey(vespaName)) { - return params.result().smallConstants().get(vespaName); - } - if (params.result().largeConstants().containsKey(vespaName)) { - return params.result().largeConstants().get(vespaName); - } - Session.Runner fetched = params.model().session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " + - importedTensors.size()); - return TensorConverter.toVespaTensor(importedTensors.get(0)); - } - - private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) { - AttrValue shapes = params.node().getAttrMap().get("_output_shapes"); - if (shapes == null) - throw new IllegalArgumentException("'" + name + "' is missing a tensor shape"); - Tensor constant = getConstantTensor(params, name); - return createConstant(params, constant); - } - - private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!tensorSize(inputType).equals(tensorSize(outputType))) { - throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); - } - - // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, - // then use the dimension order of the new shape to roll back into a tensor. - // Here we create a transformation tensor that is multiplied with the from tensor to map into - // the new shape. We have to introduce temporary dimension names and rename back if dimension names - // in the new and old tensor type overlap. - - ExpressionNode unrollFrom = unrollTensorExpression(inputType); - ExpressionNode unrollTo = unrollTensorExpression(outputType); - ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); - - TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); - Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - - TensorFunction outputFunction = new Reduce( - new Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); - TypedTensorFunction output = checkNamingConvention(outputType, outputFunction); - return Optional.of(output); - } - - private static ExpressionNode unrollTensorExpression(TensorType type) { - if (type.rank() == 0) { - return new ConstantNode(DoubleValue.zero); - } - List<ExpressionNode> children = new ArrayList<>(); - List<ArithmeticOperator> operators = new ArrayList<>(); - int size = 1; - for (int i = type.dimensions().size() - 1; i >= 0; --i) { - TensorType.Dimension dimension = type.dimensions().get(i); - children.add(0, new ReferenceNode(dimension.name())); - if (size > 1) { - operators.add(0, ArithmeticOperator.MULTIPLY); - children.add(0, new ConstantNode(new DoubleValue(size))); - } - size *= dimensionSize(dimension); - if (i > 0) { - operators.add(0, ArithmeticOperator.PLUS); - } - } - return new ArithmeticNode(children, operators); - } - - private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) { - AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims"); - return keepDimsAttr != null && keepDimsAttr.getB(); - } - - private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension: inputType.dimensions()) { - String name = dimension.name(); - Long size = dimensionSize(dimension); - if (reduceDimensions.contains(name)) { - size = 1L; - } - builder.indexed(name, size); - } - return builder.build(); - } - - private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) { - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - return fixNamingConvention(type, function); - } - } - return new TypedTensorFunction(type, function); - } - - private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) { - TensorType.Builder correctType = new TensorType.Builder(); - List<String> from = new ArrayList<>(); - List<String> to = new ArrayList<>(); - for (int i = 0; i < type.dimensions().size(); ++i) { - String correct = String.format("d%d", i); - String current = type.dimensions().get(i).name(); - if (!current.equals(correct)) { - from.add(current); - to.add(correct); - } - correctType.indexed(correct, dimensionSize(type.dimensions().get(i))); - } - if (from.size() > 0) { - function = new Rename(function, from, to); - type = correctType.build(); - } - return new TypedTensorFunction(type, function); - } - - private static Long tensorSize(TensorType type) { - Long size = 1L; - for (TensorType.Dimension dimension : type.dimensions()) { - size *= dimensionSize(dimension); - } - return size; - } - - private static Long dimensionSize(TensorType.Dimension dim) { - return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); - } - - private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) { - List<Optional<TypedTensorFunction>> inputs = params.inputs(); - if (!inputs.stream().allMatch(Optional::isPresent)) { - return false; - } - if (inputs.size() != expected) { - params.signature().importWarning("Expected " + expected + - " arguments to " + params.node().getOp() + ", but got " + inputs.size()); - return false; - } - return true; - } - - public static String toVespaName(String name) { - return name != null ? name.replace('/', '_') : null; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java deleted file mode 100644 index b88ffce275a..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ /dev/null @@ -1,155 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; - - -/** - * Converts TensorFlow tensors into Vespa tensors. - * - * @author bratseth - */ -public class TensorConverter { - - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { - TensorType type = toVespaTensorType(tfTensor.shape()); - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - private static TensorType toVespaTensorType(long[] shape) { - TensorType.Builder b = new TensorType.Builder(); - int dimensionIndex = 0; - for (long dimensionSize : shape) { - if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed("d" + (dimensionIndex++), dimensionSize); - } - return b.build(); - } - - private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { - switch (tfTensor.dataType()) { - case DOUBLE: return new DoubleValues(tfTensor); - case FLOAT: return new FloatValues(tfTensor); - case BOOL: return new BoolValues(tfTensor); - case UINT8: return new IntValues(tfTensor); - case INT32: return new IntValues(tfTensor); - case INT64: return new LongValues(tfTensor); - default: - throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); - } - } - - /** Allows reading values from buffers of various numeric types as bytes */ - private static abstract class Values { - - private final int size; - - protected Values(int size) { - this.size = size; - } - - abstract double get(int i); - - int size() { return size; } - - } - - private static class DoubleValues extends Values { - - private final DoubleBuffer values; - - DoubleValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = DoubleBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class FloatValues extends Values { - - private final FloatBuffer values; - - FloatValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = FloatBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class BoolValues extends Values { - - private final ByteBuffer values; - - BoolValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = ByteBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class IntValues extends Values { - - private final IntBuffer values; - - IntValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = IntBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - - private static class LongValues extends Values { - - private final LongBuffer values; - - LongValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = LongBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - - @Override - double get(int i) { - return values.get(i); - } - - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index c97ee2b1514..7116d430502 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -2,10 +2,20 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; import com.yahoo.searchlib.rankingexpression.parser.ParseException; -import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.yolean.Exceptions; import org.tensorflow.SavedModelBundle; +import org.tensorflow.Session; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.MetaGraphDef; import org.tensorflow.framework.NodeDef; @@ -24,6 +34,7 @@ import java.util.stream.Collectors; * Converts a saved TensorFlow model into a ranking expression and set of constants. * * @author bratseth + * @author lesters */ public class TensorFlowImporter { @@ -57,196 +68,303 @@ public class TensorFlowImporter { } } - private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) { - TensorFlowModel result = new TensorFlowModel(); + /** + * Imports the TensorFlow graph by first importing the tensor types, then + * finding a suitable set of dimensions names for each + * placeholder/constant/variable, then importing the expressions. + */ + private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) { + TensorFlowModel model = new TensorFlowModel(); + OperationIndex index = new OperationIndex(); + + importSignatures(graph, model); + importNodes(graph, model, index); + findDimensionNames(model, index); + importExpressions(model, index, bundle); + + // nodes with multiple outputs are calculated multiple times. consider adding macros for those. + + reportWarnings(model, index); + + return model; + } + + private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) { for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { - TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName" + String signatureName = signatureEntry.getKey(); + TensorFlowModel.Signature signature = model.signature(signatureName); + + Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); + for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { + String inputName = input.getKey(); + signature.input(inputName, namePartOf(input.getValue().getName())); + } - importInputs(signatureEntry.getValue().getInputsMap(), signature); - for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) { + Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); + for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { String outputName = output.getKey(); - try { - NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef()); - Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, ""); - - // Commonly, there are multiple paths through a TensorFlow graph, for instance for - // training and testing/evaluation. Examples are dropout and batch norm. For Vespa - // we are not concerned with training paths, so we can ignore non-supported operations - // as long as they are on a path that will not be evaluated run time. Operations - // that fail import will not have a value present in the optionals. However, the - // final output node must have value present. It is an error if it does not. - - Optional<TypedTensorFunction> outputFunction = importNode(params); - if (!outputFunction.isPresent()) { - throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n"))); - } - signature.output(outputName, namePartOf(output.getValue().getName())); - } - catch (IllegalArgumentException e) { - signature.skippedOutput(outputName, Exceptions.toMessageString(e)); - } + signature.output(outputName, namePartOf(output.getValue().getName())); } } - return result; } - private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) { - inputInfoMap.forEach((key, value) -> { - String argumentName = namePartOf(value.getName()); - TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape()); - // Arguments are (Placeholder) nodes, so not local to the signature: - signature.owner().argument(argumentName, argumentType); - signature.input(key, argumentName); - }); + private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String inputName : signature.inputs().values()) { + if (inputName.equals(operation.node().getName())) { + return true; + } + } + } + return false; } - /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */ - private Optional<TypedTensorFunction> importNode(Parameters params) { - String nodeName = params.node().getName(); - if (params.imported().containsKey(nodeName)) { - return Optional.of(params.imported().get(nodeName)); + private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + if (outputName.equals(operation.node().getName())) { + return true; + } + } } + return false; + } - Optional<TypedTensorFunction> function = OperationMapper.map(params); - if ( ! function.isPresent()) { - return Optional.empty(); - } - if ( ! controlDependenciesArePresent(params)) { - return Optional.empty(); + private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + importNode(outputName, graph.getGraphDef(), index); + } } - params.imported().put(nodeName, function.get()); + } - try { - // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output - // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree - params.result().expression(nodeName, - new RankingExpression(nodeName, function.get().function().toString())); - return function; + private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) { + if (index.alreadyImported(name)) { + return index.get(name); } - catch (ParseException e) { - throw new RuntimeException("Tensorflow function " + function.get().function() + - " cannot be parsed as a ranking expression", e); + NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph); + List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index); + TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name)); + index.put(name, operation); + + List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index); + if (controlInputs.size() > 0) { + operation.setControlInputs(controlInputs); } - } - private boolean controlDependenciesArePresent(Parameters params) { - return params.node().getInputList().stream() - .filter(TensorFlowImporter::isControlDependency) - .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) - .allMatch(Optional::isPresent); + return operation; } - private static boolean isControlDependency(String nodeName) { - return nodeName.startsWith("^"); + private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) { + return node.getInputList().stream() + .filter(name -> ! isControlDependency(name)) + .map(name -> importNode(name, graph, index)) + .collect(Collectors.toList()); } - private List<Optional<TypedTensorFunction>> importArguments(Parameters params) { - return params.node().getInputList().stream() - .filter(nodeName -> !isControlDependency(nodeName)) - .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName)))) + private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) { + return node.getInputList().stream() + .filter(name -> isControlDependency(name)) + .map(name -> importNode(name, graph, index)) .collect(Collectors.toList()); } - private NodeDef getNode(String name, GraphDef graph) { - return graph.getNodeList().stream() - .filter(node -> node.getName().equals(name)) - .findFirst() - .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'")); + private static boolean isControlDependency(String name) { + return name.startsWith("^"); } - /** - * A method signature input and output has the form name:index. - * This returns the name part without the index. - */ - private static String namePartOf(String name) { - name = name.startsWith("^") ? name.substring(1) : name; - return name.split(":")[0]; + /** Find dimension names to avoid excessive renaming while evaluating the model. */ + private static void findDimensionNames(TensorFlowModel model, OperationIndex index) { + DimensionRenamer renamer = new DimensionRenamer(); + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + addDimensionNameConstraints(index.get(output), renamer); + } + } + renamer.solve(); + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + renameDimensions(index.get(output), renamer); + } + } } - /** - * This return the index part. Indexes are used for nodes with - * multiple outputs. - */ - private static String indexPartOf(String name) { - int i = name.indexOf(":"); - return i < 0 ? "" : name.substring(i + 1); + private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); + operation.addDimensionNameConstraints(renamer); + } } + private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) { + if (operation.type().isPresent()) { + operation.inputs().forEach(input -> renameDimensions(input, renamer)); + operation.renameDimensions(renamer); + } + } - private Parameters createParameters(GraphDef graph, - SavedModelBundle model, - TensorFlowModel result, - TensorFlowModel.Signature signature, - NodeDef node, - String port) { - return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port); + private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String outputName : signature.outputs().values()) { + try { + Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle); + if (!function.isPresent()) { + signature.skippedOutput(outputName, "No valid output function could be found."); + } + } + catch (IllegalArgumentException e) { + signature.skippedOutput(outputName, Exceptions.toMessageString(e)); + } + } + } } - /** Parameter object to hold important data while importing */ - static final class Parameters { + private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { + if (!operation.type().isPresent()) { + return Optional.empty(); + } + if (operation.isConstant()) { + return importConstant(model, operation, bundle); + } - private final TensorFlowImporter owner; - private final GraphDef graph; - private final SavedModelBundle model; - private final TensorFlowModel result; - private final TensorFlowModel.Signature signature; - private final Map<String, TypedTensorFunction> imported; - private final NodeDef node; - private final String port; + importInputExpressions(operation, model, bundle); + importRankingExpression(model, operation); + importInputExpression(model, operation); + importMacroExpression(model, operation); - private Parameters(TensorFlowImporter owner, - GraphDef graph, - SavedModelBundle model, - TensorFlowModel result, - TensorFlowModel.Signature signature, - Map<String, TypedTensorFunction> imported, - NodeDef node, - String port) { - this.owner = owner; - this.graph = graph; - this.model = model; - this.result = result; - this.signature = signature; - this.imported = imported; - this.node = node; - this.port = port; - } + return operation.function(); + } - GraphDef graph() { - return this.graph; + private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { + operation.inputs().forEach(input -> importExpression(input, model, bundle)); + } + + private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.macro().isPresent()) { + model.macro(operation.vespaName(), operation.macro().get()); } + } - SavedModelBundle model() { - return this.model; + private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) { + String name = operation.vespaName(); + if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { + return operation.function(); } - TensorFlowModel result() { - return this.result; + Tensor tensor; + if (operation.getConstantValue().isPresent()) { + Value value = operation.getConstantValue().get(); + if ( ! (value instanceof TensorValue)) { + return operation.function(); // scalar values are inserted directly into the expression + } + tensor = value.asTensor(); + } else { + Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName()); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) { + throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " + + importedTensors.size()); + } + // Here we use the type from the operation, which will have correct dimension names after name resolving + tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get()); + operation.setConstantValue(new TensorValue(tensor)); } - TensorFlowModel.Signature signature() { - return this.signature; + if (tensor.type().rank() == 0 || tensor.size() <= 1) { + model.smallConstant(operation.vespaName(), tensor); + } else { + model.largeConstant(operation.vespaName(), tensor); } + return operation.function(); + } + + private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.function().isPresent()) { + String name = operation.node().getName(); + if (!model.expressions().containsKey(operation.node().getName())) { + TensorFunction function = operation.function().get(); + + // Make sure output adheres to standard naming convention + if (isSignatureOutput(model, operation)) { + OrderedTensorType operationType = operation.type().get(); + OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()); + if ( ! operationType.equals(standardNamingType)) { + List<String> renameFrom = operationType.dimensionNames(); + List<String> renameTo = standardNamingType.dimensionNames(); + function = new Rename(function, renameFrom, renameTo); + } + } - Map<String, TypedTensorFunction> imported() { - return this.imported; + try { + // We add all intermediate nodes imported as separate expressions. Only + // those referenced in a signature output will be used. We parse the + // TensorFunction here to convert it to a RankingExpression tree. + model.expression(name, new RankingExpression(name, function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } + } } + } - NodeDef node() { - return node; + private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) { + if (operation.isInput() && isSignatureInput(model, operation)) { + // All inputs must have dimensions with standard naming convention: d0, d1, ... + OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node()); + model.argument(operation.node().getName(), standardNamingConvention.type()); + model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); } + } - String port() { - return port; + private static void reportWarnings(TensorFlowModel model, OperationIndex index) { + for (TensorFlowModel.Signature signature : model.signatures().values()) { + for (String output : signature.outputs().values()) { + reportWarnings(index.get(output), signature); + } } + } - Parameters copy(NodeDef node, String port) { - return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port); + private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) { + for (String warning : operation.warnings()) { + signature.importWarning(warning); } + } - List<Optional<TypedTensorFunction>> inputs() { - return owner.importArguments(this); + private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { + for (NodeDef node : graph.getNodeList()) { + if (node.getName().equals(name)) { + return node; + } } + throw new IllegalArgumentException("Could not find node '" + name + "'"); + } + + /** + * A method signature input and output has the form name:index. + * This returns the name part without the index. + */ + private static String namePartOf(String name) { + name = name.startsWith("^") ? name.substring(1) : name; + return name.split(":")[0]; + } + + /** + * This return the output port part. Indexes are used for nodes with + * multiple outputs. + */ + private static int portPartOf(String name) { + int i = name.indexOf(":"); + return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); + } + + + private static class OperationIndex { + private final Map<String, TensorFlowOperation> index = new HashMap<>(); + public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } + public TensorFlowOperation get(String key) { return index.get(key); } + public boolean alreadyImported(String key) { return index.containsKey(key); } } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java deleted file mode 100644 index 600225bfe76..00000000000 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java +++ /dev/null @@ -1,30 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package com.yahoo.searchlib.rankingexpression.integration.tensorflow; - -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.functions.TensorFunction; - -/** - * A tensor function returning a specific tensor type - * - * @author bratseth - */ -final class TypedTensorFunction { - - private final TensorType type; - private final TensorFunction function; - - public TypedTensorFunction(TensorType type, TensorFunction function) { - this.type = type; - this.function = function; - } - - public TensorType type() { - return type; - } - - public TensorFunction function() { - return function; - } - -} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java new file mode 100644 index 00000000000..c1665d066a4 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java @@ -0,0 +1,210 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Objects; +import java.util.Optional; + +/** + * A constraint satisfier to find suitable dimension names to reduce the + * amount of necessary renaming during evaluation of an imported model. + * + * @author lesters + */ +public class DimensionRenamer { + + private final String dimensionPrefix; + private final Map<String, List<Integer>> variables = new HashMap<>(); + private final Map<Arc, Constraint> constraints = new HashMap<>(); + private final Map<String, Integer> renames = new HashMap<>(); + + private int iterations = 0; + + public DimensionRenamer() { + this("d"); + } + + public DimensionRenamer(String dimensionPrefix) { + this.dimensionPrefix = dimensionPrefix; + } + + /** + * Add a dimension name variable. + */ + public void addDimension(String name) { + variables.computeIfAbsent(name, d -> new ArrayList<>()); + } + + /** + * Add a constraint between dimension names. + */ + public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) { + Arc arc = new Arc(from, to, operation); + Arc opposite = arc.opposite(); + constraints.put(arc, pred); + constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric + } + + /** + * Retrieve resulting name of dimension after solving for constraints. + */ + public Optional<String> dimensionNameOf(String name) { + if (!renames.containsKey(name)) { + return Optional.empty(); + } + return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name))); + } + + /** + * Perform iterative arc consistency until we have found a solution. After + * an initial iteration, the variables (dimensions) will have multiple + * valid values. Find a single valid assignment by iteratively locking one + * dimension after another, and running the arc consistency algorithm + * multiple times. + * + * This requires having constraints that result in an absolute ordering: + * equals, lesserThan and greaterThan do that, but adding notEquals does + * not typically result in a guaranteed ordering. If that is needed, the + * algorithm below needs to be adapted with a backtracking (tree) search + * to find solutions. + */ + public void solve(int maxIterations) { + initialize(); + + // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts + + for (String dimension : variables.keySet()) { + List<Integer> values = variables.get(dimension); + if (values.size() > 1) { + if (!ac3()) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution."); + } + values.sort(Integer::compare); + variables.put(dimension, Collections.singletonList(values.get(0))); + } + renames.put(dimension, variables.get(dimension).get(0)); + if (iterations > maxIterations) { + throw new IllegalArgumentException("Dimension renamer unable to find a solution within " + + maxIterations + " iterations"); + } + } + + // Todo: handle failure more gracefully: + // If a solution can't be found, look at the operation node in the arc + // with the most remaining constraints, and inject a rename operation. + // Then run this algorithm again. + } + + public void solve() { + solve(100000); + } + + private void initialize() { + for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) { + List<Integer> values = variable.getValue(); + for (int i = 0; i < variables.size(); ++i) { + values.add(i); // invariant: values are in increasing order + } + } + } + + private boolean ac3() { + Deque<Arc> workList = new ArrayDeque<>(constraints.keySet()); + while (!workList.isEmpty()) { + Arc arc = workList.pop(); + iterations += 1; + if (revise(arc)) { + if (variables.get(arc.from).size() == 0) { + return false; // no solution found + } + for (Arc constraint : constraints.keySet()) { + if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) { + workList.add(constraint); + } + } + } + } + return true; + } + + private boolean revise(Arc arc) { + boolean revised = false; + for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) { + Integer from = fromIterator.next(); + boolean satisfied = false; + for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) { + Integer to = toIterator.next(); + if (constraints.get(arc).test(from, to)) { + satisfied = true; + } + } + if (!satisfied) { + fromIterator.remove(); + revised = true; + } + } + return revised; + } + + public interface Constraint { + boolean test(Integer x, Integer y); + } + + public static boolean equals(Integer x, Integer y) { + return Objects.equals(x, y); + } + + public static boolean lesserThan(Integer x, Integer y) { + return x < y; + } + + public static boolean greaterThan(Integer x, Integer y) { + return x > y; + } + + private static class Arc { + + private final String from; + private final String to; + private final TensorFlowOperation operation; + + Arc(String from, String to, TensorFlowOperation operation) { + this.from = from; + this.to = to; + this.operation = operation; + } + + Arc opposite() { + return new Arc(to, from, operation); + } + + @Override + public int hashCode() { + return Objects.hash(from, to); + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof Arc)) { + return false; + } + Arc other = (Arc) obj; + return Objects.equals(from, other.from) && Objects.equals(to, other.to); + } + + @Override + public String toString() { + return String.format("%s -> %s", from, to); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java new file mode 100644 index 00000000000..0fe73fad8ce --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java @@ -0,0 +1,108 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; +import com.yahoo.tensor.functions.ScalarFunctions; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +/** + * Maps from TensorFlow operations to Vespa operations. + * + * @author bratseth + * @author lesters + */ +public class OperationMapper { + + public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) { + switch (node.getOp().toLowerCase()) { + /* + * array ops + */ + case "const": return new Const(node, inputs, port); + case "expanddims": return new ExpandDims(node, inputs, port); + case "identity": return new Identity(node, inputs, port); + case "placeholder": return new Placeholder(node, inputs, port); + case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port); + case "reshape": return new Reshape(node, inputs, port); + case "shape": return new Shape(node, inputs, port); + case "squeeze": return new Squeeze(node, inputs, port); + + /* + * control flow + */ + case "merge": return new Merge(node, inputs, port); + case "switch": return new Switch(node, inputs, port); + + /* + * math ops + */ + case "add": return new Join(node, inputs, port, ScalarFunctions.add()); + case "add_n": return new Join(node, inputs, port, ScalarFunctions.add()); + case "acos": return new Map(node, inputs, port, ScalarFunctions.acos()); + case "div": return new Join(node, inputs, port, ScalarFunctions.divide()); + case "realdiv": return new Join(node, inputs, port, ScalarFunctions.divide()); + case "floor": return new Map(node, inputs, port, ScalarFunctions.floor()); + case "matmul": return new Matmul(node, inputs, port); + case "maximum": return new Join(node, inputs, port, ScalarFunctions.max()); + case "mean": return new Mean(node, inputs, port); + case "reducemean": return new Mean(node, inputs, port); + case "mul": return new Join(node, inputs, port, ScalarFunctions.multiply()); + case "multiply": return new Join(node, inputs, port, ScalarFunctions.multiply()); + case "rsqrt": return new Map(node, inputs, port, ScalarFunctions.rsqrt()); + case "select": return new Select(node, inputs, port); + case "where3": return new Select(node, inputs, port); + case "sigmoid": return new Map(node, inputs, port, ScalarFunctions.sigmoid()); + case "squareddifference": return new Join(node, inputs, port, ScalarFunctions.squareddifference()); + case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract()); + case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract()); + + /* + * nn ops + */ + case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add()); + case "elu": return new Map(node, inputs, port, ScalarFunctions.elu()); + case "relu": return new Map(node, inputs, port, ScalarFunctions.relu()); + case "selu": return new Map(node, inputs, port, ScalarFunctions.selu()); + + /* + * random ops + */ + + /* + * state ops + */ + case "variable": return new Variable(node, inputs, port); + case "variablev2": return new Variable(node, inputs, port); + + /* + * evaluation no-ops + */ + case "stopgradient":return new Identity(node, inputs, port); + case "noop": return new NoOp(node, inputs, port); + } + return new NoOp(node, inputs, port); + } + +} + + + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java new file mode 100644 index 00000000000..3742e443a06 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -0,0 +1,237 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; +import org.tensorflow.framework.TensorShapeProto; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * A Vespa tensor type is ordered by the lexicographical ordering of dimension + * names. TensorFlow tensors have an explicit ordering of their dimensions. + * During import, we need to track the Vespa dimension that matches the + * corresponding TensorFlow dimension as the ordering can change after + * dimension renaming. That is the purpose of this class. + * + * @author lesters + */ +public class OrderedTensorType { + + private final TensorType type; + private final List<TensorType.Dimension> dimensions; + + private final long[] innerSizesTensorFlow; + private final long[] innerSizesVespa; + private final int[] dimensionMap; + + private OrderedTensorType(List<TensorType.Dimension> dimensions) { + this.dimensions = Collections.unmodifiableList(dimensions); + this.type = new TensorType.Builder(dimensions).build(); + this.innerSizesTensorFlow = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + public TensorType type() { + return this.type; + } + + public List<TensorType.Dimension> dimensions() { + return dimensions; + } + + public List<String> dimensionNames() { + return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList()); + } + + private int[] createDimensionMap() { + int numDimensions = dimensions.size(); + if (numDimensions == 0) { + return null; + } + innerSizesTensorFlow[numDimensions - 1] = 1; + innerSizesVespa[numDimensions - 1] = 1; + for (int i = numDimensions - 1; --i >= 0; ) { + innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1]; + innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1]; + } + int[] mapping = new int[numDimensions]; + for (int i = 0; i < numDimensions; ++i) { + TensorType.Dimension dim1 = dimensions().get(i); + for (int j = 0; j < numDimensions; ++j) { + TensorType.Dimension dim2 = type.dimensions().get(j); + if (dim1.equals(dim2)) { + mapping[i] = j; + break; + } + } + } + return mapping; + } + + /** + * When dimension ordering between Vespa and TensorFlow differs, i.e. + * after dimension renaming, use the dimension map to read in values + * so that they are correctly laid out in memory for Vespa. + * Used when importing tensors from TensorFlow. + */ + public int toDirectIndex(int index) { + if (dimensions.size() == 0) { + return 0; + } + if (dimensionMap == null) { + throw new IllegalArgumentException("Dimension map is not available"); + } + int directIndex = 0; + long rest = index; + for (int i = 0; i < dimensions.size(); ++i) { + long address = rest / innerSizesTensorFlow[i]; + directIndex += innerSizesVespa[dimensionMap[i]] * address; + rest %= innerSizesTensorFlow[i]; + } + return directIndex; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || !(obj instanceof OrderedTensorType)) { + return false; + } + OrderedTensorType other = (OrderedTensorType) obj; + if (dimensions.size() != dimensions.size()) { + return false; + } + List<TensorType.Dimension> thisDimensions = this.dimensions(); + List<TensorType.Dimension> otherDimensions = other.dimensions(); + for (int i = 0; i < thisDimensions.size(); ++i) { + if (!thisDimensions.get(i).equals(otherDimensions.get(i))) { + return false; + } + } + return true; + } + + public static void verifyType(NodeDef node, OrderedTensorType type) { + if (type == null) { + return; + } + TensorShapeProto shape = tensorFlowShape(node); + if (shape != null && type.type != null) { + if (shape.getDimCount() != type.type.rank()) { + throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + + "does not match Vespa shape"); + } + for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) { + int vespaIndex = type.dimensionMap[tensorFlowIndex]; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); + TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { + throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + + "does not match Vespa dimensions"); + } + } + } + } + + private static TensorShapeProto tensorFlowShape(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); + if (attrValueList == null) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "does not exist"); + } + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + + "is not of expected type"); + } + List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); + return shapeList.get(0); // support multiple outputs? + } + + public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) { + List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size()); + for (TensorType.Dimension dimension : type.dimensions) { + String oldName = dimension.name(); + Optional<String> newName = renamer.dimensionNameOf(oldName); + if (!newName.isPresent()) + return type; // presumably, already renamed + TensorType.Dimension.Type dimensionType = dimension.type(); + if (dimensionType == TensorType.Dimension.Type.indexedBound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); + } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) { + renamedDimensions.add(TensorType.Dimension.indexed(newName.get())); + } else if (dimensionType == TensorType.Dimension.Type.mapped) { + renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); + } + } + return new OrderedTensorType(renamedDimensions); + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node) { + return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ... + } + + public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { + Builder builder = new Builder(node); + TensorShapeProto shape = tensorFlowShape(node); + for (int i = 0; i < shape.getDimCount(); ++ i) { + String dimensionName = dimensionPrefix + i; + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); + if (tensorFlowDimension.getSize() >= 0) { + builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); + } else { + builder.add(TensorType.Dimension.indexed(dimensionName)); + } + } + return builder.build(); + } + + public static class Builder { + + private final TensorShapeProto shape; + private final List<TensorType.Dimension> dimensions; + + public Builder(NodeDef node) { + this.shape = tensorFlowShape(node); + this.dimensions = new ArrayList<>(shape.getDimCount()); + } + + public Builder add(TensorType.Dimension vespaDimension) { + int index = dimensions.size(); + TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index); + long size = tensorFlowDimension.getSize(); + if (size >= 0) { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension types"); + } + if (!vespaDimension.size().isPresent()) { + throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + + "not have a size"); + } + if (vespaDimension.size().get() != size) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get()); + } + } else { + if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { + throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + + "dimension types"); + } + } + this.dimensions.add(vespaDimension); + return this; + } + + public OrderedTensorType build() { + return new OrderedTensorType(dimensions); + } + + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java new file mode 100644 index 00000000000..3f55e622fdf --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java @@ -0,0 +1,224 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer; + +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import org.tensorflow.framework.TensorProto; + +import java.nio.ByteBuffer; +import java.nio.DoubleBuffer; +import java.nio.FloatBuffer; +import java.nio.IntBuffer; +import java.nio.LongBuffer; + + +/** + * Converts TensorFlow tensors into Vespa tensors. + * + * @author bratseth + * @author lesters + */ +public class TensorConverter { + + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { + return toVespaTensor(tfTensor, "d"); + } + + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); + Values values = readValuesOf(tfTensor); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + for (int i = 0; i < values.size(); i++) + builder.cellByDirectIndex(i, values.get(i)); + return builder.build(); + } + + public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { + Values values = readValuesOf(tfTensor); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); + for (int i = 0; i < values.size(); i++) { + builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i)); + } + return builder.build(); + } + + public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) { + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); + Values values = readValuesOf(tensorProto); + for (int i = 0; i < values.size(); ++i) { + builder.cellByDirectIndex(i, values.get(i)); + } + return builder.build(); + } + + private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(); + int dimensionIndex = 0; + for (long dimensionSize : shape) { + if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... + b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); + } + return b.build(); + } + + public static Long tensorSize(TensorType type) { + Long size = 1L; + for (TensorType.Dimension dimension : type.dimensions()) { + size *= dimensionSize(dimension); + } + return size; + } + + public static Long dimensionSize(TensorType.Dimension dim) { + return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); + } + + private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { + switch (tfTensor.dataType()) { + case DOUBLE: return new DoubleValues(tfTensor); + case FLOAT: return new FloatValues(tfTensor); + case BOOL: return new BoolValues(tfTensor); + case UINT8: return new IntValues(tfTensor); + case INT32: return new IntValues(tfTensor); + case INT64: return new LongValues(tfTensor); + } + throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + + tfTensor.dataType() + " to a Vespa tensor"); + } + + private static Values readValuesOf(TensorProto tensorProto) { + switch (tensorProto.getDtype()) { + case DT_BOOL: + return new ProtoBoolValues(tensorProto); + case DT_HALF: + return new ProtoHalfValues(tensorProto); + case DT_INT16: + case DT_INT32: + return new ProtoIntValues(tensorProto); + case DT_INT64: + return new ProtoInt64Values(tensorProto); + case DT_FLOAT: + return new ProtoFloatValues(tensorProto); + case DT_DOUBLE: + return new ProtoDoubleValues(tensorProto); + } + throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); + } + + /** Allows reading values from buffers of various numeric types as bytes */ + private static abstract class Values { + abstract double get(int i); + abstract int size(); + } + + private static abstract class TensorFlowValues extends Values { + private final int size; + TensorFlowValues(int size) { + this.size = size; + } + @Override int size() { return this.size; } + } + + private static class DoubleValues extends TensorFlowValues { + private final DoubleBuffer values; + DoubleValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = DoubleBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class FloatValues extends TensorFlowValues { + private final FloatBuffer values; + FloatValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = FloatBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class BoolValues extends TensorFlowValues { + private final ByteBuffer values; + BoolValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = ByteBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class IntValues extends TensorFlowValues { + private final IntBuffer values; + IntValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = IntBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static class LongValues extends TensorFlowValues { + private final LongBuffer values; + LongValues(org.tensorflow.Tensor<?> tfTensor) { + super(tfTensor.numElements()); + values = LongBuffer.allocate(tfTensor.numElements()); + tfTensor.writeTo(values); + } + @Override double get(int i) { + return values.get(i); + } + } + + private static abstract class ProtoValues extends Values { + protected final TensorProto tensorProto; + protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } + } + + private static class ProtoBoolValues extends ProtoValues { + ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } + @Override int size() { return tensorProto.getBoolValCount(); } + } + + private static class ProtoHalfValues extends ProtoValues { + ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getHalfVal(i); } + @Override int size() { return tensorProto.getHalfValCount(); } + } + + private static class ProtoIntValues extends ProtoValues { + ProtoIntValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getIntVal(i); } + @Override int size() { return tensorProto.getIntValCount(); } + } + + private static class ProtoInt64Values extends ProtoValues { + ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getInt64Val(i); } + @Override int size() { return tensorProto.getInt64ValCount(); } + } + + private static class ProtoFloatValues extends ProtoValues { + ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getFloatVal(i); } + @Override int size() { return tensorProto.getFloatValCount(); } + } + + private static class ProtoDoubleValues extends ProtoValues { + ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); } + @Override double get(int i) { return tensorProto.getDoubleVal(i); } + @Override int size() { return tensorProto.getDoubleValCount(); } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java new file mode 100644 index 00000000000..7decef51ab7 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java @@ -0,0 +1,93 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Const extends TensorFlowOperation { + + public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + setConstantValue(value()); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + } + + @Override + public Optional<TensorFunction> function() { + if (function == null) { + function = lazyGetFunction(); + } + return Optional.ofNullable(function); + } + + @Override + protected TensorFunction lazyGetFunction() { + ExpressionNode expressionNode; + if (type.type().rank() == 0 && getConstantValue().isPresent()) { + expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue()); + } else { + expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")"); + } + return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + setConstantValue(value()); + } + + @Override + public boolean isConstant() { + return true; + } + + private Value value() { + if (!node.getAttrMap().containsKey("value")) { + throw new IllegalArgumentException("Node '" + node.getName() + "' of type " + + "const has missing 'value' attribute"); + } + AttrValue attrValue = node.getAttrMap().get("value"); + if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { + return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type())); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.B) { + return new BooleanValue(attrValue.getB()); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.I) { + return new DoubleValue(attrValue.getI()); + } + if (attrValue.getValueCase() == AttrValue.ValueCase.F) { + return new DoubleValue(attrValue.getF()); + } + throw new IllegalArgumentException("Requesting value of constant in " + + node.getName() + " but type is not recognized."); + } +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java new file mode 100644 index 00000000000..c1ad21f41d8 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java @@ -0,0 +1,107 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +public class ExpandDims extends TensorFlowOperation { + + private List<String> expandDimensions; + + public ExpandDims(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + + TensorFlowOperation axisOperation = inputs().get(1); + if (!axisOperation.getConstantValue().isPresent()) { + throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + "axis must be a constant."); + } + Tensor axis = axisOperation.getConstantValue().get().asTensor(); + if (axis.type().rank() != 0) { + throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " + + "axis argument must be a scalar."); + } + + OrderedTensorType inputType = inputs.get(0).type().get(); + int dimensionToInsert = (int)axis.asDouble(); + if (dimensionToInsert < 0) { + dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; + } + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + expandDimensions = new ArrayList<>(); + int dimensionIndex = 0; + for (TensorType.Dimension dimension : inputType.dimensions()) { + if (dimensionIndex == dimensionToInsert) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + typeBuilder.add(dimension); + dimensionIndex++; + } + + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(2)) { + return null; + } + + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : expandDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); + for (String name : expandDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + expandDimensions = renamedDimensions; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java new file mode 100644 index 00000000000..d79707a42e6 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java @@ -0,0 +1,30 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Identity extends TensorFlowOperation { + + public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) + return null; + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) + return null; + return inputs.get(0).function().orElse(null); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java new file mode 100644 index 00000000000..aa27ba2684d --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java @@ -0,0 +1,79 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleBinaryOperator; + +public class Join extends TensorFlowOperation { + + private final DoubleBinaryOperator operator; + + public Join(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) { + super(node, inputs, port); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); + OrderedTensorType out = a.type().rank() >= b.type().rank() ? a : b; + return out; + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + + // The dimension renaming below takes care of broadcasting. + + return new com.yahoo.tensor.functions.Join(aFunction.get(), bFunction.get(), operator); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + + // Well now we have potentially entered the wonderful world of "broadcasting" + // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html + // I'm not able to extract from that any unambiguous specification of which dimensions + // should be "stretched" when the tensor do not have the same number of dimensions. + // From trying this with TensorFlow it appears that the second tensor is matched to the + // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. + // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). + + TensorType a = inputs.get(0).type().get().type(); + TensorType b = inputs.get(1).type().get().type(); + if (a.rank() < b.rank()) { + TensorType temp = a; + a = b; + b = temp; + } + int sizeDifference = a.rank() - b.rank(); + for (int i = 0; i < b.rank(); ++i) { + String bDim = b.dimensions().get(i).name(); + String aDim = a.dimensions().get(i + sizeDifference).name(); + renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java new file mode 100644 index 00000000000..105d65b3d69 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java @@ -0,0 +1,38 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; +import java.util.function.DoubleUnaryOperator; + +public class Map extends TensorFlowOperation { + + private final DoubleUnaryOperator operator; + + public Map(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) { + super(node, inputs, port); + this.operator = operator; + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + return inputs.get(0).type().get(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + Optional<TensorFunction> input = inputs.get(0).function(); + return new com.yahoo.tensor.functions.Map(input.get(), operator); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java new file mode 100644 index 00000000000..ac4f78653d6 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java @@ -0,0 +1,74 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Matmul extends TensorFlowOperation { + + public Matmul(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node); + typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); + typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); + return typeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + OrderedTensorType aType = inputs.get(0).type().get(); + OrderedTensorType bType = inputs.get(1).type().get(); + if (aType.type().rank() < 2 || bType.type().rank() < 2) + throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2"); + if (aType.type().rank() != bType.type().rank()) + throw new IllegalArgumentException("Tensors in matmul must have the same rank"); + + Optional<TensorFunction> aFunction = inputs.get(0).function(); + Optional<TensorFunction> bFunction = inputs.get(1).function(); + if (!aFunction.isPresent() || !bFunction.isPresent()) { + return null; + } + return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(2)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // The second dimension of a should have the same name as the first dimension of b + renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this); + + // The first dimension of a should have a different name than the second dimension of b + renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this); + + // For efficiency, the dimensions to join over should be innermost - soft constraint + renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); + renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java new file mode 100644 index 00000000000..dfe0796d9b8 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java @@ -0,0 +1,112 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +public class Mean extends TensorFlowOperation { + + private List<String> reduceDimensions; + + public Mean(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFlowOperation reductionIndices = inputs.get(1); + if (!reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Mean in " + node.getName() + ": " + + "reduction indices must be a constant."); + } + Tensor indices = reductionIndices.getConstantValue().get().asTensor(); + reduceDimensions = new ArrayList<>(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int dimensionIndex = cell.getValue().intValue(); + if (dimensionIndex < 0) { + dimensionIndex = inputType.dimensions().size() - dimensionIndex; + } + reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name()); + } + return reducedType(inputType, shouldKeepDimensions()); + } + + // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity. + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFunction inputFunction = inputs.get(0).function().get(); + TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); + if (shouldKeepDimensions()) { + // multiply with a generated tensor created from the reduced dimensions + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); + } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply()); + } + return output; + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size()); + for (String name : reduceDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + reduceDimensions = renamedDimensions; + } + + private boolean shouldKeepDimensions() { + AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims"); + return keepDimsAttr != null && keepDimsAttr.getB(); + } + + private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if (!reduceDimensions.contains(dimension.name())) { + builder.add(dimension); + } else if (keepDimensions) { + builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java new file mode 100644 index 00000000000..d3561716725 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java @@ -0,0 +1,36 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Merge extends TensorFlowOperation { + + public Merge(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + for (TensorFlowOperation operation : inputs) { + if (operation.type().isPresent()) { + return operation.type().get(); + } + } + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + for (TensorFlowOperation operation : inputs) { + if (operation.function().isPresent()) { + return operation.function().get(); + } + } + return null; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java new file mode 100644 index 00000000000..acf5d13b057 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java @@ -0,0 +1,32 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; + +public class NoOp extends TensorFlowOperation { + + public NoOp(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, Collections.emptyList(), port); // don't propagate inputs + } + + @Override + protected OrderedTensorType lazyGetType() { + return null; + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java new file mode 100644 index 00000000000..dadce395faf --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java @@ -0,0 +1,57 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.VariableTensor; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Placeholder extends TensorFlowOperation { + + private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ... + + public Placeholder(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + standardNamingType = OrderedTensorType.fromTensorFlowType(node); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type()); + if (!standardNamingType.equals(type)) { + List<String> renameFrom = standardNamingType.dimensionNames(); + List<String> renameTo = type.dimensionNames(); + output = new Rename(output, renameFrom, renameTo); + } + return output; + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isInput() { + return true; + } + + @Override + public boolean isConstant() { + return false; + } + + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java new file mode 100644 index 00000000000..ab091b77a65 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java @@ -0,0 +1,50 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class PlaceholderWithDefault extends TensorFlowOperation { + + public PlaceholderWithDefault(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + return inputs().get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + // This should be a call to the macro we add below, but for now + // we treat this as as identity function and just pass the constant. + return inputs.get(0).function().orElse(null); + } + + @Override + public Optional<RankingExpression> macro() { + // For now, it is much more efficient to assume we always will return + // the default value, as we can prune away large parts of the expression + // tree by having it calculated as a constant. If a case arises where + // it is important to support this, implement this. + return Optional.empty(); + } + + @Override + public boolean isConstant() { + return true; // not true if we add to macro + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java new file mode 100644 index 00000000000..9b3e28ce56b --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java @@ -0,0 +1,135 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; +import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; +import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; +import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; + +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; + +public class Reshape extends TensorFlowOperation { + + public Reshape(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + TensorFlowOperation newShape = inputs.get(1); + if (!newShape.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Reshape in " + node.getName() + ": " + + "shape input must be a constant."); + } + Tensor shape = newShape.getConstantValue().get().asTensor(); + + OrderedTensorType inputType = inputs.get(0).type().get(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node); + int dimensionIndex = 0; + for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { + Tensor.Cell cell = cellIterator.next(); + int size = cell.getValue().intValue(); + if (size < 0) { + size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / + tensorSize(inputType.type()).intValue(); + } + outputTypeBuilder.add(TensorType.Dimension.indexed( + String.format("%s_%d", vespaName(), dimensionIndex), size)); + dimensionIndex++; + } + return outputTypeBuilder.build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputTypesPresent(2)) { + return null; + } + if (!allInputFunctionsPresent(2)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + TensorFunction inputFunction = inputs.get(0).function().get(); + return reshape(inputFunction, inputType.type(), type.type()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { + if (!tensorSize(inputType).equals(tensorSize(outputType))) { + throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); + } + + // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, + // then use the dimension order of the new shape to roll back into a tensor. + // Here we create a transformation tensor that is multiplied with the from tensor to map into + // the new shape. We have to introduce temporary dimension names and rename back if dimension names + // in the new and old tensor type overlap. + + ExpressionNode unrollFrom = unrollTensorExpression(inputType); + ExpressionNode unrollTo = unrollTensorExpression(outputType); + ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo); + + TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); + Generate transformTensor = new Generate(transformationType, + new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); + + TensorFunction outputFunction = new Reduce( + new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + + return outputFunction; + } + + private static ExpressionNode unrollTensorExpression(TensorType type) { + if (type.rank() == 0) { + return new ConstantNode(DoubleValue.zero); + } + List<ExpressionNode> children = new ArrayList<>(); + List<ArithmeticOperator> operators = new ArrayList<>(); + int size = 1; + for (int i = type.dimensions().size() - 1; i >= 0; --i) { + TensorType.Dimension dimension = type.dimensions().get(i); + children.add(0, new ReferenceNode(dimension.name())); + if (size > 1) { + operators.add(0, ArithmeticOperator.MULTIPLY); + children.add(0, new ConstantNode(new DoubleValue(size))); + } + size *= TensorConverter.dimensionSize(dimension); + if (i > 0) { + operators.add(0, ArithmeticOperator.PLUS); + } + } + return new ArithmeticNode(children, operators); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java new file mode 100644 index 00000000000..6a29d428cf3 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java @@ -0,0 +1,89 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.function.DoubleBinaryOperator; + +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize; +import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize; + +public class Select extends TensorFlowOperation { + + public Select(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(3)) { + return null; + } + OrderedTensorType a = inputs.get(1).type().get(); + OrderedTensorType b = inputs.get(2).type().get(); + if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) { + throw new IllegalArgumentException("'Select': input tensors must have the same shape"); + } + return a; + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(3)) { + return null; + } + TensorFlowOperation conditionOperation = inputs().get(0); + TensorFunction a = inputs().get(1).function().get(); + TensorFunction b = inputs().get(2).function().get(); + + // Shortcut: if we know during import which tensor to select, do that directly here. + if (conditionOperation.getConstantValue().isPresent()) { + Tensor condition = conditionOperation.getConstantValue().get().asTensor(); + if (condition.type().rank() == 0) { + return ((int) condition.asDouble() == 0) ? b : a; + } + if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) { + return condition.cellIterator().next().getValue().intValue() == 0 ? b : a; + } + } + + // The task is to select cells from 'x' or 'y' based on 'condition'. + // If 'condition' is 0 (false), select from 'y', if 1 (true) select + // from 'x'. We do this by individually joining 'x' and 'y' with + // 'condition', and then joining the resulting two tensors. + + TensorFunction conditionFunction = conditionOperation.function().get(); + TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply()); + TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() { + @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); } + @Override public String toString() { return "f(a,b)(a * (1-b))"; } + }); + return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add()); + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + if (!allInputTypesPresent(3)) { + return; + } + List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions(); + List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions(); + + String aDim0 = aDimensions.get(0).name(); + String aDim1 = aDimensions.get(1).name(); + String bDim0 = bDimensions.get(0).name(); + String bDim1 = bDimensions.get(1).name(); + + // These tensors should have the same dimension names + renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this); + renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java new file mode 100644 index 00000000000..8f4313022e0 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java @@ -0,0 +1,55 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Shape extends TensorFlowOperation { + + public Shape(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + createConstantValue(); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + return new OrderedTensorType.Builder(node) + .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) + .build(); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public boolean isConstant() { + return true; + } + + private void createConstantValue() { + if (!allInputTypesPresent(1)) { + return; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type()); + List<TensorType.Dimension> inputDimensions = inputType.dimensions(); + for (int i = 0; i < inputDimensions.size(); i++) { + builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L)); + } + this.setConstantValue(new TensorValue(builder.build())); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java new file mode 100644 index 00000000000..d7750b52fc3 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java @@ -0,0 +1,84 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class Squeeze extends TensorFlowOperation { + + private List<String> squeezeDimensions; + + public Squeeze(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(1)) { + return null; + } + OrderedTensorType inputType = inputs.get(0).type().get(); + squeezeDimensions = new ArrayList<>(); + + AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims"); + if (squeezeDimsAttr == null) { + squeezeDimensions = inputType.type().dimensions().stream(). + filter(dim -> TensorConverter.dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } else { + squeezeDimensions = squeezeDimsAttr.getList().getIList().stream(). + map(i -> i < 0 ? inputType.type().dimensions().size() - i : i). + map(i -> inputType.type().dimensions().get(i.intValue())). + filter(dim -> TensorConverter.dimensionSize(dim) == 1). + map(TensorType.Dimension::name). + collect(Collectors.toList()); + } + return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType); + } + + @Override + protected TensorFunction lazyGetFunction() { + if (!allInputFunctionsPresent(1)) { + return null; + } + TensorFunction inputFunction = inputs.get(0).function().get(); + return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); + } + + @Override + public void renameDimensions(DimensionRenamer renamer) { + super.renameDimensions(renamer); + List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size()); + for (String name : squeezeDimensions) { + Optional<String> newName = renamer.dimensionNameOf(name); + if (!newName.isPresent()) { + return; // presumably, already renamed + } + renamedDimensions.add(newName.get()); + } + squeezeDimensions = renamedDimensions; + } + + private OrderedTensorType reducedType(OrderedTensorType inputType) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node); + for (TensorType.Dimension dimension: inputType.type().dimensions()) { + if ( ! squeezeDimensions.contains(dimension.name())) { + builder.add(dimension); + } + } + return builder.build(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java new file mode 100644 index 00000000000..1cc0e1936eb --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java @@ -0,0 +1,48 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; +import java.util.Optional; + +public class Switch extends TensorFlowOperation { + + public Switch(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + if (!allInputTypesPresent(2)) { + return null; + } + Optional<OrderedTensorType> predicate = inputs.get(1).type(); + if (predicate.get().type().rank() != 0) { + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + "predicate must be a scalar"); + } + return inputs.get(0).type().orElse(null); + } + + @Override + protected TensorFunction lazyGetFunction() { + TensorFlowOperation predicateOperation = inputs().get(1); + if (!predicateOperation.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + "predicate must be a constant"); + } + if (port < 0 || port > 1) { + throw new IllegalArgumentException("Switch in " + node.getName() + ": " + + "choice should be boolean"); + } + + double predicate = predicateOperation.getConstantValue().get().asDouble(); + return predicate == port ? inputs().get(0).function().get() : null; + } + +} + + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java new file mode 100644 index 00000000000..fd9dfd167fb --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -0,0 +1,136 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Optional; +import java.util.function.Function; + +/** + * Wraps a TensorFlow node and produces the respective Vespa tensor operation. + * During import, a graph of these operations are constructed. Then, the + * types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper + * Vespa expressions can be extracted. + * + * @author lesters + */ +public abstract class TensorFlowOperation { + + protected final NodeDef node; + protected final int port; + protected final List<TensorFlowOperation> inputs; + protected final List<TensorFlowOperation> outputs = new ArrayList<>(); + protected final List<String> importWarnings = new ArrayList<>(); + + protected OrderedTensorType type; + protected TensorFunction function; + + private Value constantValue = null; + private List<TensorFlowOperation> controlInputs = Collections.emptyList(); + + TensorFlowOperation(NodeDef node, List<TensorFlowOperation> inputs, int port) { + this.node = node; + this.port = port; + this.inputs = Collections.unmodifiableList(inputs); + this.inputs.forEach(i -> i.outputs.add(this)); + } + + protected abstract OrderedTensorType lazyGetType(); + protected abstract TensorFunction lazyGetFunction(); + + /** Returns the Vespa tensor type of this operation if it exists */ + public Optional<OrderedTensorType> type() { + if (type == null) { + type = lazyGetType(); + } + OrderedTensorType.verifyType(node, type); + return Optional.ofNullable(type); + } + + /** Returns the Vespa tensor function implementing all operations from this node with inputs */ + public Optional<TensorFunction> function() { + if (function == null) { + if (isConstant()) { + ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")"); + function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + } else { + function = lazyGetFunction(); + } + } + return Optional.ofNullable(function); + } + + /** Return TensorFlow node */ + public NodeDef node() { return node; } + + /** Return unmodifiable list of inputs */ + public List<TensorFlowOperation> inputs() { return inputs; } + + /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */ + public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); } + + /** Returns a Vespa ranking expression that should be added as a macro */ + public Optional<RankingExpression> macro() { return Optional.empty(); } + + /** Add dimension name constraints for this operation */ + public void addDimensionNameConstraints(DimensionRenamer renamer) { } + + /** Performs dimension rename for this operation */ + public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); } + + /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ + public boolean isInput() { return false; } + + /** Return true if this node is constant */ + public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); } + + /** Sets the constant value */ + public void setConstantValue(Value value) { constantValue = value; } + + /** Gets the constant value if it exists */ + public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); } + + /** Sets the external control inputs */ + public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; } + + /** Retrieve the control inputs for this operation */ + public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); } + + /** Retrieve the valid Vespa name of this node */ + public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } + + /** Retrieve the list of warnings produced during its lifetime */ + public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } + + boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) { + if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) { + return false; + } + if (inputs.size() != expected) { + throw new IllegalArgumentException("Expected " + expected + " inputs " + + "for '" + node.getName() + "', got " + inputs.size()); + } + return inputs.stream().map(func).allMatch(Optional::isPresent); + } + + boolean allInputTypesPresent(int expected) { + return verifyInputs(expected, TensorFlowOperation::type); + } + + boolean allInputFunctionsPresent(int expected) { + return verifyInputs(expected, TensorFlowOperation::function); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java new file mode 100644 index 00000000000..6f377c4bda2 --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java @@ -0,0 +1,40 @@ +// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.TensorFunction; +import org.tensorflow.framework.NodeDef; + +import java.util.List; + +public class Variable extends TensorFlowOperation { + + public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) { + super(node, inputs, port); + } + + @Override + protected OrderedTensorType lazyGetType() { + return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_"); + } + + @Override + protected TensorFunction lazyGetFunction() { + return null; // will be added by function() since this is constant. + } + + @Override + public void addDimensionNameConstraints(DimensionRenamer renamer) { + for (TensorType.Dimension dimension : type.type().dimensions()) { + renamer.addDimension(dimension.name()); + } + } + + @Override + public boolean isConstant() { + return true; + } + +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java new file mode 100644 index 00000000000..ebcfde54c70 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java @@ -0,0 +1,49 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer; +import org.junit.Test; + +import static junit.framework.TestCase.assertTrue; + +public class DimensionRenamerTest { + + @Test + public void testMnistRenaming() { + DimensionRenamer renamer = new DimensionRenamer(); + + renamer.addDimension("first_dimension_of_x"); + renamer.addDimension("second_dimension_of_x"); + renamer.addDimension("first_dimension_of_w"); + renamer.addDimension("second_dimension_of_w"); + renamer.addDimension("first_dimension_of_b"); + + // which dimension to join on matmul + renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null); + + // other dimensions in matmul can't be equal + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null); + + // for efficiency, put dimension to join on innermost + renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null); + renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null); + + // bias + renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null); + + renamer.solve(); + + String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get(); + String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get(); + String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get(); + String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get(); + String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get(); + + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0); + assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0); + assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0); + assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0); + assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0); + + + } +} diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index 3b25bfe1b1e..f64d697d9b9 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -18,11 +18,6 @@ public class DropoutImportTestCase { public void testDropoutImport() { TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved"); - // Check (provided) macros - assertEquals(1, model.get().macros().size()); - assertTrue(model.get().macros().containsKey("training_input")); - assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString()); - // Check required macros assertEquals(1, model.get().requiredMacros().size()); assertTrue(model.get().requiredMacros().containsKey("X")); @@ -37,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java index ad5abd4c03d..60dd3865aa1 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java @@ -22,15 +22,15 @@ public class MnistSoftmaxImportTestCase { // Check constants assertEquals(2, model.get().largeConstants().size()); - Tensor constant0 = model.get().largeConstants().get("Variable"); + Tensor constant0 = model.get().largeConstants().get("Variable_read"); assertNotNull(constant0); - assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(), + assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(), constant0.type()); assertEquals(7840, constant0.size()); - Tensor constant1 = model.get().largeConstants().get("Variable_1"); + Tensor constant1 = model.get().largeConstants().get("Variable_1_read"); assertNotNull(constant1); - assertEquals(new TensorType.Builder().indexed("d0", 10).build(), + assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type()); assertEquals(10, constant1.size()); @@ -59,12 +59,10 @@ public class MnistSoftmaxImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))", + assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))", output.getRoot().toString()); // Test execution - model.assertEqualResult("Placeholder", "Variable/read"); - model.assertEqualResult("Placeholder", "Variable_1/read"); model.assertEqualResult("Placeholder", "MatMul"); model.assertEqualResult("Placeholder", "add"); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index ae7714b271a..1691756a64d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.tensorflow.SavedModelBundle; @@ -47,8 +48,11 @@ public class TestableTensorFlowModel { private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { Session.Runner runner = model.session().runner(); - org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, - FloatBuffer.allocate(d0Size * d1Size)); + FloatBuffer fb = FloatBuffer.allocate(d0Size * d1Size); + for (int i = 0; i < d1Size; ++i) { + fb.put(i, (float)(i * 1.0 / d1Size)); + } + org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb); runner.feed(inputName, placeholder); List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run(); assertEquals(1, results.size()); @@ -66,7 +70,7 @@ public class TestableTensorFlowModel { Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build()); for (int d0 = 0; d0 < d0Size; d0++) for (int d1 = 0; d1 < d1Size; d1++) - b.cell(0, d0, d1); + b.cell(d1 * 1.0 / d1Size, d0, d1); return b.build(); } |