summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-02-05 16:04:42 +0100
committerLester Solbakken <lesters@oath.com>2018-02-22 12:54:34 +0100
commitb1f46fcd0495dbce905fb8b7318781f4cf5965a7 (patch)
treed0a0506fe66e5af4af2a927101a0eb9ed9420d38 /searchlib
parente307df56eaaf5b0ebca5aefb7f7e0c5c3a970bdb (diff)
Refactor TensorFlow import and add dimension renaming.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java132
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java715
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java155
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java396
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java210
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java108
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java237
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java224
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java93
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java107
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java30
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java79
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java38
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java74
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java112
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java36
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java32
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java57
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java50
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java135
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java89
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java55
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java84
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java48
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java136
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java40
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java49
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java10
31 files changed, 2393 insertions, 1187 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java
deleted file mode 100644
index 5f0c016881a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/AttrValueConverter.java
+++ /dev/null
@@ -1,132 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorProto;
-import org.tensorflow.framework.TensorShapeProto;
-
-/**
- * @author lesters
- */
-public class AttrValueConverter {
-
- public static Tensor toVespaTensor(NodeDef tfNode, String attr) {
- if (!tfNode.getAttrMap().containsKey(attr)) {
- throw new IllegalArgumentException(tfNode.getName() + " has no attribute called " + attr);
- }
- AttrValue attrValue = tfNode.getAttrMap().get(attr);
- switch (attrValue.getValueCase()) {
- case TENSOR:
- return buildFromTensor(attrValue);
- case B:
- return buildFromSingleValue(attrValue.getB() ? 1.0 : 0.0);
- case F:
- return buildFromSingleValue(attrValue.getF());
- case I:
- return buildFromSingleValue(attrValue.getI());
- }
-
- throw new IllegalArgumentException(tfNode.getName() +
- ": unsupported attribute type: '" + attrValue.getValueCase().toString() + "'");
- }
-
- private static Tensor buildFromSingleValue(double value) {
- TensorType type = new TensorType.Builder().build();
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- builder.cellByDirectIndex(0, value);
- return builder.build();
- }
-
- private static Tensor buildFromTensor(AttrValue attrValue) {
- TensorProto tensorProto = attrValue.getTensor();
- TensorType type = toVespaTensorType(tensorProto.getTensorShape());
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- Values values = valuesOf(tensorProto);
- for (int i = 0; i < values.size(); ++i) {
- builder.cellByDirectIndex(i, values.get(i));
- }
- Tensor tensor = builder.build();
- return tensor;
- }
-
- private static Values valuesOf(TensorProto tensorProto) {
- switch (tensorProto.getDtype()) {
- case DT_BOOL:
- return new BoolValues(tensorProto);
- case DT_HALF:
- return new HalfValues(tensorProto);
- case DT_INT16:
- case DT_INT32:
- return new IntValues(tensorProto);
- case DT_INT64:
- return new Int64Values(tensorProto);
- case DT_FLOAT:
- return new FloatValues(tensorProto);
- case DT_DOUBLE:
- return new DoubleValues(tensorProto);
- }
-
- throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
- }
-
- public static TensorType toVespaTensorType(TensorShapeProto shapeProto) {
- TensorType.Builder b = new TensorType.Builder();
- for (TensorShapeProto.Dim dimension : shapeProto.getDimList()) {
- int dimensionSize = (int)dimension.getSize();
- if (dimensionSize >= 0)
- b.indexed("d" + b.rank(), dimensionSize);
- else
- b.indexed("d" + b.rank()); // unbound size
- }
- return b.build();
- }
-
- private static abstract class Values {
- protected final TensorProto tensorProto;
- protected Values(TensorProto tensorProto) { this.tensorProto = tensorProto; }
- abstract double get(int i);
- abstract int size();
- }
-
- private static class BoolValues extends Values {
- BoolValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
- @Override int size() { return tensorProto.getBoolValCount(); }
- }
-
- private static class HalfValues extends Values {
- HalfValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getHalfVal(i); }
- @Override int size() { return tensorProto.getHalfValCount(); }
- }
-
- private static class IntValues extends Values {
- IntValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getIntVal(i); }
- @Override int size() { return tensorProto.getIntValCount(); }
- }
-
- private static class Int64Values extends Values {
- Int64Values(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getInt64Val(i); }
- @Override int size() { return tensorProto.getInt64ValCount(); }
- }
-
- private static class FloatValues extends Values {
- FloatValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getFloatVal(i); }
- @Override int size() { return tensorProto.getFloatValCount(); }
- }
-
- private static class DoubleValues extends Values {
- DoubleValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getDoubleVal(i); }
- @Override int size() { return tensorProto.getDoubleValCount(); }
- }
-
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
deleted file mode 100644
index ef82045e771..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OperationMapper.java
+++ /dev/null
@@ -1,715 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.google.common.collect.ImmutableList;
-import com.yahoo.searchlib.rankingexpression.RankingExpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
-import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
-import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
-import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import com.yahoo.tensor.functions.Softmax;
-import com.yahoo.tensor.functions.TensorFunction;
-import org.tensorflow.Session;
-import org.tensorflow.framework.AttrValue;
-
-import java.util.ArrayList;
-import java.util.Iterator;
-import java.util.List;
-import java.util.Optional;
-import java.util.function.DoubleBinaryOperator;
-import java.util.function.DoubleUnaryOperator;
-import java.util.function.Function;
-import java.util.stream.Collectors;
-import java.util.stream.Stream;
-
-/**
- * Contains mappings of TensorFlow operations to the corresponding Vespa tensor functions.
- *
- * @author bratseth
- * @author lesters
- */
-class OperationMapper {
-
- // A note on conversion from implicitly numbered to explicitly named dimensions:
- //
- // Vespa tensor dimensions are explicitly named and thus have an explicit notion of being
- // 'the same' or not of some dimension in another tensor. Since TF lacks this, each operation
- // comes with a built-in definition of sameness. We mirror this by wrapping the Vespa tensor operation
- // around dimension renaming operations which mirrors those built into the TF operation definitions.
- //
- // To do this we need a naming convention: We maintain a naming of each tensor where the 'outermost'
- // dimension is named 'd0', the second outer most 'd1' and so on. Arguments are renamed to match the operation
- // and the result is then renamed again (if necessary) to recover this convention across a full nested
- // computation.
- //
- // This requires us to track tensor types throughout the conversion.
-
-
- // Supported TensorFlow operations
- enum Operation {
-
- // TODO: move the implementations to specific files as we support more operations
-
- /*
- * array ops
- */
- CONST (OperationMapper::constant),
- EXPANDDIMS (OperationMapper::expandDims),
- IDENTITY (OperationMapper::identity),
- PLACEHOLDER (OperationMapper::placeholder),
- PLACEHOLDERWITHDEFAULT (OperationMapper::placeholderWithDefault),
- RESHAPE (OperationMapper::reshape),
- SQUEEZE (OperationMapper::squeeze),
-
- /*
- * control flow
- */
- MERGE (OperationMapper::merge),
- SWITCH (OperationMapper::switchOp),
-
- /*
- * math ops
- */
- ADD (OperationMapper::add),
- ADD_N (OperationMapper::add),
- ACOS (OperationMapper::acos),
- DIV (OperationMapper::div),
- REALDIV (OperationMapper::div),
- FLOOR (OperationMapper::floor),
- MATMUL (OperationMapper::matmul),
- MAXIMUM (OperationMapper::maximum),
- MEAN (OperationMapper::mean),
- REDUCEMEAN (OperationMapper::mean),
- MUL (OperationMapper::mul),
- MULTIPLY (OperationMapper::mul),
- RSQRT (OperationMapper::rsqrt),
- SELECT (OperationMapper::select),
- WHERE3 (OperationMapper::select),
- SIGMOID (OperationMapper::sigmoid),
- SQUAREDDIFFERENCE (OperationMapper::squaredDifference),
- SUB (OperationMapper::sub),
- SUBTRACT (OperationMapper::sub),
-
- /*
- * nn ops
- */
- BIASADD (OperationMapper::add),
- ELU (OperationMapper::elu),
- RELU (OperationMapper::relu),
- SELU (OperationMapper::selu),
- SOFTMAX (OperationMapper::softMax),
-
- /*
- * state ops
- */
- VARIABLE (OperationMapper::variable),
- VARIABLEV2 (OperationMapper::variable),
-
- /*
- * evaluation no-ops
- */
- STOPGRADIENT (OperationMapper::identity),
- NOOP (OperationMapper::noOp);
-
-
- private final Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func;
-
- Operation(Function<TensorFlowImporter.Parameters, Optional<TypedTensorFunction>> func) {
- this.func = func;
- }
-
- Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) {
- return func.apply(params);
- }
-
- }
-
- static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params) {
- Optional<Operation> operation = Stream.of(Operation.values())
- .filter(op -> op.name().equalsIgnoreCase(params.node().getOp()))
- .findFirst();
- if (operation.isPresent()) {
- return operation.get().map(params);
- }
- params.signature().importWarning("TensorFlow operation '" + params.node().getOp() +
- "' in node '" + params.node().getName() + "' is not supported.");
- return Optional.empty();
- }
-
-
- // Operations ---------------------------------
-
- private static Optional<TypedTensorFunction> constant(TensorFlowImporter.Parameters params) {
- Tensor value = AttrValueConverter.toVespaTensor(params.node(), "value");
- if (value.type().rank() == 0) {
- TypedTensorFunction output = new TypedTensorFunction(value.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(
- new ConstantNode(new DoubleValue(value.asDouble()))));
- return Optional.of(output);
- }
- return createConstant(params, value);
- }
-
- private static Optional<TypedTensorFunction> expandDims(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
-
- Tensor axis = getConstantTensor(params, params.node().getInput(1));
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("Axis argument to ExpandDims must be a scalar");
- }
-
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
-
- int dimensionToInsert = (int)axis.asDouble();
- if (dimensionToInsert < 0) {
- dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
- }
-
- TensorType.Builder outputTypeBuilder = new TensorType.Builder();
- int dimensionIndex = 0;
- for (int i = 0; i < inputType.dimensions().size() + 1; ++i) {
- String name = String.format("temp_%d", i);
- Long size;
- if (i == dimensionToInsert) {
- size = 1L;
- } else {
- size = dimensionSize(inputType.dimensions().get(dimensionIndex));
- dimensionIndex++;
- }
- outputTypeBuilder.indexed(name, size);
- }
-
- return reshape(inputFunction, inputType, outputTypeBuilder.build());
- }
-
- private static Optional<TypedTensorFunction> identity(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- return params.inputs().get(0);
- }
-
- private static Optional<TypedTensorFunction> placeholder(TensorFlowImporter.Parameters params) {
- String name = params.node().getName();
- String vespaName = toVespaName(params.node().getName());
- TensorType type = params.result().arguments().get(name);
- if (type == null) {
- throw new IllegalArgumentException("A 'placeholder' node is referencing placeholder '" + name +
- "', but there is no such placeholder");
- }
- params.result().requiredMacro(vespaName, type);
- // Included literally in the expression and so must be produced by a separate macro in the rank profile
- TypedTensorFunction output = new TypedTensorFunction(type, new VariableTensor(vespaName, type));
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> placeholderWithDefault(TensorFlowImporter.Parameters params) {
- String name = toVespaName(params.node().getInput(0));
- Tensor defaultValue = getConstantTensor(params, params.node().getInput(0));
- params.result().largeConstant(name, defaultValue);
- params.result().macro(name, new RankingExpression(name, new ReferenceNode("constant(\"" + name + "\")")));
- // The default value will be provided by the macro. Users can override macro to change value.
- TypedTensorFunction output = new TypedTensorFunction(defaultValue.type(), new VariableTensor(name));
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> reshape(TensorFlowImporter.Parameters params) {
- if ( ! checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- Tensor shape = getConstantTensor(params, params.node().getInput(1));
-
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
-
- TensorType.Builder outputTypeBuilder = new TensorType.Builder();
- int dimensionIndex = 0;
- for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int size = cell.getValue().intValue();
- if (size < 0) {
- size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / tensorSize(inputType).intValue();
- }
- outputTypeBuilder.indexed(String.format("temp_%d", dimensionIndex), size);
- dimensionIndex++;
- }
- return reshape(inputFunction, inputType, outputTypeBuilder.build());
- }
-
- private static Optional<TypedTensorFunction> squeeze(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
-
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
- List<String> squeezeDimensions;
-
- AttrValue squeezeDimsAttr = params.node().getAttrMap().get("squeeze_dims");
- if (squeezeDimsAttr == null) {
- squeezeDimensions = inputType.dimensions().stream().
- filter(dim -> dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- } else {
- squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
- map(i -> i < 0 ? inputType.dimensions().size() - i : i).
- map(i -> inputType.dimensions().get(i.intValue())).
- filter(dim -> dimensionSize(dim) == 1).
- map(TensorType.Dimension::name).
- collect(Collectors.toList());
- }
-
- if (squeezeDimensions.isEmpty()) {
- return inputs.get(0);
- }
-
- TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
- TensorType outputType = Reduce.outputType(inputType, squeezeDimensions);
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> merge(TensorFlowImporter.Parameters params) {
- return params.inputs().stream()
- .filter(Optional::isPresent)
- .findFirst()
- .orElse(Optional.empty());
- }
-
- private static Optional<TypedTensorFunction> switchOp(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- Tensor predicate = getConstantTensor(params, params.node().getInput(1));
- if (predicate.type().rank() != 0) {
- throw new IllegalArgumentException("'switch': predicate must be a scalar");
- }
- double pred = predicate.asDouble();
- int output = params.port().length() > 0 ? Integer.parseInt(params.port()) : 0;
- if (output < 0 || output > 1) {
- throw new IllegalArgumentException("'switch': predicate is not boolean");
- }
- if (pred == output) {
- return inputs.get(0);
- }
- return Optional.empty();
- }
-
- private static Optional<TypedTensorFunction> add(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.add());
- }
-
- private static Optional<TypedTensorFunction> acos(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.acos());
- }
-
- private static Optional<TypedTensorFunction> div(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.divide());
- }
-
- private static Optional<TypedTensorFunction> floor(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.floor());
- }
-
- private static Optional<TypedTensorFunction> matmul(TensorFlowImporter.Parameters params) {
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
-
- TypedTensorFunction a = inputs.get(0).get();
- TypedTensorFunction b = inputs.get(1).get();
- if (a.type().rank() < 2 || b.type().rank() < 2)
- throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
- if (a.type().rank() != b.type().rank())
- throw new IllegalArgumentException("Tensors in matmul must have the same rank");
-
- String afterLastDim = "d" + (a.type().rank() + 1);
- // Let the first dimension of the second tensor be the same as the second dimension of the first
- // and the second dimension of the second argument be not present in the first argument, while leaving the
- // rest of the dimensions the same. Such is the way of implicit dimension name tensor multiplication.
-
- // TODO: Check if transpose_a or transpose_b is set true and rename differently accordingly
-
- Rename renamedB = new Rename(b.function(), ImmutableList.of("d0", "d1"),
- ImmutableList.of("d1", afterLastDim));
- Matmul matmul = new Matmul(a.function(), renamedB, "d1");
- TypedTensorFunction output = new TypedTensorFunction(Matmul.outputType(a.type(), b.type(), "d1"),
- new Rename(matmul, afterLastDim, "d1"));
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> maximum(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.max());
- }
-
- private static Optional<TypedTensorFunction> mean(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TensorFunction inputFunction = inputs.get(0).get().function();
- TensorType inputType = inputs.get(0).get().type();
-
- Tensor reductionIndices = getConstantTensor(params, params.node().getInput(1));
- List<String> reduceDimensions = new ArrayList<>();
- for (Iterator<Tensor.Cell> cellIterator = reductionIndices.cellIterator(); cellIterator.hasNext();) {
- Tensor.Cell cell = cellIterator.next();
- int dimensionIndex = cell.getValue().intValue();
- if (dimensionIndex < 0) {
- dimensionIndex = inputType.dimensions().size() - dimensionIndex;
- }
- reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
- }
-
- TensorType outputType = Reduce.outputType(inputType, reduceDimensions);
- TensorFunction outputFunction = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
-
- if (shouldKeepDimensions(params)) {
- return reshape(outputFunction, outputType, keepDimensionType(inputType, reduceDimensions));
- }
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> mul(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.multiply());
- }
-
- private static Optional<TypedTensorFunction> rsqrt(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.rsqrt());
- }
-
- private static Optional<TypedTensorFunction> select(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 3)) {
- return Optional.empty();
- }
- Tensor condition = getConstantTensor(params, params.node().getInput(0));
-
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TypedTensorFunction x = inputs.get(1).get();
- TypedTensorFunction y = inputs.get(2).get();
- if ((x.type().rank() != y.type().rank()) || !(tensorSize(x.type()).equals(tensorSize(y.type())))) {
- throw new IllegalArgumentException("'Select': input tensors must have the same shape");
- }
-
- if (condition.type().rank() == 0) {
- return Optional.of((int)condition.asDouble() == 0 ? y : x);
- }
- if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
- return Optional.of(condition.cellIterator().next().getValue().intValue() == 0 ? y : x);
- }
-
- // The task is to select cells from 'x' or 'y' based on 'condition'.
- // If 'condition' is 0 (false), select from 'y', if 1 (true) select
- // from 'x'. We do this by individually joining 'x' and 'y' with
- // 'condition', and then joining the resulting two tensors.
-
- Optional<TypedTensorFunction> conditionFunction = importConstantTensor(params, params.node().getInput(0));
- if (!conditionFunction.isPresent()) {
- return Optional.empty();
- }
- TensorFunction xCond = new Join(x.function(), conditionFunction.get().function(), ScalarFunctions.multiply());
- TensorFunction yCond = new Join(y.function(), conditionFunction.get().function(), new DoubleBinaryOperator() {
- @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
- @Override public String toString() { return "f(a,b)(a * (1-b))"; }
- });
- TensorFunction outputFunction = new Join(xCond, yCond, ScalarFunctions.add());
- TypedTensorFunction output = new TypedTensorFunction(x.type(), outputFunction);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> sigmoid(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.sigmoid());
- }
-
- private static Optional<TypedTensorFunction> squaredDifference(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.squareddifference());
- }
-
- private static Optional<TypedTensorFunction> sub(TensorFlowImporter.Parameters params) {
- return join(params, ScalarFunctions.subtract());
- }
-
- private static Optional<TypedTensorFunction> elu(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.elu());
- }
-
- private static Optional<TypedTensorFunction> relu(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.relu());
- }
-
- private static Optional<TypedTensorFunction> selu(TensorFlowImporter.Parameters params) {
- return map(params, ScalarFunctions.selu());
- }
-
- private static Optional<TypedTensorFunction> softMax(TensorFlowImporter.Parameters params) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TypedTensorFunction a = inputs.get(0).get();
- // TODO: Read the "dim" parameter and use it to decide dimension if set and != -1
- String dimension = "d" + (a.type().rank() - 1);
- Softmax softmax = new Softmax(a.function(), dimension);
- TypedTensorFunction output = new TypedTensorFunction(Softmax.outputType(a.type(), dimension), softmax);
- return Optional.of(output);
- }
-
- private static Optional<TypedTensorFunction> variable(TensorFlowImporter.Parameters params) {
- return importConstantTensor(params, params.node().getName());
- }
-
- private static Optional<TypedTensorFunction> noOp(TensorFlowImporter.Parameters params) {
- return Optional.empty();
- }
-
- /*
- * Utility
- */
-
- private static Optional<TypedTensorFunction> join(TensorFlowImporter.Parameters params, DoubleBinaryOperator doubleFunction) {
- if (!checkInputs(params, 2)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
-
- TypedTensorFunction a = inputs.get(0).get();
- TypedTensorFunction b = inputs.get(1).get();
-
- if (a.type().rank() == 0 && b.type().rank() > 0) {
- return Optional.of(new TypedTensorFunction(b.type(), new Join(a.function(), b.function(), doubleFunction)));
- }
- if (b.type().rank() == 0 && a.type().rank() > 0) {
- return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)));
- }
- if (a.type().rank() == b.type().rank()) {
- return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), b.function(), doubleFunction)));
- }
-
- // Well now we have entered the wonderful world of "broadcasting"
- // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
- // I'm not able to extract from that any unambiguous specification of which dimensions
- // should be "stretched" when the tensor do not have the same number of dimensions.
- // From trying this with TensorFlow it appears that the second tensor is matched to the
- // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
- // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
-
- if (a.type().rank() > b.type().rank()) {
- TensorFunction renameFunction = renameForBroadcast(a, b);
- return Optional.of(new TypedTensorFunction(a.type(), new Join(a.function(), renameFunction, doubleFunction)));
- }
- TensorFunction renameFunction = renameForBroadcast(b, a);
- return Optional.of(new TypedTensorFunction(b.type(), new Join(renameFunction, b.function(), doubleFunction)));
- }
-
- private static TensorFunction renameForBroadcast(TypedTensorFunction a, TypedTensorFunction b) {
- List<String> renameFrom = new ArrayList<>();
- List<String> renameTo = new ArrayList<>();
- int sizeDifference = a.type().rank() - b.type().rank();
- for (int i = 0; i < b.type().rank(); i++) {
- renameFrom.add(b.type().dimensions().get(i).name());
- renameTo.add("d" + (sizeDifference + i));
- }
- return new Rename(b.function(), renameFrom, renameTo);
- }
-
- private static Optional<TypedTensorFunction> map(TensorFlowImporter.Parameters params, DoubleUnaryOperator doubleFunction) {
- if (!checkInputs(params, 1)) {
- return Optional.empty();
- }
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- TypedTensorFunction a = inputs.get(0).get();
- TensorType resultType = com.yahoo.tensor.functions.Map.outputType(a.type());
- com.yahoo.tensor.functions.Map function = new com.yahoo.tensor.functions.Map(a.function(), doubleFunction);
- return Optional.of(new TypedTensorFunction(resultType, function));
- }
-
- private static Optional<TypedTensorFunction> createConstant(TensorFlowImporter.Parameters params, Tensor constant) {
- String name = toVespaName(params.node().getName());
- if (constant.type().rank() == 0 || constant.size() <= 1) {
- params.result().smallConstant(name, constant);
- } else {
- params.result().largeConstant(name, constant);
- }
- TypedTensorFunction output = new TypedTensorFunction(constant.type(),
- new TensorFunctionNode.TensorFunctionExpressionNode(
- new ReferenceNode("constant(\"" + name + "\")")));
- return Optional.of(output);
- }
-
- private static Tensor getConstantTensor(TensorFlowImporter.Parameters params, String name) {
- String vespaName = toVespaName(name);
- if (params.result().smallConstants().containsKey(vespaName)) {
- return params.result().smallConstants().get(vespaName);
- }
- if (params.result().largeConstants().containsKey(vespaName)) {
- return params.result().largeConstants().get(vespaName);
- }
- Session.Runner fetched = params.model().session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name + ", but got " +
- importedTensors.size());
- return TensorConverter.toVespaTensor(importedTensors.get(0));
- }
-
- private static Optional<TypedTensorFunction> importConstantTensor(TensorFlowImporter.Parameters params, String name) {
- AttrValue shapes = params.node().getAttrMap().get("_output_shapes");
- if (shapes == null)
- throw new IllegalArgumentException("'" + name + "' is missing a tensor shape");
- Tensor constant = getConstantTensor(params, name);
- return createConstant(params, constant);
- }
-
- private static Optional<TypedTensorFunction> reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!tensorSize(inputType).equals(tensorSize(outputType))) {
- throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
- }
-
- // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
- // then use the dimension order of the new shape to roll back into a tensor.
- // Here we create a transformation tensor that is multiplied with the from tensor to map into
- // the new shape. We have to introduce temporary dimension names and rename back if dimension names
- // in the new and old tensor type overlap.
-
- ExpressionNode unrollFrom = unrollTensorExpression(inputType);
- ExpressionNode unrollTo = unrollTensorExpression(outputType);
- ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
-
- TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
- Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
-
- TensorFunction outputFunction = new Reduce(
- new Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
- TypedTensorFunction output = checkNamingConvention(outputType, outputFunction);
- return Optional.of(output);
- }
-
- private static ExpressionNode unrollTensorExpression(TensorType type) {
- if (type.rank() == 0) {
- return new ConstantNode(DoubleValue.zero);
- }
- List<ExpressionNode> children = new ArrayList<>();
- List<ArithmeticOperator> operators = new ArrayList<>();
- int size = 1;
- for (int i = type.dimensions().size() - 1; i >= 0; --i) {
- TensorType.Dimension dimension = type.dimensions().get(i);
- children.add(0, new ReferenceNode(dimension.name()));
- if (size > 1) {
- operators.add(0, ArithmeticOperator.MULTIPLY);
- children.add(0, new ConstantNode(new DoubleValue(size)));
- }
- size *= dimensionSize(dimension);
- if (i > 0) {
- operators.add(0, ArithmeticOperator.PLUS);
- }
- }
- return new ArithmeticNode(children, operators);
- }
-
- private static boolean shouldKeepDimensions(TensorFlowImporter.Parameters params) {
- AttrValue keepDimsAttr = params.node().getAttrMap().get("keep_dims");
- return keepDimsAttr != null && keepDimsAttr.getB();
- }
-
- private static TensorType keepDimensionType(TensorType inputType, List<String> reduceDimensions) {
- TensorType.Builder builder = new TensorType.Builder();
- for (TensorType.Dimension dimension: inputType.dimensions()) {
- String name = dimension.name();
- Long size = dimensionSize(dimension);
- if (reduceDimensions.contains(name)) {
- size = 1L;
- }
- builder.indexed(name, size);
- }
- return builder.build();
- }
-
- private static TypedTensorFunction checkNamingConvention(TensorType type, TensorFunction function) {
- for (int i = 0; i < type.dimensions().size(); ++i) {
- String correct = String.format("d%d", i);
- String current = type.dimensions().get(i).name();
- if (!current.equals(correct)) {
- return fixNamingConvention(type, function);
- }
- }
- return new TypedTensorFunction(type, function);
- }
-
- private static TypedTensorFunction fixNamingConvention(TensorType type, TensorFunction function) {
- TensorType.Builder correctType = new TensorType.Builder();
- List<String> from = new ArrayList<>();
- List<String> to = new ArrayList<>();
- for (int i = 0; i < type.dimensions().size(); ++i) {
- String correct = String.format("d%d", i);
- String current = type.dimensions().get(i).name();
- if (!current.equals(correct)) {
- from.add(current);
- to.add(correct);
- }
- correctType.indexed(correct, dimensionSize(type.dimensions().get(i)));
- }
- if (from.size() > 0) {
- function = new Rename(function, from, to);
- type = correctType.build();
- }
- return new TypedTensorFunction(type, function);
- }
-
- private static Long tensorSize(TensorType type) {
- Long size = 1L;
- for (TensorType.Dimension dimension : type.dimensions()) {
- size *= dimensionSize(dimension);
- }
- return size;
- }
-
- private static Long dimensionSize(TensorType.Dimension dim) {
- return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
- }
-
- private static boolean checkInputs(TensorFlowImporter.Parameters params, int expected) {
- List<Optional<TypedTensorFunction>> inputs = params.inputs();
- if (!inputs.stream().allMatch(Optional::isPresent)) {
- return false;
- }
- if (inputs.size() != expected) {
- params.signature().importWarning("Expected " + expected +
- " arguments to " + params.node().getOp() + ", but got " + inputs.size());
- return false;
- }
- return true;
- }
-
- public static String toVespaName(String name) {
- return name != null ? name.replace('/', '_') : null;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
deleted file mode 100644
index b88ffce275a..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java
+++ /dev/null
@@ -1,155 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
-import java.nio.FloatBuffer;
-import java.nio.IntBuffer;
-import java.nio.LongBuffer;
-
-
-/**
- * Converts TensorFlow tensors into Vespa tensors.
- *
- * @author bratseth
- */
-public class TensorConverter {
-
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
- TensorType type = toVespaTensorType(tfTensor.shape());
- Values values = readValuesOf(tfTensor);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- for (int i = 0; i < values.size(); i++)
- builder.cellByDirectIndex(i, values.get(i));
- return builder.build();
- }
-
- private static TensorType toVespaTensorType(long[] shape) {
- TensorType.Builder b = new TensorType.Builder();
- int dimensionIndex = 0;
- for (long dimensionSize : shape) {
- if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
- b.indexed("d" + (dimensionIndex++), dimensionSize);
- }
- return b.build();
- }
-
- private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
- switch (tfTensor.dataType()) {
- case DOUBLE: return new DoubleValues(tfTensor);
- case FLOAT: return new FloatValues(tfTensor);
- case BOOL: return new BoolValues(tfTensor);
- case UINT8: return new IntValues(tfTensor);
- case INT32: return new IntValues(tfTensor);
- case INT64: return new LongValues(tfTensor);
- default:
- throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
- }
- }
-
- /** Allows reading values from buffers of various numeric types as bytes */
- private static abstract class Values {
-
- private final int size;
-
- protected Values(int size) {
- this.size = size;
- }
-
- abstract double get(int i);
-
- int size() { return size; }
-
- }
-
- private static class DoubleValues extends Values {
-
- private final DoubleBuffer values;
-
- DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = DoubleBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class FloatValues extends Values {
-
- private final FloatBuffer values;
-
- FloatValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = FloatBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class BoolValues extends Values {
-
- private final ByteBuffer values;
-
- BoolValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = ByteBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class IntValues extends Values {
-
- private final IntBuffer values;
-
- IntValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = IntBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
- private static class LongValues extends Values {
-
- private final LongBuffer values;
-
- LongValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = LongBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
-
- @Override
- double get(int i) {
- return values.get(i);
- }
-
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index c97ee2b1514..7116d430502 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -2,10 +2,20 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
-import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.yolean.Exceptions;
import org.tensorflow.SavedModelBundle;
+import org.tensorflow.Session;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.MetaGraphDef;
import org.tensorflow.framework.NodeDef;
@@ -24,6 +34,7 @@ import java.util.stream.Collectors;
* Converts a saved TensorFlow model into a ranking expression and set of constants.
*
* @author bratseth
+ * @author lesters
*/
public class TensorFlowImporter {
@@ -57,196 +68,303 @@ public class TensorFlowImporter {
}
}
- private TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle model) {
- TensorFlowModel result = new TensorFlowModel();
+ /**
+ * Imports the TensorFlow graph by first importing the tensor types, then
+ * finding a suitable set of dimensions names for each
+ * placeholder/constant/variable, then importing the expressions.
+ */
+ private static TensorFlowModel importGraph(MetaGraphDef graph, SavedModelBundle bundle) {
+ TensorFlowModel model = new TensorFlowModel();
+ OperationIndex index = new OperationIndex();
+
+ importSignatures(graph, model);
+ importNodes(graph, model, index);
+ findDimensionNames(model, index);
+ importExpressions(model, index, bundle);
+
+ // nodes with multiple outputs are calculated multiple times. consider adding macros for those.
+
+ reportWarnings(model, index);
+
+ return model;
+ }
+
+ private static void importSignatures(MetaGraphDef graph, TensorFlowModel model) {
for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
- TensorFlowModel.Signature signature = result.signature(signatureEntry.getKey()); // Prefer key over "methodName"
+ String signatureName = signatureEntry.getKey();
+ TensorFlowModel.Signature signature = model.signature(signatureName);
+
+ Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
+ for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
+ String inputName = input.getKey();
+ signature.input(inputName, namePartOf(input.getValue().getName()));
+ }
- importInputs(signatureEntry.getValue().getInputsMap(), signature);
- for (Map.Entry<String, TensorInfo> output : signatureEntry.getValue().getOutputsMap().entrySet()) {
+ Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
+ for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
String outputName = output.getKey();
- try {
- NodeDef node = getNode(namePartOf(output.getValue().getName()), graph.getGraphDef());
- Parameters params = createParameters(graph.getGraphDef(), model, result, signature, node, "");
-
- // Commonly, there are multiple paths through a TensorFlow graph, for instance for
- // training and testing/evaluation. Examples are dropout and batch norm. For Vespa
- // we are not concerned with training paths, so we can ignore non-supported operations
- // as long as they are on a path that will not be evaluated run time. Operations
- // that fail import will not have a value present in the optionals. However, the
- // final output node must have value present. It is an error if it does not.
-
- Optional<TypedTensorFunction> outputFunction = importNode(params);
- if (!outputFunction.isPresent()) {
- throw new IllegalArgumentException(signature.importWarnings().stream().collect(Collectors.joining("\n")));
- }
- signature.output(outputName, namePartOf(output.getValue().getName()));
- }
- catch (IllegalArgumentException e) {
- signature.skippedOutput(outputName, Exceptions.toMessageString(e));
- }
+ signature.output(outputName, namePartOf(output.getValue().getName()));
}
}
- return result;
}
- private void importInputs(Map<String, TensorInfo> inputInfoMap, TensorFlowModel.Signature signature) {
- inputInfoMap.forEach((key, value) -> {
- String argumentName = namePartOf(value.getName());
- TensorType argumentType = AttrValueConverter.toVespaTensorType(value.getTensorShape());
- // Arguments are (Placeholder) nodes, so not local to the signature:
- signature.owner().argument(argumentName, argumentType);
- signature.input(key, argumentName);
- });
+ private static boolean isSignatureInput(TensorFlowModel model, TensorFlowOperation operation) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String inputName : signature.inputs().values()) {
+ if (inputName.equals(operation.node().getName())) {
+ return true;
+ }
+ }
+ }
+ return false;
}
- /** Recursively convert a graph of TensorFlow nodes into a Vespa tensor function expression tree */
- private Optional<TypedTensorFunction> importNode(Parameters params) {
- String nodeName = params.node().getName();
- if (params.imported().containsKey(nodeName)) {
- return Optional.of(params.imported().get(nodeName));
+ private static boolean isSignatureOutput(TensorFlowModel model, TensorFlowOperation operation) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ if (outputName.equals(operation.node().getName())) {
+ return true;
+ }
+ }
}
+ return false;
+ }
- Optional<TypedTensorFunction> function = OperationMapper.map(params);
- if ( ! function.isPresent()) {
- return Optional.empty();
- }
- if ( ! controlDependenciesArePresent(params)) {
- return Optional.empty();
+ private static void importNodes(MetaGraphDef graph, TensorFlowModel model, OperationIndex index) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ importNode(outputName, graph.getGraphDef(), index);
+ }
}
- params.imported().put(nodeName, function.get());
+ }
- try {
- // We add all intermediate nodes imported as separate expressions. Only those referenced in a signature output
- // will be used. We parse the TensorFunction here to convert it to a RankingExpression tree
- params.result().expression(nodeName,
- new RankingExpression(nodeName, function.get().function().toString()));
- return function;
+ private static TensorFlowOperation importNode(String name, GraphDef graph, OperationIndex index) {
+ if (index.alreadyImported(name)) {
+ return index.get(name);
}
- catch (ParseException e) {
- throw new RuntimeException("Tensorflow function " + function.get().function() +
- " cannot be parsed as a ranking expression", e);
+ NodeDef node = getTensorFlowNodeFromGraph(namePartOf(name), graph);
+ List<TensorFlowOperation> inputs = importNodeInputs(node, graph, index);
+ TensorFlowOperation operation = OperationMapper.get(node, inputs, portPartOf(name));
+ index.put(name, operation);
+
+ List<TensorFlowOperation> controlInputs = importControlInputs(node, graph, index);
+ if (controlInputs.size() > 0) {
+ operation.setControlInputs(controlInputs);
}
- }
- private boolean controlDependenciesArePresent(Parameters params) {
- return params.node().getInputList().stream()
- .filter(TensorFlowImporter::isControlDependency)
- .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName))))
- .allMatch(Optional::isPresent);
+ return operation;
}
- private static boolean isControlDependency(String nodeName) {
- return nodeName.startsWith("^");
+ private static List<TensorFlowOperation> importNodeInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ return node.getInputList().stream()
+ .filter(name -> ! isControlDependency(name))
+ .map(name -> importNode(name, graph, index))
+ .collect(Collectors.toList());
}
- private List<Optional<TypedTensorFunction>> importArguments(Parameters params) {
- return params.node().getInputList().stream()
- .filter(nodeName -> !isControlDependency(nodeName))
- .map(nodeName -> importNode(params.copy(getNode(namePartOf(nodeName), params.graph()), indexPartOf(nodeName))))
+ private static List<TensorFlowOperation> importControlInputs(NodeDef node, GraphDef graph, OperationIndex index) {
+ return node.getInputList().stream()
+ .filter(name -> isControlDependency(name))
+ .map(name -> importNode(name, graph, index))
.collect(Collectors.toList());
}
- private NodeDef getNode(String name, GraphDef graph) {
- return graph.getNodeList().stream()
- .filter(node -> node.getName().equals(name))
- .findFirst()
- .orElseThrow(() -> new IllegalArgumentException("Could not find node '" + name + "'"));
+ private static boolean isControlDependency(String name) {
+ return name.startsWith("^");
}
- /**
- * A method signature input and output has the form name:index.
- * This returns the name part without the index.
- */
- private static String namePartOf(String name) {
- name = name.startsWith("^") ? name.substring(1) : name;
- return name.split(":")[0];
+ /** Find dimension names to avoid excessive renaming while evaluating the model. */
+ private static void findDimensionNames(TensorFlowModel model, OperationIndex index) {
+ DimensionRenamer renamer = new DimensionRenamer();
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ addDimensionNameConstraints(index.get(output), renamer);
+ }
+ }
+ renamer.solve();
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ renameDimensions(index.get(output), renamer);
+ }
+ }
}
- /**
- * This return the index part. Indexes are used for nodes with
- * multiple outputs.
- */
- private static String indexPartOf(String name) {
- int i = name.indexOf(":");
- return i < 0 ? "" : name.substring(i + 1);
+ private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
+ operation.addDimensionNameConstraints(renamer);
+ }
}
+ private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
+ if (operation.type().isPresent()) {
+ operation.inputs().forEach(input -> renameDimensions(input, renamer));
+ operation.renameDimensions(renamer);
+ }
+ }
- private Parameters createParameters(GraphDef graph,
- SavedModelBundle model,
- TensorFlowModel result,
- TensorFlowModel.Signature signature,
- NodeDef node,
- String port) {
- return new Parameters(this, graph, model, result, signature, new HashMap<>(), node, port);
+ private static void importExpressions(TensorFlowModel model, OperationIndex index, SavedModelBundle bundle) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String outputName : signature.outputs().values()) {
+ try {
+ Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
+ if (!function.isPresent()) {
+ signature.skippedOutput(outputName, "No valid output function could be found.");
+ }
+ }
+ catch (IllegalArgumentException e) {
+ signature.skippedOutput(outputName, Exceptions.toMessageString(e));
+ }
+ }
+ }
}
- /** Parameter object to hold important data while importing */
- static final class Parameters {
+ private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
+ if (!operation.type().isPresent()) {
+ return Optional.empty();
+ }
+ if (operation.isConstant()) {
+ return importConstant(model, operation, bundle);
+ }
- private final TensorFlowImporter owner;
- private final GraphDef graph;
- private final SavedModelBundle model;
- private final TensorFlowModel result;
- private final TensorFlowModel.Signature signature;
- private final Map<String, TypedTensorFunction> imported;
- private final NodeDef node;
- private final String port;
+ importInputExpressions(operation, model, bundle);
+ importRankingExpression(model, operation);
+ importInputExpression(model, operation);
+ importMacroExpression(model, operation);
- private Parameters(TensorFlowImporter owner,
- GraphDef graph,
- SavedModelBundle model,
- TensorFlowModel result,
- TensorFlowModel.Signature signature,
- Map<String, TypedTensorFunction> imported,
- NodeDef node,
- String port) {
- this.owner = owner;
- this.graph = graph;
- this.model = model;
- this.result = result;
- this.signature = signature;
- this.imported = imported;
- this.node = node;
- this.port = port;
- }
+ return operation.function();
+ }
- GraphDef graph() {
- return this.graph;
+ private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
+ operation.inputs().forEach(input -> importExpression(input, model, bundle));
+ }
+
+ private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.macro().isPresent()) {
+ model.macro(operation.vespaName(), operation.macro().get());
}
+ }
- SavedModelBundle model() {
- return this.model;
+ private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) {
+ String name = operation.vespaName();
+ if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
+ return operation.function();
}
- TensorFlowModel result() {
- return this.result;
+ Tensor tensor;
+ if (operation.getConstantValue().isPresent()) {
+ Value value = operation.getConstantValue().get();
+ if ( ! (value instanceof TensorValue)) {
+ return operation.function(); // scalar values are inserted directly into the expression
+ }
+ tensor = value.asTensor();
+ } else {
+ Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName());
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1) {
+ throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " +
+ importedTensors.size());
+ }
+ // Here we use the type from the operation, which will have correct dimension names after name resolving
+ tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get());
+ operation.setConstantValue(new TensorValue(tensor));
}
- TensorFlowModel.Signature signature() {
- return this.signature;
+ if (tensor.type().rank() == 0 || tensor.size() <= 1) {
+ model.smallConstant(operation.vespaName(), tensor);
+ } else {
+ model.largeConstant(operation.vespaName(), tensor);
}
+ return operation.function();
+ }
+
+ private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.function().isPresent()) {
+ String name = operation.node().getName();
+ if (!model.expressions().containsKey(operation.node().getName())) {
+ TensorFunction function = operation.function().get();
+
+ // Make sure output adheres to standard naming convention
+ if (isSignatureOutput(model, operation)) {
+ OrderedTensorType operationType = operation.type().get();
+ OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
+ if ( ! operationType.equals(standardNamingType)) {
+ List<String> renameFrom = operationType.dimensionNames();
+ List<String> renameTo = standardNamingType.dimensionNames();
+ function = new Rename(function, renameFrom, renameTo);
+ }
+ }
- Map<String, TypedTensorFunction> imported() {
- return this.imported;
+ try {
+ // We add all intermediate nodes imported as separate expressions. Only
+ // those referenced in a signature output will be used. We parse the
+ // TensorFunction here to convert it to a RankingExpression tree.
+ model.expression(name, new RankingExpression(name, function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
+ }
}
+ }
- NodeDef node() {
- return node;
+ private static void importInputExpression(TensorFlowModel model, TensorFlowOperation operation) {
+ if (operation.isInput() && isSignatureInput(model, operation)) {
+ // All inputs must have dimensions with standard naming convention: d0, d1, ...
+ OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
+ model.argument(operation.node().getName(), standardNamingConvention.type());
+ model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
}
+ }
- String port() {
- return port;
+ private static void reportWarnings(TensorFlowModel model, OperationIndex index) {
+ for (TensorFlowModel.Signature signature : model.signatures().values()) {
+ for (String output : signature.outputs().values()) {
+ reportWarnings(index.get(output), signature);
+ }
}
+ }
- Parameters copy(NodeDef node, String port) {
- return new Parameters(this.owner, this.graph, this.model, this.result, this.signature, this.imported, node, port);
+ private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
+ for (String warning : operation.warnings()) {
+ signature.importWarning(warning);
}
+ }
- List<Optional<TypedTensorFunction>> inputs() {
- return owner.importArguments(this);
+ private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
+ for (NodeDef node : graph.getNodeList()) {
+ if (node.getName().equals(name)) {
+ return node;
+ }
}
+ throw new IllegalArgumentException("Could not find node '" + name + "'");
+ }
+
+ /**
+ * A method signature input and output has the form name:index.
+ * This returns the name part without the index.
+ */
+ private static String namePartOf(String name) {
+ name = name.startsWith("^") ? name.substring(1) : name;
+ return name.split(":")[0];
+ }
+
+ /**
+ * This return the output port part. Indexes are used for nodes with
+ * multiple outputs.
+ */
+ private static int portPartOf(String name) {
+ int i = name.indexOf(":");
+ return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
+ }
+
+
+ private static class OperationIndex {
+ private final Map<String, TensorFlowOperation> index = new HashMap<>();
+ public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
+ public TensorFlowOperation get(String key) { return index.get(key); }
+ public boolean alreadyImported(String key) { return index.containsKey(key); }
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
deleted file mode 100644
index 600225bfe76..00000000000
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TypedTensorFunction.java
+++ /dev/null
@@ -1,30 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
-
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.TensorFunction;
-
-/**
- * A tensor function returning a specific tensor type
- *
- * @author bratseth
- */
-final class TypedTensorFunction {
-
- private final TensorType type;
- private final TensorFunction function;
-
- public TypedTensorFunction(TensorType type, TensorFunction function) {
- this.type = type;
- this.function = function;
- }
-
- public TensorType type() {
- return type;
- }
-
- public TensorFunction function() {
- return function;
- }
-
-}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
new file mode 100644
index 00000000000..c1665d066a4
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/DimensionRenamer.java
@@ -0,0 +1,210 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+
+import java.util.ArrayDeque;
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
+
+/**
+ * A constraint satisfier to find suitable dimension names to reduce the
+ * amount of necessary renaming during evaluation of an imported model.
+ *
+ * @author lesters
+ */
+public class DimensionRenamer {
+
+ private final String dimensionPrefix;
+ private final Map<String, List<Integer>> variables = new HashMap<>();
+ private final Map<Arc, Constraint> constraints = new HashMap<>();
+ private final Map<String, Integer> renames = new HashMap<>();
+
+ private int iterations = 0;
+
+ public DimensionRenamer() {
+ this("d");
+ }
+
+ public DimensionRenamer(String dimensionPrefix) {
+ this.dimensionPrefix = dimensionPrefix;
+ }
+
+ /**
+ * Add a dimension name variable.
+ */
+ public void addDimension(String name) {
+ variables.computeIfAbsent(name, d -> new ArrayList<>());
+ }
+
+ /**
+ * Add a constraint between dimension names.
+ */
+ public void addConstraint(String from, String to, Constraint pred, TensorFlowOperation operation) {
+ Arc arc = new Arc(from, to, operation);
+ Arc opposite = arc.opposite();
+ constraints.put(arc, pred);
+ constraints.put(opposite, (x,y) -> pred.test(y, x)); // make constraint graph symmetric
+ }
+
+ /**
+ * Retrieve resulting name of dimension after solving for constraints.
+ */
+ public Optional<String> dimensionNameOf(String name) {
+ if (!renames.containsKey(name)) {
+ return Optional.empty();
+ }
+ return Optional.of(String.format("%s%d", dimensionPrefix, renames.get(name)));
+ }
+
+ /**
+ * Perform iterative arc consistency until we have found a solution. After
+ * an initial iteration, the variables (dimensions) will have multiple
+ * valid values. Find a single valid assignment by iteratively locking one
+ * dimension after another, and running the arc consistency algorithm
+ * multiple times.
+ *
+ * This requires having constraints that result in an absolute ordering:
+ * equals, lesserThan and greaterThan do that, but adding notEquals does
+ * not typically result in a guaranteed ordering. If that is needed, the
+ * algorithm below needs to be adapted with a backtracking (tree) search
+ * to find solutions.
+ */
+ public void solve(int maxIterations) {
+ initialize();
+
+ // Todo: evaluate possible improved efficiency by using a heuristic such as min-conflicts
+
+ for (String dimension : variables.keySet()) {
+ List<Integer> values = variables.get(dimension);
+ if (values.size() > 1) {
+ if (!ac3()) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution.");
+ }
+ values.sort(Integer::compare);
+ variables.put(dimension, Collections.singletonList(values.get(0)));
+ }
+ renames.put(dimension, variables.get(dimension).get(0));
+ if (iterations > maxIterations) {
+ throw new IllegalArgumentException("Dimension renamer unable to find a solution within " +
+ maxIterations + " iterations");
+ }
+ }
+
+ // Todo: handle failure more gracefully:
+ // If a solution can't be found, look at the operation node in the arc
+ // with the most remaining constraints, and inject a rename operation.
+ // Then run this algorithm again.
+ }
+
+ public void solve() {
+ solve(100000);
+ }
+
+ private void initialize() {
+ for (Map.Entry<String, List<Integer>> variable : variables.entrySet()) {
+ List<Integer> values = variable.getValue();
+ for (int i = 0; i < variables.size(); ++i) {
+ values.add(i); // invariant: values are in increasing order
+ }
+ }
+ }
+
+ private boolean ac3() {
+ Deque<Arc> workList = new ArrayDeque<>(constraints.keySet());
+ while (!workList.isEmpty()) {
+ Arc arc = workList.pop();
+ iterations += 1;
+ if (revise(arc)) {
+ if (variables.get(arc.from).size() == 0) {
+ return false; // no solution found
+ }
+ for (Arc constraint : constraints.keySet()) {
+ if (arc.from.equals(constraint.to) && !arc.to.equals(constraint.from)) {
+ workList.add(constraint);
+ }
+ }
+ }
+ }
+ return true;
+ }
+
+ private boolean revise(Arc arc) {
+ boolean revised = false;
+ for(Iterator<Integer> fromIterator = variables.get(arc.from).iterator(); fromIterator.hasNext(); ) {
+ Integer from = fromIterator.next();
+ boolean satisfied = false;
+ for (Iterator<Integer> toIterator = variables.get(arc.to).iterator(); toIterator.hasNext(); ) {
+ Integer to = toIterator.next();
+ if (constraints.get(arc).test(from, to)) {
+ satisfied = true;
+ }
+ }
+ if (!satisfied) {
+ fromIterator.remove();
+ revised = true;
+ }
+ }
+ return revised;
+ }
+
+ public interface Constraint {
+ boolean test(Integer x, Integer y);
+ }
+
+ public static boolean equals(Integer x, Integer y) {
+ return Objects.equals(x, y);
+ }
+
+ public static boolean lesserThan(Integer x, Integer y) {
+ return x < y;
+ }
+
+ public static boolean greaterThan(Integer x, Integer y) {
+ return x > y;
+ }
+
+ private static class Arc {
+
+ private final String from;
+ private final String to;
+ private final TensorFlowOperation operation;
+
+ Arc(String from, String to, TensorFlowOperation operation) {
+ this.from = from;
+ this.to = to;
+ this.operation = operation;
+ }
+
+ Arc opposite() {
+ return new Arc(to, from, operation);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(from, to);
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof Arc)) {
+ return false;
+ }
+ Arc other = (Arc) obj;
+ return Objects.equals(from, other.from) && Objects.equals(to, other.to);
+ }
+
+ @Override
+ public String toString() {
+ return String.format("%s -> %s", from, to);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
new file mode 100644
index 00000000000..0fe73fad8ce
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -0,0 +1,108 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Join;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Map;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Matmul;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Mean;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Merge;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.NoOp;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Placeholder;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.PlaceholderWithDefault;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Reshape;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Select;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Shape;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Squeeze;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Switch;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+/**
+ * Maps from TensorFlow operations to Vespa operations.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class OperationMapper {
+
+ public static TensorFlowOperation get(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ switch (node.getOp().toLowerCase()) {
+ /*
+ * array ops
+ */
+ case "const": return new Const(node, inputs, port);
+ case "expanddims": return new ExpandDims(node, inputs, port);
+ case "identity": return new Identity(node, inputs, port);
+ case "placeholder": return new Placeholder(node, inputs, port);
+ case "placeholderwithdefault": return new PlaceholderWithDefault(node, inputs, port);
+ case "reshape": return new Reshape(node, inputs, port);
+ case "shape": return new Shape(node, inputs, port);
+ case "squeeze": return new Squeeze(node, inputs, port);
+
+ /*
+ * control flow
+ */
+ case "merge": return new Merge(node, inputs, port);
+ case "switch": return new Switch(node, inputs, port);
+
+ /*
+ * math ops
+ */
+ case "add": return new Join(node, inputs, port, ScalarFunctions.add());
+ case "add_n": return new Join(node, inputs, port, ScalarFunctions.add());
+ case "acos": return new Map(node, inputs, port, ScalarFunctions.acos());
+ case "div": return new Join(node, inputs, port, ScalarFunctions.divide());
+ case "realdiv": return new Join(node, inputs, port, ScalarFunctions.divide());
+ case "floor": return new Map(node, inputs, port, ScalarFunctions.floor());
+ case "matmul": return new Matmul(node, inputs, port);
+ case "maximum": return new Join(node, inputs, port, ScalarFunctions.max());
+ case "mean": return new Mean(node, inputs, port);
+ case "reducemean": return new Mean(node, inputs, port);
+ case "mul": return new Join(node, inputs, port, ScalarFunctions.multiply());
+ case "multiply": return new Join(node, inputs, port, ScalarFunctions.multiply());
+ case "rsqrt": return new Map(node, inputs, port, ScalarFunctions.rsqrt());
+ case "select": return new Select(node, inputs, port);
+ case "where3": return new Select(node, inputs, port);
+ case "sigmoid": return new Map(node, inputs, port, ScalarFunctions.sigmoid());
+ case "squareddifference": return new Join(node, inputs, port, ScalarFunctions.squareddifference());
+ case "sub": return new Join(node, inputs, port, ScalarFunctions.subtract());
+ case "subtract": return new Join(node, inputs, port, ScalarFunctions.subtract());
+
+ /*
+ * nn ops
+ */
+ case "biasadd": return new Join(node, inputs, port, ScalarFunctions.add());
+ case "elu": return new Map(node, inputs, port, ScalarFunctions.elu());
+ case "relu": return new Map(node, inputs, port, ScalarFunctions.relu());
+ case "selu": return new Map(node, inputs, port, ScalarFunctions.selu());
+
+ /*
+ * random ops
+ */
+
+ /*
+ * state ops
+ */
+ case "variable": return new Variable(node, inputs, port);
+ case "variablev2": return new Variable(node, inputs, port);
+
+ /*
+ * evaluation no-ops
+ */
+ case "stopgradient":return new Identity(node, inputs, port);
+ case "noop": return new NoOp(node, inputs, port);
+ }
+ return new NoOp(node, inputs, port);
+ }
+
+}
+
+
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
new file mode 100644
index 00000000000..3742e443a06
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -0,0 +1,237 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+import org.tensorflow.framework.TensorShapeProto;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * A Vespa tensor type is ordered by the lexicographical ordering of dimension
+ * names. TensorFlow tensors have an explicit ordering of their dimensions.
+ * During import, we need to track the Vespa dimension that matches the
+ * corresponding TensorFlow dimension as the ordering can change after
+ * dimension renaming. That is the purpose of this class.
+ *
+ * @author lesters
+ */
+public class OrderedTensorType {
+
+ private final TensorType type;
+ private final List<TensorType.Dimension> dimensions;
+
+ private final long[] innerSizesTensorFlow;
+ private final long[] innerSizesVespa;
+ private final int[] dimensionMap;
+
+ private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ this.dimensions = Collections.unmodifiableList(dimensions);
+ this.type = new TensorType.Builder(dimensions).build();
+ this.innerSizesTensorFlow = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ public TensorType type() {
+ return this.type;
+ }
+
+ public List<TensorType.Dimension> dimensions() {
+ return dimensions;
+ }
+
+ public List<String> dimensionNames() {
+ return dimensions.stream().map(TensorType.Dimension::name).collect(Collectors.toList());
+ }
+
+ private int[] createDimensionMap() {
+ int numDimensions = dimensions.size();
+ if (numDimensions == 0) {
+ return null;
+ }
+ innerSizesTensorFlow[numDimensions - 1] = 1;
+ innerSizesVespa[numDimensions - 1] = 1;
+ for (int i = numDimensions - 1; --i >= 0; ) {
+ innerSizesTensorFlow[i] = dimensions().get(i+1).size().orElse(-1L) * innerSizesTensorFlow[i+1];
+ innerSizesVespa[i] = type.dimensions().get(i+1).size().orElse(-1L) * innerSizesVespa[i+1];
+ }
+ int[] mapping = new int[numDimensions];
+ for (int i = 0; i < numDimensions; ++i) {
+ TensorType.Dimension dim1 = dimensions().get(i);
+ for (int j = 0; j < numDimensions; ++j) {
+ TensorType.Dimension dim2 = type.dimensions().get(j);
+ if (dim1.equals(dim2)) {
+ mapping[i] = j;
+ break;
+ }
+ }
+ }
+ return mapping;
+ }
+
+ /**
+ * When dimension ordering between Vespa and TensorFlow differs, i.e.
+ * after dimension renaming, use the dimension map to read in values
+ * so that they are correctly laid out in memory for Vespa.
+ * Used when importing tensors from TensorFlow.
+ */
+ public int toDirectIndex(int index) {
+ if (dimensions.size() == 0) {
+ return 0;
+ }
+ if (dimensionMap == null) {
+ throw new IllegalArgumentException("Dimension map is not available");
+ }
+ int directIndex = 0;
+ long rest = index;
+ for (int i = 0; i < dimensions.size(); ++i) {
+ long address = rest / innerSizesTensorFlow[i];
+ directIndex += innerSizesVespa[dimensionMap[i]] * address;
+ rest %= innerSizesTensorFlow[i];
+ }
+ return directIndex;
+ }
+
+ @Override
+ public boolean equals(Object obj) {
+ if (obj == null || !(obj instanceof OrderedTensorType)) {
+ return false;
+ }
+ OrderedTensorType other = (OrderedTensorType) obj;
+ if (dimensions.size() != dimensions.size()) {
+ return false;
+ }
+ List<TensorType.Dimension> thisDimensions = this.dimensions();
+ List<TensorType.Dimension> otherDimensions = other.dimensions();
+ for (int i = 0; i < thisDimensions.size(); ++i) {
+ if (!thisDimensions.get(i).equals(otherDimensions.get(i))) {
+ return false;
+ }
+ }
+ return true;
+ }
+
+ public static void verifyType(NodeDef node, OrderedTensorType type) {
+ if (type == null) {
+ return;
+ }
+ TensorShapeProto shape = tensorFlowShape(node);
+ if (shape != null && type.type != null) {
+ if (shape.getDimCount() != type.type.rank()) {
+ throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
+ "does not match Vespa shape");
+ }
+ for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) {
+ int vespaIndex = type.dimensionMap[tensorFlowIndex];
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
+ TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
+ throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
+ "does not match Vespa dimensions");
+ }
+ }
+ }
+ }
+
+ private static TensorShapeProto tensorFlowShape(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
+ if (attrValueList == null) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "does not exist");
+ }
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
+ "is not of expected type");
+ }
+ List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
+ return shapeList.get(0); // support multiple outputs?
+ }
+
+ public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size());
+ for (TensorType.Dimension dimension : type.dimensions) {
+ String oldName = dimension.name();
+ Optional<String> newName = renamer.dimensionNameOf(oldName);
+ if (!newName.isPresent())
+ return type; // presumably, already renamed
+ TensorType.Dimension.Type dimensionType = dimension.type();
+ if (dimensionType == TensorType.Dimension.Type.indexedBound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
+ } else if (dimensionType == TensorType.Dimension.Type.indexedUnbound) {
+ renamedDimensions.add(TensorType.Dimension.indexed(newName.get()));
+ } else if (dimensionType == TensorType.Dimension.Type.mapped) {
+ renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
+ }
+ }
+ return new OrderedTensorType(renamedDimensions);
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node) {
+ return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
+ }
+
+ public static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
+ Builder builder = new Builder(node);
+ TensorShapeProto shape = tensorFlowShape(node);
+ for (int i = 0; i < shape.getDimCount(); ++ i) {
+ String dimensionName = dimensionPrefix + i;
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
+ if (tensorFlowDimension.getSize() >= 0) {
+ builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
+ } else {
+ builder.add(TensorType.Dimension.indexed(dimensionName));
+ }
+ }
+ return builder.build();
+ }
+
+ public static class Builder {
+
+ private final TensorShapeProto shape;
+ private final List<TensorType.Dimension> dimensions;
+
+ public Builder(NodeDef node) {
+ this.shape = tensorFlowShape(node);
+ this.dimensions = new ArrayList<>(shape.getDimCount());
+ }
+
+ public Builder add(TensorType.Dimension vespaDimension) {
+ int index = dimensions.size();
+ TensorShapeProto.Dim tensorFlowDimension = shape.getDim(index);
+ long size = tensorFlowDimension.getSize();
+ if (size >= 0) {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension types");
+ }
+ if (!vespaDimension.size().isPresent()) {
+ throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
+ "not have a size");
+ }
+ if (vespaDimension.size().get() != size) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get());
+ }
+ } else {
+ if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
+ throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
+ "dimension types");
+ }
+ }
+ this.dimensions.add(vespaDimension);
+ return this;
+ }
+
+ public OrderedTensorType build() {
+ return new OrderedTensorType(dimensions);
+ }
+
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
new file mode 100644
index 00000000000..3f55e622fdf
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/TensorConverter.java
@@ -0,0 +1,224 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import org.tensorflow.framework.TensorProto;
+
+import java.nio.ByteBuffer;
+import java.nio.DoubleBuffer;
+import java.nio.FloatBuffer;
+import java.nio.IntBuffer;
+import java.nio.LongBuffer;
+
+
+/**
+ * Converts TensorFlow tensors into Vespa tensors.
+ *
+ * @author bratseth
+ * @author lesters
+ */
+public class TensorConverter {
+
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
+ return toVespaTensor(tfTensor, "d");
+ }
+
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ for (int i = 0; i < values.size(); i++)
+ builder.cellByDirectIndex(i, values.get(i));
+ return builder.build();
+ }
+
+ public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
+ Values values = readValuesOf(tfTensor);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
+ for (int i = 0; i < values.size(); i++) {
+ builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
+ }
+ return builder.build();
+ }
+
+ public static Tensor toVespaTensor(TensorProto tensorProto, TensorType type) {
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
+ Values values = readValuesOf(tensorProto);
+ for (int i = 0; i < values.size(); ++i) {
+ builder.cellByDirectIndex(i, values.get(i));
+ }
+ return builder.build();
+ }
+
+ private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder();
+ int dimensionIndex = 0;
+ for (long dimensionSize : shape) {
+ if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
+ b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
+ }
+ return b.build();
+ }
+
+ public static Long tensorSize(TensorType type) {
+ Long size = 1L;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ size *= dimensionSize(dimension);
+ }
+ return size;
+ }
+
+ public static Long dimensionSize(TensorType.Dimension dim) {
+ return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
+ }
+
+ private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
+ switch (tfTensor.dataType()) {
+ case DOUBLE: return new DoubleValues(tfTensor);
+ case FLOAT: return new FloatValues(tfTensor);
+ case BOOL: return new BoolValues(tfTensor);
+ case UINT8: return new IntValues(tfTensor);
+ case INT32: return new IntValues(tfTensor);
+ case INT64: return new LongValues(tfTensor);
+ }
+ throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
+ tfTensor.dataType() + " to a Vespa tensor");
+ }
+
+ private static Values readValuesOf(TensorProto tensorProto) {
+ switch (tensorProto.getDtype()) {
+ case DT_BOOL:
+ return new ProtoBoolValues(tensorProto);
+ case DT_HALF:
+ return new ProtoHalfValues(tensorProto);
+ case DT_INT16:
+ case DT_INT32:
+ return new ProtoIntValues(tensorProto);
+ case DT_INT64:
+ return new ProtoInt64Values(tensorProto);
+ case DT_FLOAT:
+ return new ProtoFloatValues(tensorProto);
+ case DT_DOUBLE:
+ return new ProtoDoubleValues(tensorProto);
+ }
+ throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
+ }
+
+ /** Allows reading values from buffers of various numeric types as bytes */
+ private static abstract class Values {
+ abstract double get(int i);
+ abstract int size();
+ }
+
+ private static abstract class TensorFlowValues extends Values {
+ private final int size;
+ TensorFlowValues(int size) {
+ this.size = size;
+ }
+ @Override int size() { return this.size; }
+ }
+
+ private static class DoubleValues extends TensorFlowValues {
+ private final DoubleBuffer values;
+ DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = DoubleBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class FloatValues extends TensorFlowValues {
+ private final FloatBuffer values;
+ FloatValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = FloatBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class BoolValues extends TensorFlowValues {
+ private final ByteBuffer values;
+ BoolValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = ByteBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class IntValues extends TensorFlowValues {
+ private final IntBuffer values;
+ IntValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = IntBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static class LongValues extends TensorFlowValues {
+ private final LongBuffer values;
+ LongValues(org.tensorflow.Tensor<?> tfTensor) {
+ super(tfTensor.numElements());
+ values = LongBuffer.allocate(tfTensor.numElements());
+ tfTensor.writeTo(values);
+ }
+ @Override double get(int i) {
+ return values.get(i);
+ }
+ }
+
+ private static abstract class ProtoValues extends Values {
+ protected final TensorProto tensorProto;
+ protected ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
+ }
+
+ private static class ProtoBoolValues extends ProtoValues {
+ ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
+ @Override int size() { return tensorProto.getBoolValCount(); }
+ }
+
+ private static class ProtoHalfValues extends ProtoValues {
+ ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getHalfVal(i); }
+ @Override int size() { return tensorProto.getHalfValCount(); }
+ }
+
+ private static class ProtoIntValues extends ProtoValues {
+ ProtoIntValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getIntVal(i); }
+ @Override int size() { return tensorProto.getIntValCount(); }
+ }
+
+ private static class ProtoInt64Values extends ProtoValues {
+ ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getInt64Val(i); }
+ @Override int size() { return tensorProto.getInt64ValCount(); }
+ }
+
+ private static class ProtoFloatValues extends ProtoValues {
+ ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getFloatVal(i); }
+ @Override int size() { return tensorProto.getFloatValCount(); }
+ }
+
+ private static class ProtoDoubleValues extends ProtoValues {
+ ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); }
+ @Override double get(int i) { return tensorProto.getDoubleVal(i); }
+ @Override int size() { return tensorProto.getDoubleValCount(); }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
new file mode 100644
index 00000000000..7decef51ab7
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Const.java
@@ -0,0 +1,93 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Const extends TensorFlowOperation {
+
+ public Const(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ setConstantValue(value());
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ }
+
+ @Override
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ function = lazyGetFunction();
+ }
+ return Optional.ofNullable(function);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ ExpressionNode expressionNode;
+ if (type.type().rank() == 0 && getConstantValue().isPresent()) {
+ expressionNode = new ConstantNode(getConstantValue().get().asDoubleValue());
+ } else {
+ expressionNode = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ }
+ return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ setConstantValue(value());
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ private Value value() {
+ if (!node.getAttrMap().containsKey("value")) {
+ throw new IllegalArgumentException("Node '" + node.getName() + "' of type " +
+ "const has missing 'value' attribute");
+ }
+ AttrValue attrValue = node.getAttrMap().get("value");
+ if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
+ return new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type().get().type()));
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
+ return new BooleanValue(attrValue.getB());
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
+ return new DoubleValue(attrValue.getI());
+ }
+ if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
+ return new DoubleValue(attrValue.getF());
+ }
+ throw new IllegalArgumentException("Requesting value of constant in " +
+ node.getName() + " but type is not recognized.");
+ }
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
new file mode 100644
index 00000000000..c1ad21f41d8
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ExpandDims.java
@@ -0,0 +1,107 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+public class ExpandDims extends TensorFlowOperation {
+
+ private List<String> expandDimensions;
+
+ public ExpandDims(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+
+ TensorFlowOperation axisOperation = inputs().get(1);
+ if (!axisOperation.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ "axis must be a constant.");
+ }
+ Tensor axis = axisOperation.getConstantValue().get().asTensor();
+ if (axis.type().rank() != 0) {
+ throw new IllegalArgumentException("ExpandDims in " + node.getName() + ": " +
+ "axis argument must be a scalar.");
+ }
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ int dimensionToInsert = (int)axis.asDouble();
+ if (dimensionToInsert < 0) {
+ dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ expandDimensions = new ArrayList<>();
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : inputType.dimensions()) {
+ if (dimensionIndex == dimensionToInsert) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+ typeBuilder.add(dimension);
+ dimensionIndex++;
+ }
+
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : expandDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ return new com.yahoo.tensor.functions.Join(inputs().get(0).function().get(), generatedFunction, ScalarFunctions.multiply());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
+ for (String name : expandDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ expandDimensions = renamedDimensions;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
new file mode 100644
index 00000000000..d79707a42e6
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Identity.java
@@ -0,0 +1,30 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Identity extends TensorFlowOperation {
+
+ public Identity(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1))
+ return null;
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1))
+ return null;
+ return inputs.get(0).function().orElse(null);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
new file mode 100644
index 00000000000..aa27ba2684d
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java
@@ -0,0 +1,79 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleBinaryOperator;
+
+public class Join extends TensorFlowOperation {
+
+ private final DoubleBinaryOperator operator;
+
+ public Join(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleBinaryOperator operator) {
+ super(node, inputs, port);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType a = inputs.get(0).type().get();
+ OrderedTensorType b = inputs.get(1).type().get();
+ OrderedTensorType out = a.type().rank() >= b.type().rank() ? a : b;
+ return out;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+
+ // The dimension renaming below takes care of broadcasting.
+
+ return new com.yahoo.tensor.functions.Join(aFunction.get(), bFunction.get(), operator);
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+
+ // Well now we have potentially entered the wonderful world of "broadcasting"
+ // https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html
+ // I'm not able to extract from that any unambiguous specification of which dimensions
+ // should be "stretched" when the tensor do not have the same number of dimensions.
+ // From trying this with TensorFlow it appears that the second tensor is matched to the
+ // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true.
+ // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first).
+
+ TensorType a = inputs.get(0).type().get().type();
+ TensorType b = inputs.get(1).type().get().type();
+ if (a.rank() < b.rank()) {
+ TensorType temp = a;
+ a = b;
+ b = temp;
+ }
+ int sizeDifference = a.rank() - b.rank();
+ for (int i = 0; i < b.rank(); ++i) {
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i + sizeDifference).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
new file mode 100644
index 00000000000..105d65b3d69
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Map.java
@@ -0,0 +1,38 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+import java.util.function.DoubleUnaryOperator;
+
+public class Map extends TensorFlowOperation {
+
+ private final DoubleUnaryOperator operator;
+
+ public Map(NodeDef node, List<TensorFlowOperation> inputs, int port, DoubleUnaryOperator operator) {
+ super(node, inputs, port);
+ this.operator = operator;
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ return inputs.get(0).type().get();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ Optional<TensorFunction> input = inputs.get(0).function();
+ return new com.yahoo.tensor.functions.Map(input.get(), operator);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
new file mode 100644
index 00000000000..ac4f78653d6
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Matmul.java
@@ -0,0 +1,74 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Matmul extends TensorFlowOperation {
+
+ public Matmul(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
+ typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ OrderedTensorType aType = inputs.get(0).type().get();
+ OrderedTensorType bType = inputs.get(1).type().get();
+ if (aType.type().rank() < 2 || bType.type().rank() < 2)
+ throw new IllegalArgumentException("Tensors in matmul must have rank of at least 2");
+ if (aType.type().rank() != bType.type().rank())
+ throw new IllegalArgumentException("Tensors in matmul must have the same rank");
+
+ Optional<TensorFunction> aFunction = inputs.get(0).function();
+ Optional<TensorFunction> bFunction = inputs.get(1).function();
+ if (!aFunction.isPresent() || !bFunction.isPresent()) {
+ return null;
+ }
+ return new com.yahoo.tensor.functions.Matmul(aFunction.get(), bFunction.get(), aType.dimensions().get(1).name());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(2)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // The second dimension of a should have the same name as the first dimension of b
+ renamer.addConstraint(aDim1, bDim0, DimensionRenamer::equals, this);
+
+ // The first dimension of a should have a different name than the second dimension of b
+ renamer.addConstraint(aDim0, bDim1, DimensionRenamer::lesserThan, this);
+
+ // For efficiency, the dimensions to join over should be innermost - soft constraint
+ renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
+ renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
new file mode 100644
index 00000000000..dfe0796d9b8
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Mean.java
@@ -0,0 +1,112 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.Optional;
+
+public class Mean extends TensorFlowOperation {
+
+ private List<String> reduceDimensions;
+
+ public Mean(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFlowOperation reductionIndices = inputs.get(1);
+ if (!reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Mean in " + node.getName() + ": " +
+ "reduction indices must be a constant.");
+ }
+ Tensor indices = reductionIndices.getConstantValue().get().asTensor();
+ reduceDimensions = new ArrayList<>();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ for (Iterator<Tensor.Cell> cellIterator = indices.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int dimensionIndex = cell.getValue().intValue();
+ if (dimensionIndex < 0) {
+ dimensionIndex = inputType.dimensions().size() - dimensionIndex;
+ }
+ reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
+ }
+ return reducedType(inputType, shouldKeepDimensions());
+ }
+
+ // todo: optimization: if keepDims and one reduce dimension that has size 1: same as identity.
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
+ if (shouldKeepDimensions()) {
+ // multiply with a generated tensor created from the reduced dimensions
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
+ }
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ output = new com.yahoo.tensor.functions.Join(output, generatedFunction, ScalarFunctions.multiply());
+ }
+ return output;
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
+ for (String name : reduceDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ reduceDimensions = renamedDimensions;
+ }
+
+ private boolean shouldKeepDimensions() {
+ AttrValue keepDimsAttr = node.getAttrMap().get("keep_dims");
+ return keepDimsAttr != null && keepDimsAttr.getB();
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if (!reduceDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ } else if (keepDimensions) {
+ builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
new file mode 100644
index 00000000000..d3561716725
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Merge.java
@@ -0,0 +1,36 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Merge extends TensorFlowOperation {
+
+ public Merge(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ for (TensorFlowOperation operation : inputs) {
+ if (operation.type().isPresent()) {
+ return operation.type().get();
+ }
+ }
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ for (TensorFlowOperation operation : inputs) {
+ if (operation.function().isPresent()) {
+ return operation.function().get();
+ }
+ }
+ return null;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
new file mode 100644
index 00000000000..acf5d13b057
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/NoOp.java
@@ -0,0 +1,32 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.Collections;
+import java.util.List;
+
+public class NoOp extends TensorFlowOperation {
+
+ public NoOp(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, Collections.emptyList(), port); // don't propagate inputs
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return null;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
new file mode 100644
index 00000000000..dadce395faf
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Placeholder.java
@@ -0,0 +1,57 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.VariableTensor;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Placeholder extends TensorFlowOperation {
+
+ private OrderedTensorType standardNamingType; // using standard naming convention: d0, d1, ...
+
+ public Placeholder(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ standardNamingType = OrderedTensorType.fromTensorFlowType(node);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFunction output = new VariableTensor(vespaName(), standardNamingType.type());
+ if (!standardNamingType.equals(type)) {
+ List<String> renameFrom = standardNamingType.dimensionNames();
+ List<String> renameTo = type.dimensionNames();
+ output = new Rename(output, renameFrom, renameTo);
+ }
+ return output;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isInput() {
+ return true;
+ }
+
+ @Override
+ public boolean isConstant() {
+ return false;
+ }
+
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
new file mode 100644
index 00000000000..ab091b77a65
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
@@ -0,0 +1,50 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class PlaceholderWithDefault extends TensorFlowOperation {
+
+ public PlaceholderWithDefault(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ return inputs().get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ // This should be a call to the macro we add below, but for now
+ // we treat this as as identity function and just pass the constant.
+ return inputs.get(0).function().orElse(null);
+ }
+
+ @Override
+ public Optional<RankingExpression> macro() {
+ // For now, it is much more efficient to assume we always will return
+ // the default value, as we can prune away large parts of the expression
+ // tree by having it calculated as a constant. If a case arises where
+ // it is important to support this, implement this.
+ return Optional.empty();
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true; // not true if we add to macro
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
new file mode 100644
index 00000000000..9b3e28ce56b
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Reshape.java
@@ -0,0 +1,135 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
+import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
+import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.Iterator;
+import java.util.List;
+import java.util.stream.Collectors;
+
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+
+public class Reshape extends TensorFlowOperation {
+
+ public Reshape(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ TensorFlowOperation newShape = inputs.get(1);
+ if (!newShape.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Reshape in " + node.getName() + ": " +
+ "shape input must be a constant.");
+ }
+ Tensor shape = newShape.getConstantValue().get().asTensor();
+
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(node);
+ int dimensionIndex = 0;
+ for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
+ Tensor.Cell cell = cellIterator.next();
+ int size = cell.getValue().intValue();
+ if (size < 0) {
+ size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
+ tensorSize(inputType.type()).intValue();
+ }
+ outputTypeBuilder.add(TensorType.Dimension.indexed(
+ String.format("%s_%d", vespaName(), dimensionIndex), size));
+ dimensionIndex++;
+ }
+ return outputTypeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ if (!allInputFunctionsPresent(2)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return reshape(inputFunction, inputType.type(), type.type());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
+ if (!tensorSize(inputType).equals(tensorSize(outputType))) {
+ throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
+ }
+
+ // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
+ // then use the dimension order of the new shape to roll back into a tensor.
+ // Here we create a transformation tensor that is multiplied with the from tensor to map into
+ // the new shape. We have to introduce temporary dimension names and rename back if dimension names
+ // in the new and old tensor type overlap.
+
+ ExpressionNode unrollFrom = unrollTensorExpression(inputType);
+ ExpressionNode unrollTo = unrollTensorExpression(outputType);
+ ExpressionNode transformExpression = new ComparisonNode(unrollFrom, TruthOperator.EQUAL, unrollTo);
+
+ TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
+ Generate transformTensor = new Generate(transformationType,
+ new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
+
+ TensorFunction outputFunction = new Reduce(
+ new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+
+ return outputFunction;
+ }
+
+ private static ExpressionNode unrollTensorExpression(TensorType type) {
+ if (type.rank() == 0) {
+ return new ConstantNode(DoubleValue.zero);
+ }
+ List<ExpressionNode> children = new ArrayList<>();
+ List<ArithmeticOperator> operators = new ArrayList<>();
+ int size = 1;
+ for (int i = type.dimensions().size() - 1; i >= 0; --i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ children.add(0, new ReferenceNode(dimension.name()));
+ if (size > 1) {
+ operators.add(0, ArithmeticOperator.MULTIPLY);
+ children.add(0, new ConstantNode(new DoubleValue(size)));
+ }
+ size *= TensorConverter.dimensionSize(dimension);
+ if (i > 0) {
+ operators.add(0, ArithmeticOperator.PLUS);
+ }
+ }
+ return new ArithmeticNode(children, operators);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
new file mode 100644
index 00000000000..6a29d428cf3
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Select.java
@@ -0,0 +1,89 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.function.DoubleBinaryOperator;
+
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.dimensionSize;
+import static com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter.tensorSize;
+
+public class Select extends TensorFlowOperation {
+
+ public Select(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(3)) {
+ return null;
+ }
+ OrderedTensorType a = inputs.get(1).type().get();
+ OrderedTensorType b = inputs.get(2).type().get();
+ if ((a.type().rank() != b.type().rank()) || !(tensorSize(a.type()).equals(tensorSize(b.type())))) {
+ throw new IllegalArgumentException("'Select': input tensors must have the same shape");
+ }
+ return a;
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(3)) {
+ return null;
+ }
+ TensorFlowOperation conditionOperation = inputs().get(0);
+ TensorFunction a = inputs().get(1).function().get();
+ TensorFunction b = inputs().get(2).function().get();
+
+ // Shortcut: if we know during import which tensor to select, do that directly here.
+ if (conditionOperation.getConstantValue().isPresent()) {
+ Tensor condition = conditionOperation.getConstantValue().get().asTensor();
+ if (condition.type().rank() == 0) {
+ return ((int) condition.asDouble() == 0) ? b : a;
+ }
+ if (condition.type().rank() == 1 && dimensionSize(condition.type().dimensions().get(0)) == 1) {
+ return condition.cellIterator().next().getValue().intValue() == 0 ? b : a;
+ }
+ }
+
+ // The task is to select cells from 'x' or 'y' based on 'condition'.
+ // If 'condition' is 0 (false), select from 'y', if 1 (true) select
+ // from 'x'. We do this by individually joining 'x' and 'y' with
+ // 'condition', and then joining the resulting two tensors.
+
+ TensorFunction conditionFunction = conditionOperation.function().get();
+ TensorFunction aCond = new com.yahoo.tensor.functions.Join(a, conditionFunction, ScalarFunctions.multiply());
+ TensorFunction bCond = new com.yahoo.tensor.functions.Join(b, conditionFunction, new DoubleBinaryOperator() {
+ @Override public double applyAsDouble(double a, double b) { return a * (1.0 - b); }
+ @Override public String toString() { return "f(a,b)(a * (1-b))"; }
+ });
+ return new com.yahoo.tensor.functions.Join(aCond, bCond, ScalarFunctions.add());
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!allInputTypesPresent(3)) {
+ return;
+ }
+ List<TensorType.Dimension> aDimensions = inputs.get(1).type().get().dimensions();
+ List<TensorType.Dimension> bDimensions = inputs.get(2).type().get().dimensions();
+
+ String aDim0 = aDimensions.get(0).name();
+ String aDim1 = aDimensions.get(1).name();
+ String bDim0 = bDimensions.get(0).name();
+ String bDim1 = bDimensions.get(1).name();
+
+ // These tensors should have the same dimension names
+ renamer.addConstraint(aDim0, bDim0, DimensionRenamer::equals, this);
+ renamer.addConstraint(aDim1, bDim1, DimensionRenamer::equals, this);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
new file mode 100644
index 00000000000..8f4313022e0
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Shape.java
@@ -0,0 +1,55 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Shape extends TensorFlowOperation {
+
+ public Shape(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ createConstantValue();
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ return new OrderedTensorType.Builder(node)
+ .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
+ .build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+ private void createConstantValue() {
+ if (!allInputTypesPresent(1)) {
+ return;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type().get().type());
+ List<TensorType.Dimension> inputDimensions = inputType.dimensions();
+ for (int i = 0; i < inputDimensions.size(); i++) {
+ builder.cellByDirectIndex(i, inputDimensions.get(i).size().orElse(-1L));
+ }
+ this.setConstantValue(new TensorValue(builder.build()));
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
new file mode 100644
index 00000000000..d7750b52fc3
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Squeeze.java
@@ -0,0 +1,84 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+public class Squeeze extends TensorFlowOperation {
+
+ private List<String> squeezeDimensions;
+
+ public Squeeze(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(1)) {
+ return null;
+ }
+ OrderedTensorType inputType = inputs.get(0).type().get();
+ squeezeDimensions = new ArrayList<>();
+
+ AttrValue squeezeDimsAttr = node.getAttrMap().get("squeeze_dims");
+ if (squeezeDimsAttr == null) {
+ squeezeDimensions = inputType.type().dimensions().stream().
+ filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ } else {
+ squeezeDimensions = squeezeDimsAttr.getList().getIList().stream().
+ map(i -> i < 0 ? inputType.type().dimensions().size() - i : i).
+ map(i -> inputType.type().dimensions().get(i.intValue())).
+ filter(dim -> TensorConverter.dimensionSize(dim) == 1).
+ map(TensorType.Dimension::name).
+ collect(Collectors.toList());
+ }
+ return squeezeDimensions.isEmpty() ? inputType : reducedType(inputType);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!allInputFunctionsPresent(1)) {
+ return null;
+ }
+ TensorFunction inputFunction = inputs.get(0).function().get();
+ return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ List<String> renamedDimensions = new ArrayList<>(squeezeDimensions.size());
+ for (String name : squeezeDimensions) {
+ Optional<String> newName = renamer.dimensionNameOf(name);
+ if (!newName.isPresent()) {
+ return; // presumably, already renamed
+ }
+ renamedDimensions.add(newName.get());
+ }
+ squeezeDimensions = renamedDimensions;
+ }
+
+ private OrderedTensorType reducedType(OrderedTensorType inputType) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(node);
+ for (TensorType.Dimension dimension: inputType.type().dimensions()) {
+ if ( ! squeezeDimensions.contains(dimension.name())) {
+ builder.add(dimension);
+ }
+ }
+ return builder.build();
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
new file mode 100644
index 00000000000..1cc0e1936eb
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Switch.java
@@ -0,0 +1,48 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+import java.util.Optional;
+
+public class Switch extends TensorFlowOperation {
+
+ public Switch(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!allInputTypesPresent(2)) {
+ return null;
+ }
+ Optional<OrderedTensorType> predicate = inputs.get(1).type();
+ if (predicate.get().type().rank() != 0) {
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ "predicate must be a scalar");
+ }
+ return inputs.get(0).type().orElse(null);
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ TensorFlowOperation predicateOperation = inputs().get(1);
+ if (!predicateOperation.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ "predicate must be a constant");
+ }
+ if (port < 0 || port > 1) {
+ throw new IllegalArgumentException("Switch in " + node.getName() + ": " +
+ "choice should be boolean");
+ }
+
+ double predicate = predicateOperation.getConstantValue().get().asDouble();
+ return predicate == port ? inputs().get(0).function().get() : null;
+ }
+
+}
+
+
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
new file mode 100644
index 00000000000..fd9dfd167fb
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -0,0 +1,136 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
+ * During import, a graph of these operations are constructed. Then, the
+ * types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper
+ * Vespa expressions can be extracted.
+ *
+ * @author lesters
+ */
+public abstract class TensorFlowOperation {
+
+ protected final NodeDef node;
+ protected final int port;
+ protected final List<TensorFlowOperation> inputs;
+ protected final List<TensorFlowOperation> outputs = new ArrayList<>();
+ protected final List<String> importWarnings = new ArrayList<>();
+
+ protected OrderedTensorType type;
+ protected TensorFunction function;
+
+ private Value constantValue = null;
+ private List<TensorFlowOperation> controlInputs = Collections.emptyList();
+
+ TensorFlowOperation(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ this.node = node;
+ this.port = port;
+ this.inputs = Collections.unmodifiableList(inputs);
+ this.inputs.forEach(i -> i.outputs.add(this));
+ }
+
+ protected abstract OrderedTensorType lazyGetType();
+ protected abstract TensorFunction lazyGetFunction();
+
+ /** Returns the Vespa tensor type of this operation if it exists */
+ public Optional<OrderedTensorType> type() {
+ if (type == null) {
+ type = lazyGetType();
+ }
+ OrderedTensorType.verifyType(node, type);
+ return Optional.ofNullable(type);
+ }
+
+ /** Returns the Vespa tensor function implementing all operations from this node with inputs */
+ public Optional<TensorFunction> function() {
+ if (function == null) {
+ if (isConstant()) {
+ ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")");
+ function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ } else {
+ function = lazyGetFunction();
+ }
+ }
+ return Optional.ofNullable(function);
+ }
+
+ /** Return TensorFlow node */
+ public NodeDef node() { return node; }
+
+ /** Return unmodifiable list of inputs */
+ public List<TensorFlowOperation> inputs() { return inputs; }
+
+ /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
+ public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
+
+ /** Returns a Vespa ranking expression that should be added as a macro */
+ public Optional<RankingExpression> macro() { return Optional.empty(); }
+
+ /** Add dimension name constraints for this operation */
+ public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+
+ /** Performs dimension rename for this operation */
+ public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); }
+
+ /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
+ public boolean isInput() { return false; }
+
+ /** Return true if this node is constant */
+ public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }
+
+ /** Sets the constant value */
+ public void setConstantValue(Value value) { constantValue = value; }
+
+ /** Gets the constant value if it exists */
+ public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }
+
+ /** Sets the external control inputs */
+ public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }
+
+ /** Retrieve the control inputs for this operation */
+ public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }
+
+ /** Retrieve the valid Vespa name of this node */
+ public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
+
+ /** Retrieve the list of warnings produced during its lifetime */
+ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
+
+ boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
+ if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
+ return false;
+ }
+ if (inputs.size() != expected) {
+ throw new IllegalArgumentException("Expected " + expected + " inputs " +
+ "for '" + node.getName() + "', got " + inputs.size());
+ }
+ return inputs.stream().map(func).allMatch(Optional::isPresent);
+ }
+
+ boolean allInputTypesPresent(int expected) {
+ return verifyInputs(expected, TensorFlowOperation::type);
+ }
+
+ boolean allInputFunctionsPresent(int expected) {
+ return verifyInputs(expected, TensorFlowOperation::function);
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
new file mode 100644
index 00000000000..6f377c4bda2
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Variable.java
@@ -0,0 +1,40 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.List;
+
+public class Variable extends TensorFlowOperation {
+
+ public Variable(NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ return OrderedTensorType.fromTensorFlowType(node, vespaName() + "_");
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ return null; // will be added by function() since this is constant.
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ for (TensorType.Dimension dimension : type.type().dimensions()) {
+ renamer.addDimension(dimension.name());
+ }
+ }
+
+ @Override
+ public boolean isConstant() {
+ return true;
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
new file mode 100644
index 00000000000..ebcfde54c70
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DimensionRenamerTest.java
@@ -0,0 +1,49 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import org.junit.Test;
+
+import static junit.framework.TestCase.assertTrue;
+
+public class DimensionRenamerTest {
+
+ @Test
+ public void testMnistRenaming() {
+ DimensionRenamer renamer = new DimensionRenamer();
+
+ renamer.addDimension("first_dimension_of_x");
+ renamer.addDimension("second_dimension_of_x");
+ renamer.addDimension("first_dimension_of_w");
+ renamer.addDimension("second_dimension_of_w");
+ renamer.addDimension("first_dimension_of_b");
+
+ // which dimension to join on matmul
+ renamer.addConstraint("second_dimension_of_x", "first_dimension_of_w", DimensionRenamer::equals, null);
+
+ // other dimensions in matmul can't be equal
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_w", DimensionRenamer::lesserThan, null);
+
+ // for efficiency, put dimension to join on innermost
+ renamer.addConstraint("first_dimension_of_x", "second_dimension_of_x", DimensionRenamer::lesserThan, null);
+ renamer.addConstraint("first_dimension_of_w", "second_dimension_of_w", DimensionRenamer::greaterThan, null);
+
+ // bias
+ renamer.addConstraint("second_dimension_of_w", "first_dimension_of_b", DimensionRenamer::equals, null);
+
+ renamer.solve();
+
+ String firstDimensionOfXName = renamer.dimensionNameOf("first_dimension_of_x").get();
+ String secondDimensionOfXName = renamer.dimensionNameOf("second_dimension_of_x").get();
+ String firstDimensionOfWName = renamer.dimensionNameOf("first_dimension_of_w").get();
+ String secondDimensionOfWName = renamer.dimensionNameOf("second_dimension_of_w").get();
+ String firstDimensionOfBName = renamer.dimensionNameOf("first_dimension_of_b").get();
+
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfXName) < 0);
+ assertTrue(firstDimensionOfWName.compareTo(secondDimensionOfWName) > 0);
+ assertTrue(secondDimensionOfXName.compareTo(firstDimensionOfWName) == 0);
+ assertTrue(firstDimensionOfXName.compareTo(secondDimensionOfWName) < 0);
+ assertTrue(secondDimensionOfWName.compareTo(firstDimensionOfBName) == 0);
+
+
+ }
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index 3b25bfe1b1e..f64d697d9b9 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -18,11 +18,6 @@ public class DropoutImportTestCase {
public void testDropoutImport() {
TestableTensorFlowModel model = new TestableTensorFlowModel("src/test/files/integration/tensorflow/dropout/saved");
- // Check (provided) macros
- assertEquals(1, model.get().macros().size());
- assertTrue(model.get().macros().containsKey("training_input"));
- assertEquals("constant(\"training_input\")", model.get().macros().get("training_input").getRoot().toString());
-
// Check required macros
assertEquals(1, model.get().requiredMacros().size());
assertTrue(model.get().requiredMacros().containsKey("X"));
@@ -37,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(rename(reduce(join(X, rename(constant(\"outputs_kernel\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"outputs_bias\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
index ad5abd4c03d..60dd3865aa1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/MnistSoftmaxImportTestCase.java
@@ -22,15 +22,15 @@ public class MnistSoftmaxImportTestCase {
// Check constants
assertEquals(2, model.get().largeConstants().size());
- Tensor constant0 = model.get().largeConstants().get("Variable");
+ Tensor constant0 = model.get().largeConstants().get("Variable_read");
assertNotNull(constant0);
- assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
- Tensor constant1 = model.get().largeConstants().get("Variable_1");
+ Tensor constant1 = model.get().largeConstants().get("Variable_1_read");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
constant1.type());
assertEquals(10, constant1.size());
@@ -59,12 +59,10 @@ public class MnistSoftmaxImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("join(rename(reduce(join(Placeholder, rename(constant(\"Variable\"), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant(\"Variable_1\"), d0, d1), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"Variable_1_read\"), f(a,b)(a + b))",
output.getRoot().toString());
// Test execution
- model.assertEqualResult("Placeholder", "Variable/read");
- model.assertEqualResult("Placeholder", "Variable_1/read");
model.assertEqualResult("Placeholder", "MatMul");
model.assertEqualResult("Placeholder", "add");
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index ae7714b271a..1691756a64d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -4,6 +4,7 @@ package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.SavedModelBundle;
@@ -47,8 +48,11 @@ public class TestableTensorFlowModel {
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
Session.Runner runner = model.session().runner();
- org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size },
- FloatBuffer.allocate(d0Size * d1Size));
+ FloatBuffer fb = FloatBuffer.allocate(d0Size * d1Size);
+ for (int i = 0; i < d1Size; ++i) {
+ fb.put(i, (float)(i * 1.0 / d1Size));
+ }
+ org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ d0Size, d1Size }, fb);
runner.feed(inputName, placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
@@ -66,7 +70,7 @@ public class TestableTensorFlowModel {
Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", d0Size).indexed("d1", d1Size).build());
for (int d0 = 0; d0 < d0Size; d0++)
for (int d1 = 0; d1 < d1Size; d1++)
- b.cell(0, d0, d1);
+ b.cell(d1 * 1.0 / d1Size, d0, d1);
return b.build();
}