summaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/prelude/query/NearestNeighborItem.java
blob: e8fa70afd1b679357a0f5a8a09ab9d8466f6021c (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.prelude.query;

import com.google.common.annotations.Beta;
import com.yahoo.compress.IntegerCompressor;

import java.nio.ByteBuffer;

/**
 * Represent a query item matching the K nearest neighbors in a multi-dimensional vector space.
 * The query point vector is referenced by the name of a tensor passed as a query rank feature;
 * specifying "myvector" as the name means the query must set "ranking.features.query(myvector)".
 * This rank feature must be configured with the correct tensor type in the active query profile.
 * The field name (AKA the index name) given must be an attribute, with the exact same tensor type.
 *
 * @author arnej
 */
@Beta
public class NearestNeighborItem extends SimpleTaggableItem {

    private int targetNumHits = 0;
    private int hnswExploreAdditionalHits = 0;
    private boolean approximate = true;
    private String field;
    private String queryTensorName;

    public NearestNeighborItem(String fieldName, String queryTensorName) {
        this.field = fieldName;
        this.queryTensorName = queryTensorName;
    }

    /** Returns the K number of hits to produce */
    public int getTargetNumHits() { return targetNumHits; }

    /** Returns the field name */
    public String getIndexName() { return field; }

    /** Returns the number of extra hits to explore in HNSW algorithm */
    public int getHnswExploreAdditionalHits() { return hnswExploreAdditionalHits; }

    /** Returns whether approximation is allowed */
    public boolean getAllowApproximate() { return approximate; }

    /** Returns the name of the query tensor */
    public String getQueryTensorName() { return queryTensorName; }

    /** Set the K number of hits to produce */
    public void setTargetNumHits(int target) { this.targetNumHits = target; }

    /** Set the number of extra hits to explore in HNSW algorithm */
    public void setHnswExploreAdditionalHits(int num) { this.hnswExploreAdditionalHits = num; }

    /** Set whether approximation is allowed */
    public void setAllowApproximate(boolean value) { this.approximate = value; }

    @Override
    public void setIndexName(String index) { this.field = index; }

    @Override
    public ItemType getItemType() { return ItemType.NEAREST_NEIGHBOR; }

    @Override
    public String getName() { return "NEAREST_NEIGHBOR"; }

    @Override
    public int getTermCount() { return 1; }

    @Override
    public int encode(ByteBuffer buffer) {
        super.encodeThis(buffer);
        putString(field, buffer);
        putString(queryTensorName, buffer);
        IntegerCompressor.putCompressedPositiveNumber(targetNumHits, buffer);
        IntegerCompressor.putCompressedPositiveNumber((approximate ? 1 : 0), buffer);
        IntegerCompressor.putCompressedPositiveNumber(hnswExploreAdditionalHits, buffer);
        return 1;  // number of encoded stack dump items
    }

    @Override
    protected void appendBodyString(StringBuilder buffer) {
        buffer.append("{field=").append(field);
        buffer.append(",queryTensorName=").append(queryTensorName);
        buffer.append(",hnsw.exploreAdditionalHits=").append(hnswExploreAdditionalHits);
        buffer.append(",approximate=").append(String.valueOf(approximate));
        buffer.append(",targetNumHits=").append(targetNumHits).append("}");
    }

}