summaryrefslogtreecommitdiffstats
path: root/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/Mnist_SoftmaxTestCase.java
blob: 53989af44605880097608c2995f58ed473da1695 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
package com.yahoo.searchlib.rankingexpression.integration.tensorflow;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Ignore;
import org.junit.Test;
import org.tensorflow.SavedModelBundle;
import org.tensorflow.Session;

import java.nio.FloatBuffer;
import java.util.List;

import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;

/**
 * @author bratseth
 */
public class Mnist_SoftmaxTestCase {

    @Ignore
    @Test
    public void testImporting() {
        String modelDir = "src/test/files/integration/tensorflow/mnist_softmax/saved";
        ImportResult result = new TensorFlowImporter().importModel(modelDir);

        // Check logged messages
        result.warnings().forEach(System.err::println);
        assertEquals(0, result.warnings().size());

        // Check constants
        assertEquals(2, result.constants().size());

        Tensor constant0 = result.constants().get("Variable");
        assertNotNull(constant0);
        assertEquals(new TensorType.Builder().indexed("d0", 784).indexed("d1", 10).build(),
                     constant0.type());
        assertEquals(7840, constant0.size());

        Tensor constant1 = result.constants().get("Variable_1");
        assertNotNull(constant1);
        assertEquals(new TensorType.Builder().indexed("d0", 10).build(),
                     constant1.type());
        assertEquals(10, constant1.size());

        // Check signatures
        assertEquals(1, result.signatures().size());
        ImportResult.Signature signature = result.signatures().get("serving_default");
        assertNotNull(signature);

        // ... signature inputs
        assertEquals(1, signature.inputs().size());
        TensorType argument0 = signature.inputType("x");
        assertNotNull(argument0);
        assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(), argument0);

        // ... signature outputs
        assertEquals(1, signature.outputs().size());
        RankingExpression output = signature.outputs().get("y");
        assertNotNull(output);
        assertEquals("y", output.getName());
        assertEquals("" +
                     "join(rename(matmul(Placeholder, rename(constant(Variable), (d0, d1), (d1, d3)), d1), d3, d1), " +
                     "rename(constant(Variable_1), d0, d1), " +
                     "f(a,b)(a + b))",
                     toNonPrimitiveString(output));

        // Test execution
        // TODO: Pass imported result instead of re-importing
        String signatureName = "serving_default";
        assertEqualResult(modelDir, signatureName, "Variable/read");
        assertEqualResult(modelDir, signatureName, "Variable_1/read");
        // TODO: Assert that argument fed is as expected assertEqualResult(modelDir, signatureName, "Placeholder");
        assertEqualResult(modelDir, signatureName, "MatMul");
        assertEqualResult(modelDir, signatureName, "add");
    }

    private void assertEqualResult(String modelDir, String signatureName, String operationName) {
        ImportResult result = new TensorFlowImporter().importNode(modelDir, signatureName, operationName);

        Tensor tfResult = tensorFlowExecute(modelDir, operationName);
        Context context = contextFrom(result);
        Tensor placeholder = placeholderArgument();
        context.put("Placeholder", new TensorValue(placeholder));
        Tensor vespaResult = result.signatures().get(signatureName).outputs().get(operationName).evaluate(context).asTensor();
        assertEquals("Operation '" + operationName + "' produces equal results", vespaResult, tfResult);
    }

    private Tensor tensorFlowExecute(String modelDir, String operationName) {
        SavedModelBundle model = SavedModelBundle.load(modelDir, "serve");
        Session.Runner runner = model.session().runner();
        org.tensorflow.Tensor<?> placeholder = org.tensorflow.Tensor.create(new long[]{ 1, 784 }, FloatBuffer.allocate(784));
        runner.feed("Placeholder", placeholder);
        List<org.tensorflow.Tensor<?>> results = runner.fetch(operationName).run();
        assertEquals(1, results.size());
        return new TensorConverter().toVespaTensor(results.get(0));
    }

    private Context contextFrom(ImportResult result) {
        MapContext context = new MapContext();
        result.constants().forEach((name, tensor) -> context.put("constant(" + name + ")", new TensorValue(tensor)));
        return context;
    }

    private String toNonPrimitiveString(RankingExpression expression) {
        // toString on the wrapping expression will map to primitives, which is harder to read
        return ((TensorFunctionNode)expression.getRoot()).function().toString();
    }

    private Tensor placeholderArgument() {
        int size = 784;
        Tensor.Builder b = Tensor.Builder.of(new TensorType.Builder().indexed("d0", 1).indexed("d1", size).build());
        for (int i = 0; i < size; i++)
            b.cell(0, 0, i);
        return b.build();
    }

}