aboutsummaryrefslogtreecommitdiffstats
path: root/container-search/src/main/java/com/yahoo/search/query/rewrite/rewriters/GenericExpansionRewriter.java
blob: 40b549ea4acadf09bcd07f62f42105b5aee26cf7 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.query.rewrite.rewriters;

import java.io.*;
import java.util.*;
import java.util.logging.Logger;

import com.yahoo.component.annotation.Inject;
import com.yahoo.component.chain.dependencies.Provides;
import com.yahoo.fsa.FSA;
import com.yahoo.search.query.rewrite.*;
import com.yahoo.search.*;
import com.yahoo.component.ComponentId;
import com.yahoo.filedistribution.fileacquirer.FileAcquirer;
import com.yahoo.search.query.rewrite.RewritesConfig;
import com.yahoo.prelude.querytransform.PhraseMatcher;

/**
 * This rewriter would add rewrites to entities (e.g abbreviation, synonym, etc)<br>
 * to boost precision
 * - FSA dict: [normalized original query]\t[rewrite 1]\t[rewrite 2]\t[etc]<br>
 * - Features:<br>
 *   RewritesAsUnitEquiv flag: add proximity boosted rewrites<br>
 *   PartialPhraseMatch flag: whether to match whole phrase or partial phrase<br>
 *   MaxRewrites flag: the maximum number of rewrites to be added<br>
 *
 * @author Karen Sze Wing Lee
 */
@Provides("GenericExpansionRewriter")
public class GenericExpansionRewriter extends QueryRewriteSearcher {

    // Flag for skipping this rewriter if the query has been rewritten
    private final boolean SKIP_REWRITER_IF_REWRITTEN = false;

    // Name of the rewriter
    public static final String REWRITER_NAME = "GenericExpansionRewriter";

    // Generic expansion dictionary name
    public static final String GENERIC_EXPAND_DICT = "GenericExpansion";

    // Default generic expansion dictionary file name
    public static final String GENERIC_EXPAND_DICT_FILENAME = "GenericExpansionRewriter.fsa";

    // PhraseMatcher created from FSA dict
    private PhraseMatcher phraseMatcher;

    private Logger logger;


    /**
     * Constructor for GenericExpansionRewriter.
     * Load configs using default format
     */
    @Inject
    public GenericExpansionRewriter(ComponentId id,
                        FileAcquirer fileAcquirer,
                        RewritesConfig config) {
        super(id, fileAcquirer, config);
    }

    /**
     * Constructor for GenericExpansionRewriter unit test.
     * Load configs using default format
     */
    public GenericExpansionRewriter(RewritesConfig config,
                        HashMap<String, File> fileList) {
        super(config, fileList);
    }

    /**
     * Instance creation time config loading besides FSA.
     * Create PhraseMatcher from FSA dict
     */
    public boolean configure(FileAcquirer fileAcquirer,
                             RewritesConfig config,
                             HashMap<String, File> fileList) {
        logger = Logger.getLogger(GenericExpansionRewriter.class.getName());
        FSA fsa = (FSA)rewriterDicts.get(GENERIC_EXPAND_DICT);
        if (fsa==null) {
            RewriterUtils.error(logger, "Error retrieving FSA dictionary: " + GENERIC_EXPAND_DICT);
            return false;
        }
        // Create Phrase Matcher
        RewriterUtils.log(logger, "Creating PhraseMatcher");
        try {
            phraseMatcher = new PhraseMatcher(fsa, false);
        } catch (IllegalArgumentException e) {
            RewriterUtils.error(logger, "Error creating phrase matcher");
            return false;
        }

        // Match single word as well
        phraseMatcher.setMatchSingleItems(true);

        // Return all matches instead of only the longest match
        phraseMatcher.setMatchAll(true);

        return true;
    }

