aboutsummaryrefslogtreecommitdiffstats
path: root/documentapi/src/test/java/com/yahoo/documentapi/messagebus/protocol/test/MessagesTestBase.java
blob: 71cae9d136a5219c53a48a07339418fc22014125 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.documentapi.messagebus.protocol.test;

import com.yahoo.component.Version;
import com.yahoo.document.DocumentTypeManager;
import com.yahoo.document.DocumentTypeManagerConfigurer;
import com.yahoo.documentapi.messagebus.protocol.DocumentProtocol;
import com.yahoo.messagebus.Routable;
import org.junit.Test;

import java.io.IOException;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.util.*;

import static org.junit.Assert.*;

/**
 * @author Simon Thoresen Hult
 */
public abstract class MessagesTestBase {

    protected enum Language {
        JAVA,
        CPP
    }
    protected static final Set<Language> LANGUAGES = EnumSet.allOf(Language.class);

    protected final DocumentTypeManager docMan = new DocumentTypeManager();
    protected final DocumentProtocol protocol = new DocumentProtocol(docMan, null);

    public MessagesTestBase() {
        DocumentTypeManagerConfigurer.configure(docMan, "file:./test/cfg/testdoc.cfg");
    }

    @Test
    public void requireThatTestsPass() throws Exception {
        Map<Integer, RunnableTest> tests = new TreeMap<>();
        registerTests(tests);
        for (Map.Entry<Integer, RunnableTest> entry : tests.entrySet()) {
            entry.getValue().run();
        }
        if (shouldTestCoverage()) {
            assertCoverage(protocol.getRoutableTypes(version()), new ArrayList<>(tests.keySet()));
        }
    }

    /**
     * Returns the version to use for serialization.
     *
     * @return The version.
     */
    protected abstract Version version();

    /**
     * Registers the tests to run.
     */
    protected abstract void registerTests(Map<Integer, RunnableTest> out);

    /**
     * Returns whether or not to test message test coverage.
     */
    protected abstract boolean shouldTestCoverage();

    /**
     * Encodes the given routable using the current version of the test case.
     *
     * @param routable The routable to encode.
     * @return The encoded data.
     */
    public byte[] encode(Routable routable) {
        return protocol.encode(version(), routable);
    }

    /**
     * Decodes the given byte array using the current version of the test case.
     *
     * @param data The data to decode.
     * @return The decoded routable.
     */
    public Routable decode(byte[] data) {
        return protocol.decode(version(), data);
    }

    public String getPath(String filename) {
        return TestFileUtil.getPath(filename);
    }

    private boolean fileContentIsUnchanged(String path, byte[] dataToWrite) throws IOException {
        if (!Files.exists(Paths.get(path))) {
            return false;
        }
        byte[] existingData = TestFileUtil.readFile(path);
        return Arrays.equals(existingData, dataToWrite);
    }

    @FunctionalInterface
    public interface DataTamper {
        byte[] tamperWith(byte[] data);
        static byte[] truncate(byte[] data, int bytes) {
            int newLength = data.length - bytes;
            assertTrue(newLength > 0);
            byte[] res = new byte[newLength];
            System.arraycopy(data, 0, res, 0, newLength);
            return res;
        }
        static byte[] pad(byte[] data, int bytes) {
            int newLength = data.length + bytes;
            byte[] res = new byte[newLength];
            System.arraycopy(data, 0, res, 0, data.length);
            return res;
        }
    }

    /**
     * Writes the content of the given routable to the given file.
     *
     * @param filename The name of the file to write to.
     * @param routable The routable to serialize.
     * @param tamper allows tampering with serialized data
     * @return The size of the written file.
     */
    public int serialize(String filename, Routable routable, DataTamper tamper) {
        Version version = version();
        String path = getPath(version + "-java-" + filename + ".dat");
        byte[] data = protocol.encode(version, routable);
        data = tamper.tamperWith(data);
        assertNotNull(data);
        assertTrue(data.length > 0);
        try {
            if (fileContentIsUnchanged(path, data)) {
                System.out.println(String.format("Serialization for '%s' is unchanged; not overwriting it", path));
            } else {
                System.out.println(String.format("Serializing to '%s'..", path));
                // This only happens when protocol encoding has changed and takes place
                // during local development, not regular test runs.
                TestFileUtil.writeToFile(path, data);
            }
        } catch (IOException e) {
            throw new AssertionError(e);
        }
        assertEquals(routable.getType(), protocol.decode(version, data).getType());
        return data.length;
    }
    public int serialize(String filename, Routable routable) {
        return serialize(filename, routable, data -> data);
    }

    /**
     * Reads the content of the given file and creates a corresponding routable.
     *
     * @param filename The name of the file to read from.
     * @param classId  The type that the routable must decode as.
     * @param lang     The language constant that dictates what file format to read from.
     * @return The decoded routable.
     */
    public Routable deserialize(String filename, int classId, Language lang) {
        Version version = version();
        String path = getPath(version + "-" + (lang == Language.JAVA ? "java" : "cpp") + "-" + filename + ".dat");
        System.out.println("Deserializing from '" + path + "'..");
        byte[] data;
        try {
            data = TestFileUtil.readFile(path);
        } catch (IOException e) {
            throw new AssertionError(e);
        }
        Routable ret = protocol.decode(version, data);
        assertNotNull(ret);
        assertEquals(classId, ret.getType());
        return ret;
    }

    private static void assertCoverage(List<Integer> registered, List<Integer> tested) {
        boolean ok = true;
        List<Integer> lst = new ArrayList<>(tested);
        for (Integer type : registered) {
            if (!lst.contains(type)) {
                System.err.println("Routable type " + type + " is registered in DocumentProtocol but not tested.");
                ok = false;
            } else {
                lst.remove(type);
            }
        }
        if (!lst.isEmpty()) {
            for (Integer type : lst) {
                System.err.println("Routable type " + type + " is tested but not registered in DocumentProtocol.");
            }
            ok = false;
        }
        assertTrue(ok);
    }

    protected static interface RunnableTest {

        public void run() throws Exception;
    }
}