summaryrefslogtreecommitdiffstats
path: root/container-search/src/test/java/com/yahoo/fs4/test/RankFeaturesTestCase.java
blob: 4671ada6545026654fbb7b466d4c7d552fdef422 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.fs4.test;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.search.query.ranking.RankFeatures;
import com.yahoo.search.query.ranking.RankProperties;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.serialization.TypedBinaryFormat;
import com.yahoo.text.Utf8;
import org.junit.Test;

import java.nio.ByteBuffer;
import java.util.*;

import static org.junit.Assert.assertEquals;

/**
 * @author geirst
 */
public class RankFeaturesTestCase {

    @Test
    public void requireThatRankPropertiesTakesBothStringAndObject() {
        RankProperties p = new RankProperties();
        p.put("string", "b");
        p.put("object", new Integer(7));
        assertEquals("7", p.get("object").get(0));
        assertEquals("b", p.get("string").get(0));
    }

    @Test
    public void requireThatSingleTensorIsBinaryEncoded() {
        TensorType type = new TensorType.Builder().mapped("x").mapped("y").mapped("z").build();
        Tensor tensor = Tensor.from(type, "{ {x:a, y:b, z:c}:2.0, {x:a, y:b, z:c2}:3.0 }");
        assertTensorEncodingAndDecoding(type, "query(my_tensor)", "my_tensor", tensor);
        assertTensorEncodingAndDecoding(type, "$my_tensor", "my_tensor", tensor);
    }

    @Test
    public void requireThatMultipleTensorsAreBinaryEncoded() {
        TensorType type = new TensorType.Builder().mapped("x").mapped("y").mapped("z").build();
        Tensor tensor1 = Tensor.from(type, "{ {x:a, y:b, z:c}:2.0, {x:a, y:b, z:c2}:3.0 }");
        Tensor tensor2 = Tensor.from(type, "{ {x:a, y:b, z:c}:5.0 }");
        assertTensorEncodingAndDecoding(type, Arrays.asList(
                new Entry("query(tensor1)", "tensor1", tensor1),
                new Entry("$tensor2", "tensor2", tensor2)));
    }

    private static class Entry {
        final String key;
        final String normalizedKey;
        final Tensor tensor;
        Entry(String key, String normalizedKey, Tensor tensor) {
            this.key = key;
            this.normalizedKey = normalizedKey;
            this.tensor = tensor;
        }
    }

    private static void assertTensorEncodingAndDecoding(TensorType type, List<Entry> entries) {
        RankProperties properties = createRankPropertiesWithTensors(entries);
        assertEquals(entries.size(), properties.asMap().size());

        Map<String, Object> decodedProperties = decode(type, encode(properties));
        assertEquals(entries.size() * 2, properties.asMap().size()); // tensor type info has been added
        assertEquals(entries.size() * 2, decodedProperties.size());
        for (Entry entry : entries) {
            assertEquals(entry.tensor, decodedProperties.get(entry.normalizedKey));
            assertEquals("tensor", decodedProperties.get(entry.normalizedKey + ".type"));
        }
    }

    private static void assertTensorEncodingAndDecoding(TensorType type, String key, String normalizedKey, Tensor tensor) {
        assertTensorEncodingAndDecoding(type, Arrays.asList(new Entry(key, normalizedKey, tensor)));
    }

    private static RankProperties createRankPropertiesWithTensors(List<Entry> entries) {
        RankFeatures features = new RankFeatures();
        for (Entry entry : entries) {
            features.put(entry.key, entry.tensor);
        }
        RankProperties properties = new RankProperties();
        features.prepare(properties);
        return properties;
    }

    private static byte[] encode(RankProperties properties) {
        ByteBuffer buffer = ByteBuffer.allocate(512);
        properties.encode(buffer, true);
        byte[] result = new byte[buffer.position()];
        buffer.rewind();
        buffer.get(result);
        return result;
    }

    private static Map<String, Object> decode(TensorType type, byte[] encodedProperties) {
        GrowableByteBuffer buffer = GrowableByteBuffer.wrap(encodedProperties);
        byte[] mapNameBytes = new byte[buffer.getInt()];
        buffer.get(mapNameBytes);
        int numEntries = buffer.getInt();
        Map<String, Object> result = new HashMap<>();
        for (int i = 0; i < numEntries; ++i) {
            byte[] keyBytes = new byte[buffer.getInt()];
            buffer.get(keyBytes);
            String key = Utf8.toString(keyBytes);
            byte[] value = new byte[buffer.getInt()];
            buffer.get(value);
            if (key.contains(".type")) {
                result.put(key, Utf8.toString(value));
            } else {
                result.put(key, TypedBinaryFormat.decode(Optional.of(type), GrowableByteBuffer.wrap(value)));
            }
        }
        return result;
    }
}