aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-04-20 15:27:01 +0200
committerLester Solbakken <lesters@oath.com>2018-04-20 15:27:01 +0200
commit12800de64f4c0631e2cdd38f0bce5357d15f9ea7 (patch)
treefdccef89eb652594d02bc56a79d49367c9fa1d50 /searchlib/src/main/java/com
parent2eebd9206d26608253844bcb2cf84f64b5f20553 (diff)
Add Tensorflow concat operation
Diffstat (limited to 'searchlib/src/main/java/com')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java116
2 files changed, 118 insertions, 0 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
index d5a3d2d69a3..977b18b9ab3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OperationMapper.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ConcatV2;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.ExpandDims;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Identity;
@@ -35,6 +36,7 @@ public class OperationMapper {
public static TensorFlowOperation get(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
switch (node.getOp().toLowerCase()) {
// array ops
+ case "concatv2": return new ConcatV2(modelName, node, inputs, port);
case "const": return new Const(modelName, node, inputs, port);
case "expanddims": return new ExpandDims(modelName, node, inputs, port);
case "identity": return new Identity(modelName, node, inputs, port);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
new file mode 100644
index 00000000000..d3bc4453edb
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/ConcatV2.java
@@ -0,0 +1,116 @@
+// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
+import org.tensorflow.framework.NodeDef;
+
+import java.util.ArrayList;
+import java.util.List;
+import java.util.Optional;
+
+public class ConcatV2 extends TensorFlowOperation {
+
+ private String concatDimensionName;
+
+ public ConcatV2(String modelName, NodeDef node, List<TensorFlowOperation> inputs, int port) {
+ super(modelName, node, inputs, port);
+ }
+
+ @Override
+ protected OrderedTensorType lazyGetType() {
+ if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ return null;
+ }
+
+ TensorFlowOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
+ if (!concatDimOp.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ "concat dimension must be a constant.");
+ }
+ Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
+ if (concatDimTensor.type().rank() != 0) {
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ "concat dimension must be a scalar.");
+ }
+
+ OrderedTensorType aType = inputs.get(0).type().get();
+
+ int concatDim = (int)concatDimTensor.asDouble();
+ long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L);
+
+ for (int i = 1; i < inputs.size() - 1; ++i) {
+ OrderedTensorType bType = inputs.get(i).type().get();
+ if (bType.rank() != aType.rank()) {
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ "inputs must have save rank.");
+ }
+ for (int j = 0; j < aType.rank(); ++j) {
+ long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
+ long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
+ if (j == concatDim) {
+ concatDimSize += dimSizeB;
+ } else if (dimSizeA != dimSizeB) {
+ throw new IllegalArgumentException("ConcatV2 in " + node.getName() + ": " +
+ "input dimension " + j + " differs in input tensors.");
+ }
+ }
+ }
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(node);
+ int dimensionIndex = 0;
+ for (TensorType.Dimension dimension : aType.dimensions()) {
+ if (dimensionIndex == concatDim) {
+ concatDimensionName = dimension.name();
+ typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize));
+ } else {
+ typeBuilder.add(dimension);
+ }
+ dimensionIndex++;
+ }
+ return typeBuilder.build();
+ }
+
+ @Override
+ protected TensorFunction lazyGetFunction() {
+ if (!inputs.stream().map(TensorFlowOperation::function).allMatch(Optional::isPresent)) {
+ return null;
+ }
+ TensorFunction result = inputs.get(0).function().get();
+ for (int i = 1; i < inputs.size() - 1; ++i) {
+ TensorFunction b = inputs.get(i).function().get();
+ result = new com.yahoo.tensor.functions.Concat(result, b, concatDimensionName);
+ }
+ return result;
+ }
+
+ @Override
+ public void addDimensionNameConstraints(DimensionRenamer renamer) {
+ if (!inputs.stream().map(TensorFlowOperation::type).allMatch(Optional::isPresent)) {
+ return;
+ }
+ OrderedTensorType a = inputs.get(0).type().get();
+ for (int i = 1; i < inputs.size() - 1; ++i) {
+ OrderedTensorType b = inputs.get(i).type().get();
+ String bDim = b.dimensions().get(i).name();
+ String aDim = a.dimensions().get(i).name();
+ renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
+ }
+ }
+
+ @Override
+ public void renameDimensions(DimensionRenamer renamer) {
+ super.renameDimensions(renamer);
+ concatDimensionName = renamer.dimensionNameOf(concatDimensionName).orElse(concatDimensionName);
+ }
+
+}