diff options
7 files changed, 123 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java index 864cd823728..271fdbbfce3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java @@ -110,6 +110,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v)); + model.macros().forEach((k, v) -> transformMacro(store, profile, k, v)); return expression.getRoot(); } @@ -123,6 +124,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil profile.getSearch().addRankingConstant(constant); } + for (Pair<String, RankingExpression> macro : store.readMacros()) { + addMacroToProfile(profile, macro.getFirst(), macro.getSecond()); + } + return store.readConverted().getRoot(); } @@ -194,6 +199,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) { + store.writeMacro(macroName, expression); + addMacroToProfile(profile, macroName, expression); + } + + private void addMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) { + if (profile.getMacros().containsKey(macroName)) { + throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists."); + } + profile.addMacro(macroName, false); // todo: inline if only used once + RankProfile.Macro macro = profile.getMacros().get(macroName); + macro.setRankingExpression(expression); + macro.setTextualExpression(expression.getRoot().toString()); + } + private String skippedOutputsDescription(TensorFlowModel.Signature signature) { if (signature.skippedOutputs().isEmpty()) return ""; StringBuilder b = new StringBuilder(": "); @@ -382,6 +402,39 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil } } + /** Adds this macro expression to the application package to it can be read later. */ + public void writeMacro(String name, RankingExpression expression) { + application.getFile(arguments.macrosPath()).appendFile(name + "\t" + + expression.getRoot().toString() + "\n"); + } + + /** Reads the previously stored macro expressions for these arguments */ + public List<Pair<String, RankingExpression>> readMacros() { + try { + ApplicationFile file = application.getFile(arguments.macrosPath()); + if (!file.exists()) return Collections.emptyList(); + + List<Pair<String, RankingExpression>> macros = new ArrayList<>(); + BufferedReader reader = new BufferedReader(file.createReader()); + String line; + while (null != (line = reader.readLine())) { + String[] parts = line.split("\t"); + String name = parts[0]; + try { + RankingExpression expression = new RankingExpression(parts[1]); + macros.add(new Pair<>(name, expression)); + } + catch (ParseException e) { + throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e); + } + } + return macros; + } + catch (IOException e) { + throw new UncheckedIOException(e); + } + } + /** * Reads the information about all the large (aka ranking) constants stored in the application package * (the constant value itself is replicated with file distribution). @@ -510,8 +563,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants"); } + /** Path to the macros file */ + public Path macrosPath() { + return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("macros.txt"); + } + public Path expressionPath() { - return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR + return ApplicationPackage.MODELS_GENERATED_DIR .append(modelPath).append("expressions").append(expressionFileName()); } diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java index beba8ade1d8..bd2bbf5c6d5 100644 --- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java @@ -263,7 +263,7 @@ public class RankingExpressionWithTensorFlowTestCase { @Test public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException { - final String expression = "join(join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))"; + final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; StoringApplicationPackage application = new StoringApplicationPackage(applicationDir); RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", "tensorflow('mnist/saved')", @@ -295,6 +295,19 @@ public class RankingExpressionWithTensorFlowTestCase { } } + @Test + public void testMacroGeneration() { + final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))"; + final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))"; + final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))"; + + RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)", + "tensorflow('mnist/saved')"); + search.assertFirstPhaseExpression(expression, "my_profile"); + search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile"); + search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile"); + } + private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) { Value value = search.rankProfile("my_profile").getConstants().get(name); assertNotNull(value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 7116d430502..9ff88103f12 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -82,8 +82,6 @@ public class TensorFlowImporter { findDimensionNames(model, index); importExpressions(model, index, bundle); - // nodes with multiple outputs are calculated multiple times. consider adding macros for those. - reportWarnings(model, index); return model; @@ -241,7 +239,14 @@ public class TensorFlowImporter { private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) { if (operation.macro().isPresent()) { - model.macro(operation.vespaName(), operation.macro().get()); + TensorFunction function = operation.macro().get(); + try { + model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); + } + catch (ParseException e) { + throw new RuntimeException("Tensorflow function " + function + + " cannot be parsed as a ranking expression", e); + } } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java index ab091b77a65..4e5709505ce 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java @@ -34,7 +34,7 @@ public class PlaceholderWithDefault extends TensorFlowOperation { } @Override - public Optional<RankingExpression> macro() { + public Optional<TensorFunction> macro() { // For now, it is much more efficient to assume we always will return // the default value, as we can prune away large parts of the expression // tree by having it calculated as a constant. If a case arises where diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index fd9dfd167fb..9e8f6df3e2c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.Ord import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; import org.tensorflow.framework.NodeDef; @@ -28,6 +29,8 @@ import java.util.function.Function; */ public abstract class TensorFlowOperation { + protected final static String MACRO_PREFIX = "tf_macro_"; + protected final NodeDef node; protected final int port; protected final List<TensorFlowOperation> inputs; @@ -36,6 +39,7 @@ public abstract class TensorFlowOperation { protected OrderedTensorType type; protected TensorFunction function; + protected TensorFunction macro = null; private Value constantValue = null; private List<TensorFlowOperation> controlInputs = Collections.emptyList(); @@ -65,6 +69,9 @@ public abstract class TensorFlowOperation { if (isConstant()) { ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")"); function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + } else if (outputs.size() > 1) { + macro = lazyGetFunction(); + function = new VariableTensor(macroName(), type.type()); } else { function = lazyGetFunction(); } @@ -82,7 +89,7 @@ public abstract class TensorFlowOperation { public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); } /** Returns a Vespa ranking expression that should be added as a macro */ - public Optional<RankingExpression> macro() { return Optional.empty(); } + public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); } /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } @@ -111,6 +118,9 @@ public abstract class TensorFlowOperation { /** Retrieve the valid Vespa name of this node */ public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; } + /** Retrieve the valid Vespa name of this node if it is a macro */ + public String macroName() { return vespaName() != null ? MACRO_PREFIX + vespaName() : null; } + /** Retrieve the list of warnings produced during its lifetime */ public List<String> warnings() { return Collections.unmodifiableList(importWarnings); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index f64d697d9b9..c09b1f2b606 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", + assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 1691756a64d..9f372d8d6f5 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -1,10 +1,14 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.tensorflow.SavedModelBundle; @@ -42,6 +46,9 @@ public class TestableTensorFlowModel { Context context = contextFrom(model); Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); + + model.macros().forEach((k,v) -> evaluateMacro(context, model, k)); + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } @@ -74,4 +81,26 @@ public class TestableTensorFlowModel { return b.build(); } + private void evaluateMacro(Context context, TensorFlowModel model, String macroName) { + if (!context.names().contains(macroName)) { + RankingExpression e = model.macros().get(macroName); + evaluateMacroDependencies(context, model, e.getRoot()); + context.put(macroName, new TensorValue(e.evaluate(context).asTensor())); + } + } + + private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + if (model.macros().containsKey(name)) { + evaluateMacro(context, model, name); + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateMacroDependencies(context, model, child); + } + } + } + } |