aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-03-05 12:54:15 +0100
committerLester Solbakken <lesters@oath.com>2018-03-05 12:54:45 +0100
commit3dc6c980c74ff9b280a840374c85026297de89a3 (patch)
treeaeb062e63a9993ab190e3be6335eb00308c36446 /searchlib/src/test/java/com/yahoo
parent352999b8b295b218b5d4cc4a51b39feea21e0350 (diff)
Generate macros for TensorFlow nodes with multiple outputs
Diffstat (limited to 'searchlib/src/test/java/com/yahoo')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java29
2 files changed, 30 insertions, 1 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
index f64d697d9b9..c09b1f2b606 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java
@@ -32,7 +32,7 @@ public class DropoutImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("outputs/BiasAdd", output.getName());
- assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
+ assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))",
output.getRoot().toString());
model.assertEqualResult("X", output.getName());
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
index 1691756a64d..9f372d8d6f5 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java
@@ -1,10 +1,14 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.tensorflow.SavedModelBundle;
@@ -42,6 +46,9 @@ public class TestableTensorFlowModel {
Context context = contextFrom(model);
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
+
+ model.macros().forEach((k,v) -> evaluateMacro(context, model, k));
+
Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
@@ -74,4 +81,26 @@ public class TestableTensorFlowModel {
return b.build();
}
+ private void evaluateMacro(Context context, TensorFlowModel model, String macroName) {
+ if (!context.names().contains(macroName)) {
+ RankingExpression e = model.macros().get(macroName);
+ evaluateMacroDependencies(context, model, e.getRoot());
+ context.put(macroName, new TensorValue(e.evaluate(context).asTensor()));
+ }
+ }
+
+ private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) {
+ if (node instanceof ReferenceNode) {
+ String name = node.toString();
+ if (model.macros().containsKey(name)) {
+ evaluateMacro(context, model, name);
+ }
+ }
+ else if (node instanceof CompositeNode) {
+ for (ExpressionNode child : ((CompositeNode)node).children()) {
+ evaluateMacroDependencies(context, model, child);
+ }
+ }
+ }
+
}