1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
|
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Ignore;
import org.junit.Test;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;
import java.nio.FloatBuffer;
import java.util.List;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
/**
* @author bratseth
*/
public class Mnist_SoftmaxTestCase {
@Ignore
@Test
public void testImporting() {
String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
ImportResult result = new TensorFlowImporter().importModel(modelDir);
// Check logged messages
result.warnings().forEach(System.err::println);
assertEquals(0, result.warnings().size());
// Check constants
assertEquals(2, result.constants().size());
Tensor constant0 = result.constants().get("Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = result.constants().get("Variable_1");
assertNotNull(constant1);
assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
constant1.type());
assertEquals(10, constant1.size());
// Check signatures
assertEquals(1, result.signatures().size());
ImportResult.Signature signature = result.signatures().get("serving_default");
assertNotNull(signature);
// ... signature inputs
assertEquals(1, signature.inputs().size());
TensorType argument0 = signature.inputType("x");
assertNotNull(argument0);
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);
// ... signature outputs
assertEquals(1, signature.outputs().size());
RankingExpression output = signature.outputs().get("y");
assertNotNull(output);
assertEquals("y", output.getName());
assertEquals("" +
"join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
"rename(constant(Variable_1), d0, d1), " +
"f(a,b)(a + b))",
toNonPrimitiveString(output));
// Test execution
// TODO: Pass imported result instead of re-importing
String signatureName = "serving_default";
assertEqualResult(modelDir, signatureName, "Variable/read");
assertEqualResult(modelDir, signatureName, "Variable_1/read");
// TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
assertEqualResult(modelDir, signatureName, "MatMul");
assertEqualResult(modelDir, signatureName, "add");
}
private void assertEqualResult(String modelDir, String signatureName, String operationName) {
ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);
Tensor tfResult = tensorFlowExecute(modelDir, operationName);
Context context = contextFrom(result);
Tensor placeholder = placeholderArgument();
context.put("Placeholder", new TensorValue(placeholder));
Tensor vespaResult = result.signatures().get(signatureName).outputs().get(operationName).evaluate(context).asTensor();
assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
}
private Tensor tensorFlowExecute(String modelDir, String operationName) {
SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
Session.Runner runner = model.session().runner();
org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
runner.feed("Placeholder", placeholder);
List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
assertEquals(1, results.size());
return new TensorConverter().toVespaTensor(results.get(0));
}
private Context contextFrom(ImportResult result) {
MapContext context = new MapContext();
result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
return context;
}
private String toNonPrimitiveString(RankingExpression expression) {
// toString on the wrapping expression will map to primitives, which is harder to read
return ((TensorFunctionNode)expression.getRoot()).function().toString();
}
private Tensor placeholderArgument() {
int size = 784;
Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
for (int i = 0; i < size; i++)
b.cell(0, 0, i);
return b.build();
}
}
|