diff options
author | Lester Solbakken <lesters@oath.com> | 2018-03-08 14:10:52 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-03-08 14:14:45 +0100 |
commit | 76bd625f358a61c5962ff15cf14e321b49f26ae1 (patch) | |
tree | 9ebdc581f44d0b7539dce24049aa1be73413bfda /config-model | |
parent | 1f6f506795c4f05364327ca6a8d1c370ccbbd1e7 (diff) |
Add batch dimension reduction in generated macros
Check for inputs/placeholders containing batch dimensions inside
generated macros, and insert reduce operations. Expand those
dimensions back at the top level expression.
Diffstat (limited to 'config-model')
2 files changed, 113 insertions, 64 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 2c177633590..ce4dd679f59 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -31,6 +31,7 @@ import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Generate; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; @@ -115,11 +116,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); - verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); - store.writeConverted(expression); + verifyRequiredMacros(expression, model, profile, queryProfiles); + addGeneratedMacros(model, profile); + reduceBatchDimensions(expression, model, profile, queryProfiles); - model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v)); + model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v)); + store.writeConverted(expression); return expression.getRoot(); } @@ -133,7 +136,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } for (Pair<String, RankingExpression> macro : store.readMacros()) { - addMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond()); } return store.readConverted().getRoot(); @@ -221,16 +224,15 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } - private void transformMacro(ModelStore store, RankProfile profile, - Set<String> constantsReplacedByMacros, - String macroName, RankingExpression expression) { + private void transformGeneratedMacro(ModelStore store, RankProfile profile, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); store.writeMacro(macroName, expression); - addMacroToProfile(profile, macroName, expression); } - private void addMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { if (profile.getMacros().containsKey(macroName)) { throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); } @@ -251,12 +253,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil * Verify that the macros referred in the given expression exists in the given rank profile, * and return tensors of the types specified in requiredMacros. */ - private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, + private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model, RankProfile profile, QueryProfileRegistry queryProfiles) { Set<String> macroNames = new HashSet<>(); - addMacroNamesIn(expression.getRoot(), macroNames); + addMacroNamesIn(expression.getRoot(), macroNames, model); for (String macroName : macroNames) { - TensorType requiredType = requiredMacros.get(macroName); + TensorType requiredType = model.requiredMacros().get(macroName); if (requiredType == null) continue; // Not a required macro RankProfile.Macro macro = profile.getMacros().get(macroName); @@ -279,65 +281,126 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil "' of type " + requiredType + " which must be produced by a macro in the rank profile, but " + "this macro produces type " + actualType); - - // Check if batch dimensions can be reduced out. - reduceBatchDimensions(expression, macro, actualType); } } /** - * If the macro specifies that a single exemplar should be - * evaluated, we can reduce the batch dimension out. + * Add the generated macros to the rank profile */ - private void reduceBatchDimensions(RankingExpression expression, RankProfile.Macro macro, TensorType type) { - if (type.dimensions().size() > 1) { - List<String> reduceDimensions = new ArrayList<>(); - for (TensorType.Dimension dimension : type.dimensions()) { - if (dimension.size().orElse(-1L) == 1) { - reduceDimensions.add(dimension.name()); - } + private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) { + model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v)); + } + + /** + * Check if batch dimensions of inputs can be reduced out. If the input + * macro specifies that a single exemplar should be evaluated, we can + * reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model, + RankProfile profile, QueryProfileRegistry queryProfiles) { + TypeContext<Reference> typeContext = profile.typeContext(queryProfiles); + TensorType typeBeforeReducing = expression.getRoot().type(typeContext); + + // Check generated macros for inputs to reduce + Set<String> macroNames = new HashSet<>(); + addMacroNamesIn(expression.getRoot(), macroNames, model); + for (String macroName : macroNames) { + if ( ! model.macros().containsKey(macroName)) { + continue; } - if (reduceDimensions.size() > 0) { - ExpressionNode root = expression.getRoot(); - root = reduceBatchDimensionsAtInput(root, macro, reduceDimensions); - root = expandBatchDimensionsAtOutput(root, reduceDimensions); // todo: determine when we can skip this - expression.setRoot(root); + RankProfile.Macro macro = profile.getMacros().get(macroName); + if (macro == null) { + throw new IllegalArgumentException("Model refers to generated macro '" + macroName + + "but this macro is not present in " + profile); } + RankingExpression macroExpression = macro.getRankingExpression(); + macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext)); } + + // Check expression for inputs to reduce + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, model, typeContext); + TensorType typeAfterReducing = root.type(typeContext); + root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing); + expression.setRoot(root); } - private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, - RankProfile.Macro macro, - List<String> reduceDimensions) { + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model, + TypeContext<Reference> typeContext) { if (node instanceof TensorFunctionNode) { TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); if (tensorFunction instanceof Rename) { List<ExpressionNode> children = ((TensorFunctionNode)node).children(); if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) children.get(0); - if (referenceNode.getName().equals(macro.getName())) { - return reduceBatchDimensionExpression(tensorFunction, reduceDimensions); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(tensorFunction, typeContext); } } } } if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; - if (referenceNode.getName().equals(macro.getName())) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), reduceDimensions); + if (model.requiredMacros().containsKey(referenceNode.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); } } if (node instanceof CompositeNode) { List<ExpressionNode> children = ((CompositeNode)node).children(); List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); for (ExpressionNode child : children) { - transformedChildren.add(reduceBatchDimensionsAtInput(child, macro, reduceDimensions)); + transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext)); } return ((CompositeNode)node).setChildren(transformedChildren); } return node; } + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) { + TensorFunction result = function; + TensorType type = function.type(context); + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions); + } + } + return new TensorFunctionNode(result); + } + + /** + * If batch dimensions have been reduced away above, bring them back here + * for any following computation of the tensor. + * Todo: determine when this is not necessary! + */ + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) { + if (after.equals(before)) { + return node; + } + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (TensorType.Dimension dimension : before.dimensions()) { + if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) { + typeBuilder.indexed(dimension.name(), 1); + } + } + TensorType expandDimensionsType = typeBuilder.build(); + if (expandDimensionsType.dimensions().size() > 0) { + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(0)); + Generate generatedFunction = new Generate(expandDimensionsType, + new GeneratorLambdaFunctionNode(expandDimensionsType, + generatedExpression) + .asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.add()); + return new TensorFunctionNode(expand); + } + return node; + } + /** * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. * This method does that for the given expression and returns the result. @@ -367,33 +430,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return node; } - private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) { - return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); - } - - private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, - List<String> reduceDimensions) { - TensorType.Builder typeBuilder = new TensorType.Builder(); - for (String name : reduceDimensions) { - typeBuilder.indexed(name, 1); - } - TensorType generatedType = typeBuilder.build(); - ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); - Generate generatedFunction = new Generate(generatedType, - new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); - return new TensorFunctionNode(expand); - } - - private void addMacroNamesIn(ExpressionNode node, Set<String> names) { + private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; - if (referenceNode.getOutput() == null) // macro references cannot specify outputs + if (referenceNode.getOutput() == null) { // macro references cannot specify outputs names.add(referenceNode.getName()); + if (model.macros().containsKey(referenceNode.getName())) { + addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model); + } + } } else if (node instanceof CompositeNode) { for (ExpressionNode child : ((CompositeNode)node).children()) - addMacroNamesIn(child, names); + addMacroNamesIn(child, names, model); } } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index 06912a980a8..2cadbbf50e7 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -296,7 +296,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testTensorFlowReduceBatchDimension() { - final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); @@ -306,12 +306,12 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String expression = "join(join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", - "tensorflow('mnist/saved')"); + "tensorflow('mnist/saved')", null, null, "input", new StoringApplicationPackage(applicationDir)); search.assertFirstPhaseExpression(expression, "my_profile"); search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); @@ -319,8 +319,8 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String expression = "join(join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); |