aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java60
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java15
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java11
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java29
7 files changed, 123 insertions, 8 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
index 864cd823728..271fdbbfce3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
@@ -110,6 +110,7 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
model.largeConstants().forEach((k, v) -> transformLargeConstant(store, profile, k, v));
+ model.macros().forEach((k, v) -> transformMacro(store, profile, k, v));
return expression.getRoot();
}
@@ -123,6 +124,10 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
profile.getSearch().addRankingConstant(constant);
}
+ for (Pair<String, RankingExpression> macro : store.readMacros()) {
+ addMacroToProfile(profile, macro.getFirst(), macro.getSecond());
+ }
+
return store.readConverted().getRoot();
}
@@ -194,6 +199,21 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
+ private void transformMacro(ModelStore store, RankProfile profile, String macroName, RankingExpression expression) {
+ store.writeMacro(macroName, expression);
+ addMacroToProfile(profile, macroName, expression);
+ }
+
+ private void addMacroToProfile(RankProfile profile, String macroName, RankingExpression expression) {
+ if (profile.getMacros().containsKey(macroName)) {
+ throw new IllegalArgumentException("Generated TensorFlow macro '" + macroName + "' already exists.");
+ }
+ profile.addMacro(macroName, false); // todo: inline if only used once
+ RankProfile.Macro macro = profile.getMacros().get(macroName);
+ macro.setRankingExpression(expression);
+ macro.setTextualExpression(expression.getRoot().toString());
+ }
+
private String skippedOutputsDescription(TensorFlowModel.Signature signature) {
if (signature.skippedOutputs().isEmpty()) return "";
StringBuilder b = new StringBuilder(": ");
@@ -382,6 +402,39 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
}
}
+ /** Adds this macro expression to the application package to it can be read later. */
+ public void writeMacro(String name, RankingExpression expression) {
+ application.getFile(arguments.macrosPath()).appendFile(name + "\t" +
+ expression.getRoot().toString() + "\n");
+ }
+
+ /** Reads the previously stored macro expressions for these arguments */
+ public List<Pair<String, RankingExpression>> readMacros() {
+ try {
+ ApplicationFile file = application.getFile(arguments.macrosPath());
+ if (!file.exists()) return Collections.emptyList();
+
+ List<Pair<String, RankingExpression>> macros = new ArrayList<>();
+ BufferedReader reader = new BufferedReader(file.createReader());
+ String line;
+ while (null != (line = reader.readLine())) {
+ String[] parts = line.split("\t");
+ String name = parts[0];
+ try {
+ RankingExpression expression = new RankingExpression(parts[1]);
+ macros.add(new Pair<>(name, expression));
+ }
+ catch (ParseException e) {
+ throw new IllegalStateException("Could not parse " + arguments.expressionPath(), e);
+ }
+ }
+ return macros;
+ }
+ catch (IOException e) {
+ throw new UncheckedIOException(e);
+ }
+ }
+
/**
* Reads the information about all the large (aka ranking) constants stored in the application package
* (the constant value itself is replicated with file distribution).
@@ -510,8 +563,13 @@ public class TensorFlowFeatureConverter extends ExpressionTransformer<RankProfil
return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR.append(modelPath).append("constants");
}
+ /** Path to the macros file */
+ public Path macrosPath() {
+ return ApplicationPackage.MODELS_GENERATED_DIR.append(modelPath).append("macros.txt");
+ }
+
public Path expressionPath() {
- return ApplicationPackage.MODELS_GENERATED_REPLICATED_DIR
+ return ApplicationPackage.MODELS_GENERATED_DIR
.append(modelPath).append("expressions").append(expressionFileName());
}
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index beba8ade1d8..bd2bbf5c6d5 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -263,7 +263,7 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(join(join(0.009999999776482582, join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(a * b)), join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b)), f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b)), tensor(d0[1])(1.0), f(a,b)(a * b))";
+ final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')",
@@ -295,6 +295,19 @@ public class RankingExpressionWithTensorFlowTestCase {
}
}
+ @Test
+ public void testMacroGeneration() {
+ final String expression = "join(reduce(join(join(join(constant(\"dnn_hidden2_Const\"), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(\"dnn_outputs_weights_read\"), f(a,b)(a * b)), sum, d2), constant(\"dnn_outputs_bias_read\"), f(a,b)(a + b))";
+ final String macroExpression1 = "join(reduce(join(rename(input, (d0, d1), (d0, d4)), constant(\"dnn_hidden1_weights_read\"), f(a,b)(a * b)), sum, d4), constant(\"dnn_hidden1_bias_read\"), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(\"dnn_hidden2_weights_read\"), f(a,b)(a * b)), sum, d3), constant(\"dnn_hidden2_bias_read\"), f(a,b)(a + b))";
+
+ RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
+ "tensorflow('mnist/saved')");
+ search.assertFirstPhaseExpression(expression, "my_profile");
+ search.assertMacro(macroExpression1, "tf_macro_dnn_hidden1_add", "my_profile");
+ search.assertMacro(macroExpression2, "tf_macro_dnn_hidden2_add", "my_profile");
+ }
+
private void assertSmallConstant(String name, TensorType type, RankProfileSearchFixture search) {
Value value = search.rankProfile("my_profile").getConstants().get(name);
assertNotNull(value);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 7116d430502..9ff88103f12 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -82,8 +82,6 @@ public class TensorFlowImporter {
findDimensionNames(model, index);
importExpressions(model, index, bundle);
- // nodes with multiple outputs are calculated multiple times. consider adding macros for those.
-
reportWarnings(model, index);
return model;
@@ -241,7 +239,14 @@ public class TensorFlowImporter {
private static void importMacroExpression(TensorFlowModel model, TensorFlowOperation operation) {
if (operation.macro().isPresent()) {
- model.macro(operation.vespaName(), operation.macro().get());
+ TensorFunction function = operation.macro().get();
+ try {
+ model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
+ }
+ catch (ParseException e) {
+ throw new RuntimeException("Tensorflow function " + function +
+ " cannot be parsed as a ranking expression", e);
+ }
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
index ab091b77a65..4e5709505ce 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/PlaceholderWithDefault.java
@@ -34,7 +34,7 @@ public class PlaceholderWithDefault extends TensorFlowOperation {
}
@Override
- public Optional<RankingExpression> macro() {
+ public Optional<TensorFunction> macro() {
// For now, it is much more efficient to assume we always will return
// the default value, as we can prune away large parts of the expression
// tree by having it calculated as a constant. If a case arises where
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index fd9dfd167fb..9e8f6df3e2c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.Ord
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
import org.tensorflow.framework.NodeDef;
@@ -28,6 +29,8 @@ import java.util.function.Function;
*/
public abstract class TensorFlowOperation {
+ protected final static String MACRO_PREFIX = "tf_macro_";
+
protected final NodeDef node;
protected final int port;
protected final List<TensorFlowOperation> inputs;
@@ -36,6 +39,7 @@ public abstract class TensorFlowOperation {
protected OrderedTensorType type;
protected TensorFunction function;
+ protected TensorFunction macro = null;
private Value constantValue = null;
private List<TensorFlowOperation> controlInputs = Collections.emptyList();
@@ -65,6 +69,9 @@ public abstract class TensorFlowOperation {
if (isConstant()) {
ExpressionNode constant = new ReferenceNode("constant(\"" + vespaName() + "\")");
function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ } else if (outputs.size() > 1) {
+ macro = lazyGetFunction();
+ function = new VariableTensor(macroName(), type.type());
} else {
function = lazyGetFunction();
}
@@ -82,7 +89,7 @@ public abstract class TensorFlowOperation {
public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }
/** Returns a Vespa ranking expression that should be added as a macro */
- public Optional<RankingExpression> macro() { return Optional.empty(); }
+ public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }
/** Add dimension name constraints for this operation */
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
@@ -111,6 +118,9 @@ public abstract class TensorFlowOperation {
/** Retrieve the valid Vespa name of this node */
public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }
+ /** Retrieve the valid Vespa name of this node if it is a macro */
+ public String macroName() { return vespaName() != null ? MACRO_PREFIX + vespaName() : null; }
+
/** Retrieve the list of warnings produced during its lifetime */
public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index f64d697d9b9..c09b1f2b606 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 1691756a64d..9f372d8d6f5 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -1,10 +1,14 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.SavedModelBundle;
@@ -42,6 +46,9 @@ public class TestableTensorFlowModel {
Context context = contextFrom(model);
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
+
+ model.macros().forEach((k,v) -> evaluateMacro(context, model, k));
+
Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
@@ -74,4 +81,26 @@ public class TestableTensorFlowModel {
return b.build();
}
+ private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
+ if (!context.names().contains(macroName)) {
+ RankingExpression e = model.macros().get(macroName);
+ evaluateMacroDependencies(context, model, e.getRoot());
+ context.put(macroName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.macros().containsKey(name)) {
+ evaluateMacro(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateMacroDependencies(context, model, child);
+ }
+ }
+ }
+
}