summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-03-08 14:10:52 +0100
committerLester Solbakken <lesters@oath.com>2018-03-08 14:14:45 +0100
commit76bd625f358a61c5962ff15cf14e321b49f26ae1 (patch)
tree9ebdc581f44d0b7539dce24049aa1be73413bfda /config-model
parent1f6f506795c4f05364327ca6a8d1c370ccbbd1e7 (diff)
Add batch dimension reduction in generated macros
Check for inputs/placeholders containing batch dimensions inside generated macros, and insert reduce operations. Expand those dimensions back at the top level expression.
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java165
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java12
2 files changed, 113 insertions, 64 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 2c177633590..ce4dd679f59 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -31,6 +31,7 @@ import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
@@ -115,11 +116,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
String output = chooseOutput(signature, store.arguments().output());
RankingExpression expression = model.expressions().get(output);
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
- verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles);
- store.writeConverted(expression);
+ verifyRequiredMacros(expression, model, profile, queryProfiles);
+ addGeneratedMacros(model, profile);
+ reduceBatchDimensions(expression, model, profile, queryProfiles);
- model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v));
+ model.macros().forEach((k, v) -> transformGeneratedMacro(store, profile, constantsReplacedByMacros, k, v));
+ store.writeConverted(expression);
return expression.getRoot();
}
@@ -133,7 +136,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
for (Pair<String, RankingExpression> macro : store.readMacros()) {
- addMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ addGeneratedMacroToProfile(profile, macro.getFirst(), macro.getSecond());
}
return store.readConverted().getRoot();
@@ -221,16 +224,15 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
- private void transformMacro(ModelStore store, RankProfile profile,
- Set<String> constantsReplacedByMacros,
- String macroName, RankingExpression expression) {
+ private void transformGeneratedMacro(ModelStore store, RankProfile profile,
+ Set<String> constantsReplacedByMacros,
+ String macroName, RankingExpression expression) {
expression = replaceConstantsByMacros(expression, constantsReplacedByMacros);
store.writeMacro(macroName, expression);
- addMacroToProfile(profile, macroName, expression);
}
- private void addMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ private void addGeneratedMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
if (profile.getMacros().containsKey(macroName)) {
throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
}
@@ -251,12 +253,12 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
* Verify that the macros referred in the given expression exists in the given rank profile,
* and return tensors of the types specified in requiredMacros.
*/
- private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros,
+ private void verifyRequiredMacros(RankingExpression expression, TensorFlowModel model,
RankProfile profile, QueryProfileRegistry queryProfiles) {
Set<String> macroNames = new HashSet<>();
- addMacroNamesIn(expression.getRoot(), macroNames);
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
for (String macroName : macroNames) {
- TensorType requiredType = requiredMacros.get(macroName);
+ TensorType requiredType = model.requiredMacros().get(macroName);
if (requiredType == null) continue; // Not a required macro
RankProfile.Macro macro = profile.getMacros().get(macroName);
@@ -279,65 +281,126 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
"' of type " + requiredType +
" which must be produced by a macro in the rank profile, but " +
"this macro produces type " + actualType);
-
- // Check if batch dimensions can be reduced out.
- reduceBatchDimensions(expression, macro, actualType);
}
}
/**
- * If the macro specifies that a single exemplar should be
- * evaluated, we can reduce the batch dimension out.
+ * Add the generated macros to the rank profile
*/
- private void reduceBatchDimensions(RankingExpression expression, RankProfile.Macro macro, TensorType type) {
- if (type.dimensions().size() > 1) {
- List<String> reduceDimensions = new ArrayList<>();
- for (TensorType.Dimension dimension : type.dimensions()) {
- if (dimension.size().orElse(-1L) == 1) {
- reduceDimensions.add(dimension.name());
- }
+ private void addGeneratedMacros(TensorFlowModel model, RankProfile profile) {
+ model.macros().forEach((k, v) -> addGeneratedMacroToProfile(profile, k, v));
+ }
+
+ /**
+ * Check if batch dimensions of inputs can be reduced out. If the input
+ * macro specifies that a single exemplar should be evaluated, we can
+ * reduce the batch dimension out.
+ */
+ private void reduceBatchDimensions(RankingExpression expression, TensorFlowModel model,
+ RankProfile profile, QueryProfileRegistry queryProfiles) {
+ TypeContext<Reference> typeContext = profile.typeContext(queryProfiles);
+ TensorType typeBeforeReducing = expression.getRoot().type(typeContext);
+
+ // Check generated macros for inputs to reduce
+ Set<String> macroNames = new HashSet<>();
+ addMacroNamesIn(expression.getRoot(), macroNames, model);
+ for (String macroName : macroNames) {
+ if ( ! model.macros().containsKey(macroName)) {
+ continue;
}
- if (reduceDimensions.size() > 0) {
- ExpressionNode root = expression.getRoot();
- root = reduceBatchDimensionsAtInput(root, macro, reduceDimensions);
- root = expandBatchDimensionsAtOutput(root, reduceDimensions); // todo: determine when we can skip this
- expression.setRoot(root);
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ if (macro == null) {
+ throw new IllegalArgumentException("Model refers to generated macro '" + macroName +
+ "but this macro is not present in " + profile);
}
+ RankingExpression macroExpression = macro.getRankingExpression();
+ macroExpression.setRoot(reduceBatchDimensionsAtInput(macroExpression.getRoot(), model, typeContext));
}
+
+ // Check expression for inputs to reduce
+ ExpressionNode root = expression.getRoot();
+ root = reduceBatchDimensionsAtInput(root, model, typeContext);
+ TensorType typeAfterReducing = root.type(typeContext);
+ root = expandBatchDimensionsAtOutput(root, typeBeforeReducing, typeAfterReducing);
+ expression.setRoot(root);
}
- private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node,
- RankProfile.Macro macro,
- List<String> reduceDimensions) {
+ private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, TensorFlowModel model,
+ TypeContext<Reference> typeContext) {
if (node instanceof TensorFunctionNode) {
TensorFunction tensorFunction = ((TensorFunctionNode) node).function();
if (tensorFunction instanceof Rename) {
List<ExpressionNode> children = ((TensorFunctionNode)node).children();
if (children.size() == 1 && children.get(0) instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) children.get(0);
- if (referenceNode.getName().equals(macro.getName())) {
- return reduceBatchDimensionExpression(tensorFunction, reduceDimensions);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(tensorFunction, typeContext);
}
}
}
}
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
- if (referenceNode.getName().equals(macro.getName())) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), reduceDimensions);
+ if (model.requiredMacros().containsKey(referenceNode.getName())) {
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
}
}
if (node instanceof CompositeNode) {
List<ExpressionNode> children = ((CompositeNode)node).children();
List<ExpressionNode> transformedChildren = new ArrayList<>(children.size());
for (ExpressionNode child : children) {
- transformedChildren.add(reduceBatchDimensionsAtInput(child, macro, reduceDimensions));
+ transformedChildren.add(reduceBatchDimensionsAtInput(child, model, typeContext));
}
return ((CompositeNode)node).setChildren(transformedChildren);
}
return node;
}
+ private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, TypeContext<Reference> context) {
+ TensorFunction result = function;
+ TensorType type = function.type(context);
+ if (type.dimensions().size() > 1) {
+ List<String> reduceDimensions = new ArrayList<>();
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1) {
+ reduceDimensions.add(dimension.name());
+ }
+ }
+ if (reduceDimensions.size() > 0) {
+ result = new Reduce(function, Reduce.Aggregator.sum, reduceDimensions);
+ }
+ }
+ return new TensorFunctionNode(result);
+ }
+
+ /**
+ * If batch dimensions have been reduced away above, bring them back here
+ * for any following computation of the tensor.
+ * Todo: determine when this is not necessary!
+ */
+ private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, TensorType before, TensorType after) {
+ if (after.equals(before)) {
+ return node;
+ }
+ TensorType.Builder typeBuilder = new TensorType.Builder();
+ for (TensorType.Dimension dimension : before.dimensions()) {
+ if (dimension.size().orElse(-1L) == 1 && !after.dimensionNames().contains(dimension.name())) {
+ typeBuilder.indexed(dimension.name(), 1);
+ }
+ }
+ TensorType expandDimensionsType = typeBuilder.build();
+ if (expandDimensionsType.dimensions().size() > 0) {
+ ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(0));
+ Generate generatedFunction = new Generate(expandDimensionsType,
+ new GeneratorLambdaFunctionNode(expandDimensionsType,
+ generatedExpression)
+ .asLongListToDoubleOperator());
+ Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.add());
+ return new TensorFunctionNode(expand);
+ }
+ return node;
+ }
+
/**
* If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions.
* This method does that for the given expression and returns the result.
@@ -367,33 +430,19 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return node;
}
- private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) {
- return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions));
- }
-
- private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node,
- List<String> reduceDimensions) {
- TensorType.Builder typeBuilder = new TensorType.Builder();
- for (String name : reduceDimensions) {
- typeBuilder.indexed(name, 1);
- }
- TensorType generatedType = typeBuilder.build();
- ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
- Generate generatedFunction = new Generate(generatedType,
- new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
- return new TensorFunctionNode(expand);
- }
-
- private void addMacroNamesIn(ExpressionNode node, Set<String> names) {
+ private void addMacroNamesIn(ExpressionNode node, Set<String> names, TensorFlowModel model) {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode)node;
- if (referenceNode.getOutput() == null) // macro references cannot specify outputs
+ if (referenceNode.getOutput() == null) { // macro references cannot specify outputs
names.add(referenceNode.getName());
+ if (model.macros().containsKey(referenceNode.getName())) {
+ addMacroNamesIn(model.macros().get(referenceNode.getName()).getRoot(), names, model);
+ }
+ }
}
else if (node instanceof CompositeNode) {
for (ExpressionNode child : ((CompositeNode)node).children())
- addMacroNamesIn(child, names);
+ addMacroNamesIn(child, names, model);
}
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 06912a980a8..2cadbbf50e7 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -296,7 +296,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testTensorFlowReduceBatchDimension() {
- final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist_softmax/saved')");
search.assertFirstPhaseExpression(expression, "my_profile");
@@ -306,12 +306,12 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))";
- final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
+ final String expression = "join(join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
- "tensorflow('mnist/saved')");
+ "tensorflow('mnist/saved')", null, null, "input", new StoringApplicationPackage(applicationDir));
search.assertFirstPhaseExpression(expression, "my_profile");
search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
@@ -319,8 +319,8 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))";
- final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
+ final String expression = "join(join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);