diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-07 15:49:00 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-07 15:49:00 +0100 |
commit | e3f5b2812b51896d4a4d5304e6e8c7060e60f68a (patch) | |
tree | 50f86020fe0b4c295bffd14bb877251293593f28 /config-model/src/main | |
parent | 584dcf0b4a54b5e5a70696c15ee0c2bfe63ab656 (diff) |
Allow macros to replace TenorFlow variables
Also, remove quoting of constant arguments generated
in TensorFlow as that is unnecessary now and is
interpreted as a string constant argument to a macro.
Diffstat (limited to 'config-model/src/main')
-rw-r--r-- | config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java | 81 |
1 files changed, 68 insertions, 13 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index e81d22cefe9..2c177633590 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -9,9 +9,11 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -85,10 +88,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); - if (store.hasStoredModel()) - return transformFromStoredModel(store, context.rankProfile()); - else // not converted yet - access TensorFlow model files + if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); + else + return transformFromStoredModel(store, context.rankProfile()); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); @@ -101,16 +104,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + // Find the specified expression Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); store.writeConverted(expression); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); - model.macros().forEach((k, v) -> transformMacro(store, profile, k, v)); + model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v)); return expression.getRoot(); } @@ -189,17 +197,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil profile.addConstant(constantName, asValue(constantValue)); } - private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - Path constantPath = store.writeLargeConstant(constantName, constantValue); + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + + Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + if (!profile.getSearch().getRankingConstants().containsKey(constantName)) { + log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } } } - private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) { + private void transformMacro(ModelStore store, RankProfile profile, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); store.writeMacro(macroName, expression); addMacroToProfile(profile, macroName, expression); } @@ -312,6 +338,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return node; } + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) { return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); } |