aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-03-01 09:20:09 +0100
committerLester Solbakken <lesters@oath.com>2018-03-01 09:20:09 +0100
commit413259f26b658617a7482a72926088751adab521 (patch)
tree527330a94052f2abdbe43e9e01c6137c16683694
parentbc3ccdb3552d0d3ff5dcc463308614e72e6abd3e (diff)
Add batch dimension reduction for TensorFlow import
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java88
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java14
2 files changed, 98 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index f16697b5ba6..864cd823728 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -23,10 +23,18 @@ import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.functions.Generate;
+import com.yahoo.tensor.functions.Join;
+import com.yahoo.tensor.functions.Reduce;
+import com.yahoo.tensor.functions.Rename;
+import com.yahoo.tensor.functions.ScalarFunctions;
+import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import java.io.BufferedReader;
@@ -37,9 +45,11 @@ import java.io.UncheckedIOException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
+import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.Set;
import java.util.logging.Logger;
/**
@@ -197,7 +207,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
*/
private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros,
RankProfile profile, QueryProfileRegistry queryProfiles) {
- List<String> macroNames = new ArrayList<>();
+ Set<String> macroNames = new HashSet<>();
addMacroNamesIn(expression.getRoot(), macroNames);
for (String macroName : macroNames) {
TensorType requiredType = requiredMacros.get(macroName);
@@ -223,10 +233,84 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
"' of type " + requiredType +
" which must be produced by a macro in the rank profile, but " +
"this macro produces type " + actualType);
+
+ // Check if batch dimensions can be reduced out.
+ reduceBatchDimensions(expression, macro, actualType);
+ }
+ }
+
+ /**
+ * If the macro specifies that a single exemplar should be
+ * evaluated, we can reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, RankProfile.Macro macro, TensorType type) {
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, macro, reduceDimensions);
+ root = expandBatchDimensionsAtOutput(root, reduceDimensions); // todo: determine when we can skip this
+ expression.setRoot(root);
+ }
+ }
+ }
+
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node,
+ RankProfile.Macro macro,
+ List<String> reduceDimensions) {
+ if (node instanceof TensorFunctionNode) {
+ TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
+ if (tensorFunction instanceof Rename) {
+ List<ExpressionNode> children = ((TensorFunctionNode)node).children();
+ if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) children.get(0);
+ if (referenceNode.getName().equals(macro.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, reduceDimensions);
+ }
+ }
+ }
+ }
+ if (node instanceof ReferenceNode) {
+ ReferenceNode referenceNode = (ReferenceNode) node;
+ if (referenceNode.getName().equals(macro.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), reduceDimensions);
+ }
+ }
+ if (node instanceof CompositeNode) {
+ List<ExpressionNode> children = ((CompositeNode)node).children();
+ List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
+ for (ExpressionNode child : children) {
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, macro, reduceDimensions));
+ }
+ return ((CompositeNode)node).setChildren(transformedChildren);
+ }
+ return node;
+ }
+
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) {
+ return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions));
+ }
+
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node,
+ List<String> reduceDimensions) {
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (String name : reduceDimensions) {
+ typeBuilder.indexed(name, 1);
}
+ TensorType generatedType = typeBuilder.build();
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
+ Generate generatedFunction = new Generate(generatedType,
+ new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ return new TensorFunctionNode(expand);
}
- private void addMacroNamesIn(ExpressionNode node, List<String> names) {
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
if (referenceNode.getOutput() == null) // macro references cannot specify outputs
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 19d37c5fb44..beba8ade1d8 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -252,10 +252,20 @@ public class RankingExpressionWithTensorFlowTestCase {
}
@Test
+ public void testTensorFlowReduceBatchDimension() {
+ final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
+ "tensorflow('mnist_softmax/saved')");
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L));
+ assertLargeConstant("layer_Variable_read", search, Optional.of(7840L));
+ }
+
+ @Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
+ final String expression = "join(join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
- RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)",
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
null,
null,