diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-22 11:42:35 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2018-01-22 11:42:35 +0100 |
commit | 92abbad9207758578a4c56b6c9fe7f332a6546ee (patch) | |
tree | 37d448d8357587e1e620c3babd02ac6ba2f9f654 /searchlib/src/test/java/com | |
parent | 59594cb7ff0d97164eff542f184afe576e342a4b (diff) |
Parse generated tensor function trees
To make generated tensor function trees transparent to
the config model we need to convert each tensor function node
to the corresponding ranking expression node.
This is most easily done by parsing the tensor function
tree string output as a ranking expression (something
which is required to always work in any case).
Diffstat (limited to 'searchlib/src/test/java/com')
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java | 102 |
1 files changed, 5 insertions, 97 deletions
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java index 3ec074dc653..e22e4a36bab 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorflowImportTestCase.java @@ -60,10 +60,7 @@ public class TensorflowImportTestCase { RankingExpression output = signature.outputExpression("y"); assertNotNull(output); assertEquals("add", output.getName()); - assertEquals("" + - "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " + - "rename(constant(Variable_1), d0, d1), " + - "f(a,b)(a + b))", + assertEquals("join(rename(reduce(join(Placeholder, rename(constant('Variable'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('Variable_1'), d0, d1), f(a,b)(a + b))", toNonPrimitiveString(output)); // Test execution @@ -139,97 +136,8 @@ public class TensorflowImportTestCase { assertNotNull(output); assertEquals("dnn/outputs/add", output.getName()); assertEquals("" + - "join(" + - "rename(" + - "matmul(" + - "map(" + - "join(" + - "rename(" + - "matmul(" + - "map(" + - "join(" + - "rename(" + - "matmul(" + - "map(" + - "join(" + - "rename(" + - "matmul(" + - "X, " + - "rename(" + - "constant(dnn/hidden1/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/hidden1/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - "), " + - "f(a)(if(a < 0, exp(a)-1, a))" + - "), " + - "rename(" + - "constant(dnn/hidden2/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/hidden2/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - "), " + - "f(a)(max(0, a))" + - "), " + - "rename(" + - "constant(dnn/hidden3/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/hidden3/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - "), " + - "f(a)(1 / (1 + exp(-a)))" + - "), " + - "rename(" + - "constant(dnn/outputs/weights), " + - "(d0, d1), " + - "(d1, d3)" + - "), " + - "d1" + - "), " + - "d3, " + - "d1" + - "), " + - "rename(" + - "constant(dnn/outputs/bias), " + - "d0, " + - "d1" + - "), " + - "f(a,b)(a + b)" + - ")", - toNonPrimitiveString(output)); + "join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(map(join(rename(reduce(join(X, rename(constant('dnn/hidden1/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden1/bias'), d0, d1), f(a,b)(a + b)), f(a)(if (a < 0, exp(a) - 1, a))), rename(constant('dnn/hidden2/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden2/bias'), d0, d1), f(a,b)(a + b)), f(a)(max(0,a))), rename(constant('dnn/hidden3/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/hidden3/bias'), d0, d1), f(a,b)(a + b)), f(a)(1 / (1 + exp(-a)))), rename(constant('dnn/outputs/weights'), (d0, d1), (d1, d3)), f(a,b)(a * b)), sum, d1), d3, d1), rename(constant('dnn/outputs/bias'), d0, d1), f(a,b)(a + b))", + toNonPrimitiveString(output)); // Test constants assertEqualResult(model, result, "X", "dnn/hidden1/weights/read"); @@ -262,7 +170,7 @@ public class TensorflowImportTestCase { Tensor placeholder = placeholderArgument(); context.put(inputName, new TensorValue(placeholder)); Tensor vespaResult = result.expressions().get(operationName).evaluate(context).asTensor(); - assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult); + assertEquals("Operation '" + operationName + "' produces equal results", tfResult, vespaResult); } private Tensor tensorFlowExecute(SavedModelBundle model, String inputName, String operationName) { @@ -276,7 +184,7 @@ public class TensorflowImportTestCase { private Context contextFrom(TensorFlowModel result) { MapContext context = new MapContext(); - result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor))); + result.constants().forEach((name, tensor) -> context.put("constant('" + name + "')", new TensorValue(tensor))); return context; } |