summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-03-08 16:51:10 +0100
committerLester Solbakken <lesters@oath.com>2018-03-08 17:01:07 +0100
commit7ce1417a1a0c4cd3cd72d903f0c9ffde93baade8 (patch)
treec7b23f9bab1481f31b0eadeb7e4f34ef75939c2b /config-model
parent697f91c9753b25b5074be2a4d99c10987c5cac62 (diff)
Make TensorFlow import joins compatible with broadcasting
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java8
1 files changed, 4 insertions, 4 deletions
diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
index 2cadbbf50e7..90a20746a0a 100644
--- a/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
+++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/RankingExpressionWithTensorFlowTestCase.java
@@ -306,9 +306,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testMacroGeneration() {
- final String expression = "join(join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
+ final String expression = "join(join(reduce(join(join(join(tf_macro_dnn_hidden2_add, reduce(constant(dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(tf_macro_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",
"tensorflow('mnist/saved')", null, null, "input", new StoringApplicationPackage(applicationDir));
@@ -319,9 +319,9 @@ public class RankingExpressionWithTensorFlowTestCase {
@Test
public void testImportingFromStoredExpressionsWithSmallConstants() throws IOException {
- final String expression = "join(join(reduce(join(join(join(constant(dnn_hidden2_Const), tf_macro_dnn_hidden2_add, f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
+ final String expression = "join(join(reduce(join(join(join(tf_macro_dnn_hidden2_add, reduce(constant(dnn_hidden2_Const), sum, d2), f(a,b)(a * b)), tf_macro_dnn_hidden2_add, f(a,b)(max(a,b))), constant(dnn_outputs_weights_read), f(a,b)(a * b)), sum, d2), constant(dnn_outputs_bias_read), f(a,b)(a + b)), tensor(d0[1])(0.0), f(a,b)(a + b))";
final String macroExpression1 = "join(reduce(join(reduce(rename(input, (d0, d1), (d0, d4)), sum, d0), constant(dnn_hidden1_weights_read), f(a,b)(a * b)), sum, d4), constant(dnn_hidden1_bias_read), f(a,b)(a + b))";
- final String macroExpression2 = "join(reduce(join(join(join(0.009999999776482582, tf_macro_dnn_hidden1_add, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
+ final String macroExpression2 = "join(reduce(join(join(join(tf_macro_dnn_hidden1_add, 0.009999999776482582, f(a,b)(a * b)), tf_macro_dnn_hidden1_add, f(a,b)(max(a,b))), constant(dnn_hidden2_weights_read), f(a,b)(a * b)), sum, d3), constant(dnn_hidden2_bias_read), f(a,b)(a + b))";
StoringApplicationPackage application = new StoringApplicationPackage(applicationDir);
RankProfileSearchFixture search = fixtureWith("tensor(d0[1],d1[784])(0.0)",