diff options
author | Lester Solbakken <lesters@oath.com> | 2018-03-01 09:20:09 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-03-01 09:20:09 +0100 |
commit | 413259f26b658617a7482a72926088751adab521 (patch) | |
tree | 527330a94052f2abdbe43e9e01c6137c16683694 /config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | |
parent | bc3ccdb3552d0d3ff5dcc463308614e72e6abd3e (diff) |
Add batch dimension reduction for TensorFlow import
Diffstat (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 88 |
1 files changed, 86 insertions, 2 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index f16697b5ba6..864cd823728 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -23,10 +23,18 @@ import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.ScalarFunctions; +import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.serialization.TypedBinaryFormat; import java.io.BufferedReader; @@ -37,9 +45,11 @@ import java.io.UncheckedIOException; import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.Set; import java.util.logging.Logger; /** @@ -197,7 +207,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil */ private void verifyRequiredMacros(RankingExpression expression, Map<String, TensorType> requiredMacros, RankProfile profile, QueryProfileRegistry queryProfiles) { - List<String> macroNames = new ArrayList<>(); + Set<String> macroNames = new HashSet<>(); addMacroNamesIn(expression.getRoot(), macroNames); for (String macroName : macroNames) { TensorType requiredType = requiredMacros.get(macroName); @@ -223,10 +233,84 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil "' of type " + requiredType + " which must be produced by a macro in the rank profile, but " + "this macro produces type " + actualType); + + // Check if batch dimensions can be reduced out. + reduceBatchDimensions(expression, macro, actualType); + } + } + + /** + * If the macro specifies that a single exemplar should be + * evaluated, we can reduce the batch dimension out. + */ + private void reduceBatchDimensions(RankingExpression expression, RankProfile.Macro macro, TensorType type) { + if (type.dimensions().size() > 1) { + List<String> reduceDimensions = new ArrayList<>(); + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.size().orElse(-1L) == 1) { + reduceDimensions.add(dimension.name()); + } + } + if (reduceDimensions.size() > 0) { + ExpressionNode root = expression.getRoot(); + root = reduceBatchDimensionsAtInput(root, macro, reduceDimensions); + root = expandBatchDimensionsAtOutput(root, reduceDimensions); // todo: determine when we can skip this + expression.setRoot(root); + } + } + } + + private ExpressionNode reduceBatchDimensionsAtInput(ExpressionNode node, + RankProfile.Macro macro, + List<String> reduceDimensions) { + if (node instanceof TensorFunctionNode) { + TensorFunction tensorFunction = ((TensorFunctionNode) node).function(); + if (tensorFunction instanceof Rename) { + List<ExpressionNode> children = ((TensorFunctionNode)node).children(); + if (children.size() == 1 && children.get(0) instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) children.get(0); + if (referenceNode.getName().equals(macro.getName())) { + return reduceBatchDimensionExpression(tensorFunction, reduceDimensions); + } + } + } + } + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + if (referenceNode.getName().equals(macro.getName())) { + return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), reduceDimensions); + } + } + if (node instanceof CompositeNode) { + List<ExpressionNode> children = ((CompositeNode)node).children(); + List<ExpressionNode> transformedChildren = new ArrayList<>(children.size()); + for (ExpressionNode child : children) { + transformedChildren.add(reduceBatchDimensionsAtInput(child, macro, reduceDimensions)); + } + return ((CompositeNode)node).setChildren(transformedChildren); + } + return node; + } + + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) { + return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); + } + + private ExpressionNode expandBatchDimensionsAtOutput(ExpressionNode node, + List<String> reduceDimensions) { + TensorType.Builder typeBuilder = new TensorType.Builder(); + for (String name : reduceDimensions) { + typeBuilder.indexed(name, 1); } + TensorType generatedType = typeBuilder.build(); + ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1)); + Generate generatedFunction = new Generate(generatedType, + new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator()); + Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + return new TensorFunctionNode(expand); } - private void addMacroNamesIn(ExpressionNode node, List<String> names) { + private void addMacroNamesIn(ExpressionNode node, Set<String> names) { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode)node; if (referenceNode.getOutput() == null) // macro references cannot specify outputs |