From e3f5b2812b51896d4a4d5304e6e8c7060e60f68a Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Wed, 7 Mar 2018 15:49:00 +0100 Subject: Allow macros to replace TenorFlow variables Also, remove quoting of constant arguments generated in TensorFlow as that is unnecessary now and is interpreted as a string constant argument to a macro. --- .../TensorFlowFeatureConverter.java | 81 ++++++++++++++++++---- 1 file changed, 68 insertions(+), 13 deletions(-) (limited to 'config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index e81d22cefe9..2c177633590 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -9,9 +9,11 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -85,10 +88,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer tensorFlowImporter.importModel(store.tensorFlowModelDir())); + // Add constants + Set constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + // Find the specified expression Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); store.writeConverted(expression); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); - model.macros().forEach((k, v) -> transformMacro(store, profile, k, v)); + model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v)); return expression.getRoot(); } @@ -189,17 +197,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + + Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + if (!profile.getSearch().getRankingConstants().containsKey(constantName)) { + log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } } } - private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) { + private void transformMacro(ModelStore store, RankProfile profile, + Set constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); store.writeMacro(macroName, expression); addMacroToProfile(profile, macroName, expression); } @@ -312,6 +338,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List reduceDimensions) { return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); } -- cgit v1.2.3