diff options
author | Lester Solbakken <lesters@oath.com> | 2018-03-05 12:54:15 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-03-05 12:54:45 +0100 |
commit | 3dc6c980c74ff9b280a840374c85026297de89a3 (patch) | |
tree | aeb062e63a9993ab190e3be6335eb00308c36446 /searchlib/src/test/java/com/yahoo | |
parent | 352999b8b295b218b5d4cc4a51b39feea21e0350 (diff) |
Generate macros for TensorFlow nodes with multiple outputs
Diffstat (limited to 'searchlib/src/test/java/com/yahoo')
2 files changed, 30 insertions, 1 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java index f64d697d9b9..c09b1f2b606 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/DropoutImportTestCase.java @@ -32,7 +32,7 @@ public class DropoutImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("outputs/BiasAdd", output.getName()); - assertEquals("join(reduce(join(rename(X, (d0, d1), (d0, d2)), constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", + assertEquals("join(reduce(join(tf_macro_X, constant(\"outputs_kernel_read\"), f(a,b)(a * b)), sum, d2), constant(\"outputs_bias_read\"), f(a,b)(a + b))", output.getRoot().toString()); model.assertEqualResult("X", output.getName()); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java index 1691756a64d..9f372d8d6f5 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TestableTensorFlowModel.java @@ -1,10 +1,14 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.tensorflow; +import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.tensorflow.SavedModelBundle; @@ -42,6 +46,9 @@ public class TestableTensorFlowModel { Context context = contextFrom(model); Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); + + model.macros().forEach((k,v) -> evaluateMacro(context, model, k)); + Tensor vespaResult = model.expressions().get(operationName).evaluate(context).asTensor(); assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } @@ -74,4 +81,26 @@ public class TestableTensorFlowModel { return b.build(); } + private void evaluateMacro(Context context, TensorFlowModel model, String macroName) { + if (!context.names().contains(macroName)) { + RankingExpression e = model.macros().get(macroName); + evaluateMacroDependencies(context, model, e.getRoot()); + context.put(macroName, new TensorValue(e.evaluate(context).asTensor())); + } + } + + private void evaluateMacroDependencies(Context context, TensorFlowModel model, ExpressionNode node) { + if (node instanceof ReferenceNode) { + String name = node.toString(); + if (model.macros().containsKey(name)) { + evaluateMacro(context, model, name); + } + } + else if (node instanceof CompositeNode) { + for (ExpressionNode child : ((CompositeNode)node).children()) { + evaluateMacroDependencies(context, model, child); + } + } + } + } |