aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/searchers/InputCheckingSearcher.java
blob: fafc29e267c9561f113aa01142cac368ea49f8a0 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.searchers;

import java.nio.ByteBuffer;
import java.nio.charset.CharacterCodingException;
import java.nio.charset.CharsetDecoder;
import java.nio.charset.CodingErrorAction;
import java.nio.charset.StandardCharsets;
import java.util.HashMap;
import java.util.ListIterator;
import java.util.Map;
import java.util.logging.Logger;

import java.util.logging.Level;

import com.yahoo.component.chain.dependencies.Before;
import com.yahoo.metrics.simple.Counter;
import com.yahoo.metrics.simple.MetricReceiver;
import com.yahoo.prelude.query.CompositeItem;
import com.yahoo.prelude.query.Item;
import com.yahoo.prelude.query.PhraseItem;
import com.yahoo.prelude.query.TermItem;
import com.yahoo.prelude.query.WordItem;
import com.yahoo.processing.IllegalInputException;
import com.yahoo.search.Query;
import com.yahoo.search.Result;
import com.yahoo.search.Searcher;
import com.yahoo.search.result.ErrorMessage;
import com.yahoo.search.searchchain.Execution;
import com.yahoo.search.searchchain.PhaseNames;
import com.yahoo.yolean.Exceptions;

/**
 * Check whether the query tree seems to be "well formed". In other words, run heuristics against
 * the input data to see whether the query should sent to the search backend.
 *
 * @author Steinar Knutsen
 */
@Before(PhaseNames.BACKEND)
public class InputCheckingSearcher extends Searcher {

    private final Counter utfRejections;
    private final Counter repeatedConsecutiveTermsInPhraseRejections;
    private final Counter repeatedTermsInPhraseRejections;
    private static final Logger log = Logger.getLogger(InputCheckingSearcher.class.getName());
    private final int MAX_REPEATED_CONSECUTIVE_TERMS_IN_PHRASE = 5;
    private final int MAX_REPEATED_TERMS_IN_PHRASE = 10;

    public InputCheckingSearcher(MetricReceiver metrics) {
        utfRejections = metrics.declareCounter("double_encoded_utf8_rejections");
        repeatedTermsInPhraseRejections = metrics.declareCounter("repeated_terms_in_phrase_rejections");
        repeatedConsecutiveTermsInPhraseRejections = metrics.declareCounter("repeated_consecutive_terms_in_phrase_rejections");
    }

    @Override
    public Result search(Query query, Execution execution) {
        try {
            checkQuery(query);
        } catch (IllegalInputException e) {
            log.log(Level.FINE, () -> "Rejected query '" + query.toString() + "' on cause of: " + Exceptions.toMessageString(e));
            return new Result(query, ErrorMessage.createIllegalQuery(e.getMessage()));
        }
        return execution.search(query);
    }

    private void checkQuery(Query query) {
        doubleEncodedUtf8(query);
        checkPhrases(query.getModel().getQueryTree().getRoot());
        // add new heuristics here
    }

    private void checkPhrases(Item queryItem) {
        if (queryItem instanceof PhraseItem phrase) {
            repeatedConsecutiveTermsInPhraseCheck(phrase);
            repeatedTermsInPhraseCheck(phrase);
        } else  if (queryItem instanceof CompositeItem asComposite) {
            for (ListIterator<Item> i = asComposite.getItemIterator(); i.hasNext();) {
                checkPhrases(i.next());
            }
        }
    }

    private void repeatedConsecutiveTermsInPhraseCheck(PhraseItem phrase) {
        if (phrase.getItemCount() > MAX_REPEATED_CONSECUTIVE_TERMS_IN_PHRASE) {
            String prev = null;
            int repeatedCount = 0;
            for (int i = 0; i < phrase.getItemCount(); ++i) {
                Item item = phrase.getItem(i);
                if (item instanceof TermItem term) {
                    String current = term.getIndexedString();
                    if (prev != null) {
                        if (prev.equals(current)) {
                            repeatedCount++;
                            if (repeatedCount >= MAX_REPEATED_CONSECUTIVE_TERMS_IN_PHRASE) {
                                repeatedConsecutiveTermsInPhraseRejections.add();
                                throw new IllegalInputException("More than " + MAX_REPEATED_CONSECUTIVE_TERMS_IN_PHRASE +
                                                                " occurrences of term '" + current +
                                                                "' in a row detected in phrase : " + phrase.toString());
                            }
                        } else {
                            repeatedCount = 0;
                        }
                    }
                    prev = current;
                } else {
                    prev = null;
                    repeatedCount = 0;
                }
            }
        }
    }
    private static final class Count {
        private int v;
        Count(int initial) { v = initial; }
        void inc() { v++; }
        int get() { return v; }
    }
    private void repeatedTermsInPhraseCheck(PhraseItem phrase) {
        if (phrase.getItemCount() > MAX_REPEATED_TERMS_IN_PHRASE) {
            Map<String, Count> repeatedCount = new HashMap<>();
            for (int i = 0; i < phrase.getItemCount(); ++i) {
                Item item = phrase.getItem(i);
                if (item instanceof TermItem term) {
                    String current = term.getIndexedString();
                    Count count = repeatedCount.get(current);
                    if (count != null) {
                        if (count.get() >= MAX_REPEATED_TERMS_IN_PHRASE) {
                            repeatedTermsInPhraseRejections.add();
                            throw new IllegalInputException("Phrase contains more than " + MAX_REPEATED_TERMS_IN_PHRASE +
                                                            " occurrences of term '" + current + "' in phrase : " + phrase.toString());
                        }
                        count.inc();
                    } else {
                        repeatedCount.put(current, new Count(1));
                    }
                }
            }
        }
    }


    private void doubleEncodedUtf8(Query query) {
        int singleCharacterTerms = countSingleCharacterUserTerms(query.getModel().getQueryTree());
        if (singleCharacterTerms <= 4) {
            return;
        }
        String userInput = query.getModel().getQueryString();
        ByteBuffer asOctets = ByteBuffer.allocate(userInput.length());
        boolean asciiOnly = true;
        for (int i = 0; i < userInput.length(); ++i) {
            char c = userInput.charAt(i);
            if (c > 255) {
                return; // not double (or more) encoded
            }
            if (c > 127) {
                asciiOnly = false;
            }
            asOctets.put((byte) c);
        }
        if (asciiOnly) {
            return;
        }
        asOctets.flip();
        CharsetDecoder decoder = StandardCharsets.UTF_8.newDecoder().onMalformedInput(CodingErrorAction.REPORT)
                                                                    .onUnmappableCharacter(CodingErrorAction.REPORT);
        // OK, unmappable character is sort of theoretical, but added to be explicit
        try {
            decoder.decode(asOctets);
        } catch (CharacterCodingException e) {
            return;
        }
        utfRejections.add();
        throw new IllegalInputException("The user input has been determined to be double encoded UTF-8."
                                        + " Please investigate whether this is a false positive.");
    }

    private int countSingleCharacterUserTerms(Item queryItem) {
        if (queryItem instanceof CompositeItem asComposite) {
            int sum = 0;
            for (ListIterator<Item> i = asComposite.getItemIterator(); i.hasNext();) {
                sum += countSingleCharacterUserTerms(i.next());
            }
            return sum;
        } else if (queryItem instanceof WordItem word) {
            return (word.isFromQuery() && word.stringValue().length() == 1) ? 1 : 0;
        } else {
            return 0;
        }
    }

}