aboutsummaryrefslogtreecommitdiffstats
path: root/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzer.java
blob: b4d5a1e4edc424c755eafd8931d854d2597ace05 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.search.predicate.annotator;

import com.yahoo.document.predicate.Conjunction;
import com.yahoo.document.predicate.Disjunction;
import com.yahoo.document.predicate.FeatureConjunction;
import com.yahoo.document.predicate.FeatureRange;
import com.yahoo.document.predicate.FeatureSet;
import com.yahoo.document.predicate.Negation;
import com.yahoo.document.predicate.Predicate;
import com.yahoo.document.predicate.PredicateHash;
import com.yahoo.search.predicate.index.Feature;
import com.yahoo.search.predicate.index.conjunction.IndexableFeatureConjunction;

import java.util.HashMap;
import java.util.Map;

/**
 * This class analyzes a predicate tree to determine two characteristics:
 *  1) The sub-tree size for each conjunction/disjunction node.
 *  2) The min-feature value: a lower bound of the number of term required to satisfy a predicate. This lower bound is
 *     an estimate which is guaranteed to be less than or equal to the real lower bound.
 *
 * @author bjorncs
 */
public class PredicateTreeAnalyzer {

    /**
     * @param predicate The predicate tree.
     * @return a result object containing the min-feature value, the tree size and sub-tree sizes.
     */
    public static PredicateTreeAnalyzerResult analyzePredicateTree(Predicate predicate) {
        AnalyzerContext context = new AnalyzerContext();
        int treeSize = aggregatePredicateStatistics(predicate, false, context);
        int minFeature = ((int)Math.ceil(findMinFeature(predicate, false, context))) + (context.hasNegationPredicate ? 1 : 0);
        return new PredicateTreeAnalyzerResult(minFeature, treeSize, context.subTreeSizes);
    }

    // First analysis pass. Traverses tree in depth-first order. Determines the sub-tree sizes and counts the occurrences
    // of each feature (used by min-feature calculation in second pass).
    // Returns the size of the analyzed subtree.
    private static int aggregatePredicateStatistics(Predicate predicate, boolean isNegated, AnalyzerContext context) {
        if (predicate instanceof Negation) {
            return aggregatePredicateStatistics(((Negation) predicate).getOperand(), !isNegated, context);
        } else if (predicate instanceof Conjunction) {
            return ((Conjunction)predicate).getOperands().stream()
                    .mapToInt(child -> {
                        int size = aggregatePredicateStatistics(child, isNegated, context);
                        context.subTreeSizes.put(child, size);
                        return size;
                    }).sum();
        } else if (predicate instanceof FeatureConjunction) {
            if (isNegated) {
                context.hasNegationPredicate = true;
                return 2;
            }
            // Count the number of identical feature conjunctions - use the id from IndexableFeatureConjunction as key
            IndexableFeatureConjunction ifc = new IndexableFeatureConjunction((FeatureConjunction) predicate);
            incrementOccurrence(context.conjunctionOccurrences, ifc.id);
            // Handled as leaf in interval algorithm - count a single child
            return 1;
        } else if (predicate instanceof Disjunction) {
            return ((Disjunction)predicate).getOperands().stream()
                    .mapToInt(child -> aggregatePredicateStatistics(child, isNegated, context)).sum();
        } else if (predicate instanceof FeatureSet) {
            if (isNegated) {
                context.hasNegationPredicate = true;
                return 2;
            } else {
                FeatureSet featureSet = (FeatureSet) predicate;
                for (String value : featureSet.getValues()) {
                    incrementOccurrence(context.featureOccurrences, Feature.createHash(featureSet.getKey(), value));
                }
                return 1;
            }
        } else if (predicate instanceof FeatureRange) {
            if (isNegated) {
                context.hasNegationPredicate = true;
                return 2;
            } else {
                incrementOccurrence(context.featureOccurrences, PredicateHash.hash64(((FeatureRange) predicate).getKey()));
                return 1;
            }
        } else {
            throw new UnsupportedOperationException("Cannot handle predicate of type " + predicate.getClass().getSimpleName());
        }
    }

    // Second analysis pass. Traverses tree in depth-first order. Determines the min-feature value.
    private static double findMinFeature(Predicate predicate, boolean isNegated, AnalyzerContext context) {
        if (predicate instanceof Conjunction) {
            // Sum of children values.
            return ((Conjunction) predicate).getOperands().stream()
                    .mapToDouble(child -> findMinFeature(child, isNegated, context))
                    .sum();
        } else if (predicate instanceof FeatureConjunction) {
            if (isNegated) {
                return 0.0;
            }
            // The FeatureConjunction is handled as a leaf node in the interval algorithm.
            IndexableFeatureConjunction ifc = new IndexableFeatureConjunction((FeatureConjunction) predicate);
            return 1.0 / context.conjunctionOccurrences.get(ifc.id);
        } else if (predicate instanceof Disjunction) {
            // Minimum value of children.
            return ((Disjunction) predicate).getOperands().stream()
                    .mapToDouble(child -> findMinFeature(child, isNegated, context))
                    .min()
                    .getAsDouble();
        } else if (predicate instanceof Negation) {
            return findMinFeature(((Negation) predicate).getOperand(), !isNegated, context);
        } else if (predicate instanceof FeatureSet) {
            if (isNegated) {
                return 0.0;
            }
            double minFeature = 1.0;
            FeatureSet featureSet = (FeatureSet) predicate;
            for (String value : featureSet.getValues()) {
                long featureHash = Feature.createHash(featureSet.getKey(), value);
                // Clever mathematics to handle scenarios where same feature is used several places in predicate tree.
                minFeature = Math.min(minFeature, 1.0 / context.featureOccurrences.get(featureHash));
            }
            return minFeature;
        } else if (predicate instanceof FeatureRange) {
            if (isNegated) {
                return 0.0;
            }
            return 1.0 / context.featureOccurrences.get(PredicateHash.hash64(((FeatureRange) predicate).getKey()));
        } else {
            throw new UnsupportedOperationException("Cannot handle predicate of type " + predicate.getClass().getSimpleName());
        }
    }

    private static void incrementOccurrence(Map<Long, Integer> featureOccurrences, long featureHash) {
        featureOccurrences.merge(featureHash, 1, Integer::sum);
    }

    // Data structure to hold aggregated data during analysis.
    private static class AnalyzerContext {
        // Mapping from feature hash to occurrence count.
        public final Map<Long, Integer> featureOccurrences = new HashMap<>();
        // Mapping from conjunction id to occurrence count.
        public final Map<Long, Integer> conjunctionOccurrences = new HashMap<>();
        // Mapping from predicate to sub-tree size.
        public final Map<Predicate, Integer> subTreeSizes = new HashMap<>();
        // Does the tree contain any Negation nodes?
        public boolean hasNegationPredicate = false;
    }

}