    /**
     * Main logic of rewriter<br>
     * - Retrieve rewrites from FSA dict<br>
     * - rewrite query using features that are enabled by user
     */
    public HashMap<String, Object> rewrite(Query query,
                                           String dictKey) throws RuntimeException {

        Boolean rewritten = false;

        // Pass the original dict key to the next rewriter
        HashMap<String, Object> result = new HashMap<>();
        result.put(RewriterConstants.REWRITTEN, rewritten);
        result.put(RewriterConstants.DICT_KEY, dictKey);

        RewriterUtils.log(logger, query,
                         "In GenericExpansionRewriter, query used for dict retrieval=[" + dictKey + "]");

        // Retrieve flags for choosing between whole query match
        // or partial query match
        String partialPhraseMatch = getQPConfig(query, RewriterConstants.PARTIAL_PHRASE_MATCH);

        if(partialPhraseMatch==null) {
            RewriterUtils.error(logger, query, "Required param " + RewriterConstants.PARTIAL_PHRASE_MATCH +
                                               " is not set, skipping rewriter");
            throw new RuntimeException("Required param " + RewriterConstants.PARTIAL_PHRASE_MATCH +
                                       " is not set, skipping rewriter");
        }

        // Retrieve max number of rewrites allowed
        int maxNumRewrites = 0;
        String maxNumRewritesStr = getQPConfig(query, RewriterConstants.MAX_REWRITES);
        if(maxNumRewritesStr!=null) {
            maxNumRewrites = Integer.parseInt(maxNumRewritesStr);
            RewriterUtils.log(logger, query,
                              "Limiting max number of rewrites to: " + maxNumRewrites);
        } else {
            RewriterUtils.log(logger, query, "No limit on number of rewrites");
        }

        // Retrieve flags for choosing whether to add
        // the rewrites as phrase, default to false
        String rewritesAsUnitEquiv = getQPConfig(query, RewriterConstants.REWRITES_AS_UNIT_EQUIV);
        if(rewritesAsUnitEquiv==null) {
            rewritesAsUnitEquiv = "false";
        }

        Set<PhraseMatcher.Phrase> matches;

        // Partial Phrase Matching
        if(partialPhraseMatch.equalsIgnoreCase("true")) {
            RewriterUtils.log(logger, query, "Partial phrase matching");

            // Retrieve longest non overlapping matches
            matches = RewriterFeatures.getNonOverlappingPartialPhraseMatches(phraseMatcher, query);

        // Full Phrase Matching if set to anything else
        } else {
            RewriterUtils.log(logger, query, "Full phrase matching");

            // Retrieve longest non overlapping matches
            matches = RewriterFeatures.getNonOverlappingFullPhraseMatches(phraseMatcher, query);
        }

        if(matches==null) {
            return result;
        }

        // Add expansions to the query
        query = RewriterFeatures.addExpansions(query, matches, null, maxNumRewrites, false,
                                               rewritesAsUnitEquiv.equalsIgnoreCase("true"));

        rewritten = true;

        RewriterUtils.log(logger, query, "GenericExpansionRewriter final query: " + query.toDetailString());

        result.put(RewriterConstants.REWRITTEN, rewritten);

        return result;
    }

    /**
     * Get the flag which specifies whether this rewriter
     * should be skipped if the query has been rewritten
     *
     * @return true if rewriter should be skipped, false
     *         otherwise
     */
    public boolean getSkipRewriterIfRewritten() {
        return SKIP_REWRITER_IF_REWRITTEN;
    }

   /**
    * Get the name of the rewriter
    *
    * @return Name of the rewriter
    */
   public String getRewriterName() {
       return REWRITER_NAME;
   }

   /**
    * Get default FSA dictionary names
    *
    * @return Pair of FSA dictionary name and filename
    */
   public HashMap<String, String> getDefaultFSAs() {
       HashMap<String, String> defaultDicts = new HashMap<>();
       defaultDicts.put(GENERIC_EXPAND_DICT, GENERIC_EXPAND_DICT_FILENAME);
       return defaultDicts;
   }
}