aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2018-01-22 11:42:35 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2018-01-22 11:42:35 +0100
commit92abbad9207758578a4c56b6c9fe7f332a6546ee (patch)
tree37d448d8357587e1e620c3babd02ac6ba2f9f654 /searchlib/src/test/java/com
parent59594cb7ff0d97164eff542f184afe576e342a4b (diff)
Parse generated tensor function trees
To make generated tensor function trees transparent to the config model we need to convert each tensor function node to the corresponding ranking expression node. This is most easily done by parsing the tensor function tree string output as a ranking expression (something which is required to always work in any case).
Diffstat (limited to 'searchlib/src/test/java/com')
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java102
1 files changed, 5 insertions, 97 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
index 3ec074dc653..e22e4a36bab 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java
@@ -60,10 +60,7 @@ public class TensorflowImportTestCase {
RankingExpression output = signature.outputExpression("y");
assertNotNull(output);
assertEquals("add", output.getName());
- assertEquals("" +
- "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
- "rename(constant(Variable_1), d0, d1), " +
- "f(a,b)(a + b))",
+ assertEquals("join(rename(reduce(join(Placeholder, rename(constant('Variable'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('Variable_1'), d0, d1), f(a,b)(a + b))",
toNonPrimitiveString(output));
// Test execution
@@ -139,97 +136,8 @@ public class TensorflowImportTestCase {
assertNotNull(output);
assertEquals("dnn/outputs/add", output.getName());
assertEquals("" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "map(" +
- "join(" +
- "rename(" +
- "matmul(" +
- "X, " +
- "rename(" +
- "constant(dnn/hidden1/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden1/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(if(a < 0, exp(a)-1, a))" +
- "), " +
- "rename(" +
- "constant(dnn/hidden2/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden2/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(max(0, a))" +
- "), " +
- "rename(" +
- "constant(dnn/hidden3/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/hidden3/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- "), " +
- "f(a)(1 / (1 + exp(-a)))" +
- "), " +
- "rename(" +
- "constant(dnn/outputs/weights), " +
- "(d0, d1), " +
- "(d1, d3)" +
- "), " +
- "d1" +
- "), " +
- "d3, " +
- "d1" +
- "), " +
- "rename(" +
- "constant(dnn/outputs/bias), " +
- "d0, " +
- "d1" +
- "), " +
- "f(a,b)(a + b)" +
- ")",
- toNonPrimitiveString(output));
+ "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))",
+ toNonPrimitiveString(output));
// Test constants
assertEqualResult(model, result, "X", "dnn/hidden1/weights/read");
@@ -262,7 +170,7 @@ public class TensorflowImportTestCase {
Tensor placeholder = placeholderArgument();
context.put(inputName, new TensorValue(placeholder));
Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor();
- assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
+ assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult);
}
private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) {
@@ -276,7 +184,7 @@ public class TensorflowImportTestCase {
private Context contextFrom(TensorFlowModel result) {
MapContext context = new MapContext();
- result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
+ result.constants().forEach((name, tensor) -> context.put("constant('" + name + "')", new TensorValue(tensor)));
return context;
}