diff options
Diffstat (limited to 'config-model')
2 files changed, 98 insertions, 4 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index f16697b5ba6..864cd823728 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -23,10 +23,18 @@ import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; @@ -37,9 +45,11 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; /** @@ -197,7 +207,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil */ private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, RankProfile profile, QueryProfileRegistry queryProfiles) { - List<String> macroNames = new ArrayList<>(); + Set<String> macroNames = new HashSet<>(); addMacroNamesIn(expression.getRoot(), macroNames); for (String macroName : macroNames) { TensorType requiredType = requiredMacros.get(macroName); @@ -223,10 +233,84 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil "' of type " + requiredType + " which must be produced by a macro in the rank profile, but " + "this macro produces type " + actualType); + + // Check if batch dimensions can be reduced out. + reduceBatchDimensions(expression, macro, actualType); + } + } + + /** + * If the macro specifies that a single exemplar should be + * evaluated, we can reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, RankProfile.Macro macro, TensorType type) { + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, macro, reduceDimensions); + root = expandBatchDimensionsAtOutput(root, reduceDimensions); // todo: determine when we can skip this + expression.setRoot(root); + } + } + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, + RankProfile.Macro macro, + List<String> reduceDimensions) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List<ExpressionNode> children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (referenceNode.getName().equals(macro.getName())) { + return reduceBatchDimensionExpression(tensorFunction, reduceDimensions); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (referenceNode.getName().equals(macro.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), reduceDimensions); + } + } + if (node instanceof CompositeNode) { + List<ExpressionNode> children = ((CompositeNode)node).children(); + List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, macro, reduceDimensions)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) { + return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); + } + + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, + List<String> reduceDimensions) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); } - private void addMacroNamesIn(ExpressionNode node, List<String> names) { + private void addMacroNamesIn(ExpressionNode node, Set<String> names) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; if (referenceNode.getOutput() == null) // macro references cannot specify outputs diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 19d37c5fb44..beba8ade1d8 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -252,10 +252,20 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test + public void testTensorFlowReduceBatchDimension() { + final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", + "tensorflow('mnist_softmax/saved')"); + search.assertFirstPhaseExpression(expression, "my_profile"); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + assertLargeConstant("layer_Variable_read", search, Optional.of(7840L)); + } + + @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; + final String expression = "join(join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); - RankProfileSearchFixture search = fixtureWith("tensor(d0[2],d1[784])(0.0)", + RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", null, null, |