diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-07 15:49:00 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-07 15:49:00 +0100 |
commit | e3f5b2812b51896d4a4d5304e6e8c7060e60f68a (patch) | |
tree | 50f86020fe0b4c295bffd14bb877251293593f28 /config-model | |
parent | 584dcf0b4a54b5e5a70696c15ee0c2bfe63ab656 (diff) |
Allow macros to replace TenorFlow variables
Also, remove quoting of constant arguments generated
in TensorFlow as that is unnecessary now and is
interpreted as a string constant argument to a macro.
Diffstat (limited to 'config-model')
2 files changed, 129 insertions, 21 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index e81d22cefe9..2c177633590 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -9,9 +9,11 @@ import com.yahoo.config.model.application.provider.FilesApplicationPackage; import com.yahoo.io.IOUtils; import com.yahoo.path.Path; import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.searchdefinition.FeatureNames; import com.yahoo.searchdefinition.RankProfile; import com.yahoo.searchdefinition.RankingConstant; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,6 +53,7 @@ import java.util.Map; import java.util.Optional; import java.util.Set; import java.util.logging.Logger; +import java.util.stream.Collectors; /** * Replaces instances of the tensorflow(model-path, signature, output) @@ -85,10 +88,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil try { ModelStore store = new ModelStore(context.rankProfile().getSearch().sourceApplication(), feature.getArguments()); - if (store.hasStoredModel()) - return transformFromStoredModel(store, context.rankProfile()); - else // not converted yet - access TensorFlow model files + if ( ! store.hasStoredModel()) // not converted yet - access TensorFlow model files return transformFromTensorFlowModel(store, context.rankProfile(), context.queryProfiles()); + else + return transformFromStoredModel(store, context.rankProfile()); } catch (IllegalArgumentException | UncheckedIOException e) { throw new IllegalArgumentException("Could not use tensorflow model from " + feature, e); @@ -101,16 +104,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil TensorFlowModel model = importedModels.computeIfAbsent(store.arguments().modelPath(), k -> tensorFlowImporter.importModel(store.tensorFlowModelDir())); + // Add constants + Set<String> constantsReplacedByMacros = new HashSet<>(); + model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); + model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, queryProfiles, + constantsReplacedByMacros, k, v)); + // Find the specified expression Signature signature = chooseSignature(model, store.arguments().signature()); String output = chooseOutput(signature, store.arguments().output()); RankingExpression expression = model.expressions().get(output); + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); verifyRequiredMacros(expression, model.requiredMacros(), profile, queryProfiles); store.writeConverted(expression); - model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); - model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); - model.macros().forEach((k, v) -> transformMacro(store, profile, k, v)); + model.macros().forEach((k, v) -> transformMacro(store, profile, constantsReplacedByMacros, k, v)); return expression.getRoot(); } @@ -189,17 +197,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil profile.addConstant(constantName, asValue(constantValue)); } - private void transformLargeConstant(ModelStore store, RankProfile profile, String constantName, Tensor constantValue) { - Path constantPath = store.writeLargeConstant(constantName, constantValue); + private void transformLargeConstant(ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, + Set<String> constantsReplacedByMacros, + String constantName, Tensor constantValue) { + RankProfile.Macro macroOverridingConstant = profile.getMacros().get(constantName); + if (macroOverridingConstant != null) { + TensorType macroType = macroOverridingConstant.getRankingExpression().type(profile.typeContext(queryProfiles)); + if ( ! macroType.equals(constantValue.type())) + throw new IllegalArgumentException("Macro '" + constantName + "' replaces the constant with this name. " + + "The required type of this is " + constantValue.type() + + ", but the macro returns " + macroType); + constantsReplacedByMacros.add(constantName); // will replace constant(constantName) by constantName later + } + else { + + Path constantPath = store.writeLargeConstant(constantName, constantValue); - if ( ! profile.getSearch().getRankingConstants().containsKey(constantName)) { - log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); - profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), - constantPath.toString())); + if (!profile.getSearch().getRankingConstants().containsKey(constantName)) { + log.info("Adding constant '" + constantName + "' of type " + constantValue.type()); + profile.getSearch().addRankingConstant(new RankingConstant(constantName, constantValue.type(), + constantPath.toString())); + } } } - private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) { + private void transformMacro(ModelStore store, RankProfile profile, + Set<String> constantsReplacedByMacros, + String macroName, RankingExpression expression) { + + expression = replaceConstantsByMacros(expression, constantsReplacedByMacros); store.writeMacro(macroName, expression); addMacroToProfile(profile, macroName, expression); } @@ -312,6 +338,35 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return node; } + /** + * If a constant c is overridden by a macro, we need to replace instances of "constant(c)" by "c" in expressions. + * This method does that for the given expression and returns the result. + */ + private RankingExpression replaceConstantsByMacros(RankingExpression expression, + Set<String> constantsReplacedByMacros) { + if (constantsReplacedByMacros.isEmpty()) return expression; + return new RankingExpression(expression.getName(), + replaceConstantsByMacros(expression.getRoot(), constantsReplacedByMacros)); + } + + private ExpressionNode replaceConstantsByMacros(ExpressionNode node, Set<String> constantsReplacedByMacros) { + if (node instanceof ReferenceNode) { + Reference reference = ((ReferenceNode)node).reference(); + if (FeatureNames.isSimpleFeature(reference) && reference.name().equals("constant")) { + String argument = reference.simpleArgument().get(); + if (constantsReplacedByMacros.contains(argument)) + return new ReferenceNode(argument); + } + } + if (node instanceof CompositeNode) { // not else: this matches some of the same nodes as the outer if above + CompositeNode composite = (CompositeNode)node; + return composite.setChildren(composite.children().stream() + .map(child -> replaceConstantsByMacros(child, constantsReplacedByMacros)) + .collect(Collectors.toList())); + } + return node; + } + private ExpressionNode reduceBatchDimensionExpression(TensorFunction function, List<String> reduceDimensions) { return new TensorFunctionNode(new Reduce(function, Reduce.Aggregator.sum, reduceDimensions)); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index c650151980c..b46e0fef2bf 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -42,7 +42,7 @@ import static org.junit.Assert.*; public class RankingExpressionWithTensorFlowTestCase { private final Path applicationDir = Path.fromString("src/test/integration/tensorflow/"); - private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b))"; + private final String vespaExpression = "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b))"; @After public void removeGeneratedConstantTensorFiles() { @@ -252,8 +252,51 @@ public class RankingExpressionWithTensorFlowTestCase { } @Test + public void testImportingFromStoredExpressionsWithMacroOverridingConstant() throws IOException { + String rankProfile = + " rank-profile my_profile {\n" + + " macro Placeholder() {\n" + + " expression: tensor(d0[2],d1[784])(0.0)\n" + + " }\n" + + " macro layer_Variable_read() {\n" + + " expression: tensor(d1[10],d2[784])(0.0)\n" + + " }\n" + + " first-phase {\n" + + " expression: tensorflow('mnist_softmax/saved')" + + " }\n" + + " }"; + + + String vespaExpressionWithoutConstant = + "join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), layer_Variable_read, f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b))"; + RankProfileSearchFixture search = fixtureWith(rankProfile, new StoringApplicationPackage(applicationDir)); + search.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + + assertNull("Constant overridden by macro is not added", + search.search().getRankingConstants().get("layer_Variable_read")); + assertLargeConstant("layer_Variable_1_read", search, Optional.of(10L)); + + // At this point the expression is stored - copy application to another location which do not have a models dir + Path storedApplicationDirectory = applicationDir.getParentPath().append("copy"); + try { + storedApplicationDirectory.toFile().mkdirs(); + IOUtils.copyDirectory(applicationDir.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile(), + storedApplicationDirectory.append(ApplicationPackage.MODELS_GENERATED_DIR).toFile()); + StoringApplicationPackage storedApplication = new StoringApplicationPackage(storedApplicationDirectory); + RankProfileSearchFixture searchFromStored = fixtureWith(rankProfile, storedApplication); + searchFromStored.assertFirstPhaseExpression(vespaExpressionWithoutConstant, "my_profile"); + assertNull("Constant overridden by macro is not added", + searchFromStored.search().getRankingConstants().get("layer_Variable_read")); + assertLargeConstant("layer_Variable_1_read", searchFromStored, Optional.of(10L)); + } + finally { + IOUtils.recursiveDeleteDir(storedApplicationDirectory.toFile()); + } + } + + @Test public void testTensorFlowReduceBatchDimension() { - final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(\"layer_Variable_read\"), f(a,b)(a * b)), sum, d2), constant(\"layer_Variable_1_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(join(reduce(join(reduce(rename(Placeholder, (d0, d1), (d0, d2)), sum, d0), constant(layer_Variable_read), f(a,b)(a * b)), sum, d2), constant(layer_Variable_1_read), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist_softmax/saved')"); search.assertFirstPhaseExpression(expression, "my_profile"); @@ -263,9 +306,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testMacroGeneration() { - final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))"; RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')"); @@ -276,9 +319,9 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; - final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; - final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + final String expression = "join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", @@ -383,6 +426,16 @@ public class RankingExpressionWithTensorFlowTestCase { } } + private RankProfileSearchFixture fixtureWith(String rankProfile, StoringApplicationPackage application) { + try { + return new RankProfileSearchFixture(application, application.getQueryProfiles(), + rankProfile, null, null); + } + catch (ParseException e) { + throw new IllegalArgumentException(e); + } + } + private static class StoringApplicationPackage extends MockApplicationPackage { private final File root; |