diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-06-15 23:09:44 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-06-15 23:09:44 +0200 |
commit | 72231250ed81e10d66bfe70701e64fa5fe50f712 (patch) | |
tree | 2728bba1131a6f6e5bdf95afec7d7ff9358dac50 /predicate-search |
Publish
Diffstat (limited to 'predicate-search')
67 files changed, 7469 insertions, 0 deletions
diff --git a/predicate-search/OWNERS b/predicate-search/OWNERS new file mode 100644 index 00000000000..569bf1cc3a1 --- /dev/null +++ b/predicate-search/OWNERS @@ -0,0 +1 @@ +bjorncs diff --git a/predicate-search/pom.xml b/predicate-search/pom.xml new file mode 100644 index 00000000000..fe81578c4d7 --- /dev/null +++ b/predicate-search/pom.xml @@ -0,0 +1,74 @@ +<?xml version="1.0"?> +<!-- Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. --> +<project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" + xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> + <modelVersion>4.0.0</modelVersion> + <parent> + <groupId>com.yahoo.vespa</groupId> + <artifactId>parent</artifactId> + <version>6-SNAPSHOT</version> + <relativePath>../parent/pom.xml</relativePath> + </parent> + <artifactId>predicate-search</artifactId> + <version>6-SNAPSHOT</version> + <packaging>jar</packaging> + <name>${project.artifactId}</name> + <dependencies> + <dependency> + <groupId>junit</groupId> + <artifactId>junit</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>org.mockito</groupId> + <artifactId>mockito-core</artifactId> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>predicate-search-core</artifactId> + <version>${project.version}</version> + </dependency> + <dependency> + <groupId>com.google.guava</groupId> + <artifactId>guava</artifactId> + </dependency> + <dependency> + <groupId>com.goldmansachs</groupId> + <artifactId>gs-collections</artifactId> + </dependency> + <dependency> + <groupId>io.airlift</groupId> + <artifactId>airline</artifactId> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-core</artifactId> + </dependency> + <dependency> + <groupId>com.fasterxml.jackson.core</groupId> + <artifactId>jackson-databind</artifactId> + </dependency> + </dependencies> + <build> + <plugins> + <plugin> + <artifactId>maven-assembly-plugin</artifactId> + <configuration> + <descriptorRefs> + <descriptorRef>jar-with-dependencies</descriptorRef> + </descriptorRefs> + </configuration> + <executions> + <execution> + <id>make-assembly</id> + <phase>package</phase> + <goals> + <goal>single</goal> + </goals> + </execution> + </executions> + </plugin> + </plugins> + </build> +</project> diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/Config.java b/predicate-search/src/main/java/com/yahoo/search/predicate/Config.java new file mode 100644 index 00000000000..23569baab73 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/Config.java @@ -0,0 +1,75 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.google.common.annotations.Beta; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +/** + * Configuration for a {@link PredicateIndexBuilder}/{@link PredicateIndex} instance. + * + * @author bjorncs + */ +@Beta +public class Config { + public final int arity; + public final long lowerBound; + public final long upperBound; + public final boolean useConjunctionAlgorithm; + + private Config(int arity, long lowerBound, long upperBound, boolean useConjunctionAlgorithm) { + this.arity = arity; + this.lowerBound = lowerBound; + this.upperBound = upperBound; + this.useConjunctionAlgorithm = useConjunctionAlgorithm; + } + + public void writeToOutputStream(DataOutputStream out) throws IOException { + out.writeInt(arity); + out.writeLong(lowerBound); + out.writeLong(upperBound); + out.writeBoolean(useConjunctionAlgorithm); + } + + public static Config fromInputStream(DataInputStream in) throws IOException { + int arity = in.readInt(); + long lowerBound = in.readLong(); + long upperBound = in.readLong(); + boolean useConjunctionAlgorithm = in.readBoolean(); + return new Config(arity, lowerBound, upperBound, useConjunctionAlgorithm); + } + + public static class Builder { + private int arity = 8; + private long lowerBound = Long.MIN_VALUE; + private long upperBound = Long.MAX_VALUE; + private boolean useConjunctionAlgorithm = false; + + public Builder setArity(int arity) { + this.arity = arity; + return this; + } + + public Builder setLowerBound(long lowerBound) { + this.lowerBound = lowerBound; + return this; + } + + public Builder setUpperBound(long upperBound) { + this.upperBound = upperBound; + return this; + } + + public Builder setUseConjunctionAlgorithm(boolean enabled) { + this.useConjunctionAlgorithm = enabled; + return this; + } + + public Config build() { + return new Config(arity, lowerBound, upperBound, useConjunctionAlgorithm); + } + + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/Hit.java b/predicate-search/src/main/java/com/yahoo/search/predicate/Hit.java new file mode 100644 index 00000000000..6568ef928e3 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/Hit.java @@ -0,0 +1,68 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.google.common.annotations.Beta; + +/** + * Represents a hit from the predicate search algorithm. + * Each hit is associated with a subquery bitmap, + * indicating which subqueries the hit represents. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +@Beta +public class Hit implements Comparable<Hit> { + private final int docId; + private final long subquery; + + public Hit(int docId) { + this(docId, SubqueryBitmap.DEFAULT_VALUE); + } + + public Hit(int docId, long subquery) { + this.docId = docId; + this.subquery = subquery; + } + + @Override + public String toString() { + if (subquery == SubqueryBitmap.DEFAULT_VALUE) { + return "" + docId; + } else { + return "[" + docId + ",0x" + Long.toHexString(subquery) + "]"; + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Hit hit = (Hit) o; + + if (docId != hit.docId) return false; + if (subquery != hit.subquery) return false; + + return true; + } + + @Override + public int hashCode() { + int result = docId; + result = 31 * result + (int) (subquery ^ (subquery >>> 32)); + return result; + } + + public int getDocId() { + return docId; + } + + public long getSubquery() { + return subquery; + } + + @Override + public int compareTo(Hit o) { + return Integer.compare(docId, o.docId); + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateIndex.java b/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateIndex.java new file mode 100644 index 00000000000..44ab9ea372e --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateIndex.java @@ -0,0 +1,225 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.google.common.annotations.Beta; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.search.predicate.index.*; +import com.yahoo.search.predicate.index.conjunction.ConjunctionHit; +import com.yahoo.search.predicate.index.conjunction.ConjunctionIndex; +import com.yahoo.search.predicate.serialization.SerializationHelper; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; +import java.util.stream.Stream; + +/** + * An index of {@link Predicate} objects. + * <p> + * Use a {@link PredicateQuery} to find the ids of documents that have matching Predicates. + * Create an instance of {@link PredicateIndex} using the {@link PredicateIndexBuilder}. + * </p><p> + * To build a {@link PredicateQuery} you add features and rangeFeatures with a 64-bit + * bitmap specifying which subqueries they appear in. + * </p><p> + * To perform a search, create a {@link Searcher} and call its {@link Searcher#search(PredicateQuery)} + * method, which returns a stream of {@link Hit} objects, + * each of which contains a document id and a 64-bit bitmap specifying which subqueries the hit is for. + * </p><p> + * Note that the {@link PredicateIndex} is thread-safe, but a {@link Searcher} is not. + * Each thread <strong>must</strong> use its own searcher. + * </p> + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +@Beta +public class PredicateIndex { + private static final int SERIALIZATION_FORMAT_VERSION = 3; + + private final PredicateRangeTermExpander expander; + private final int[] internalToExternalIdMapping; + private final byte[] minFeatureIndex; + private final short[] intervalEnds; + private final int highestIntervalEnd; + private final SimpleIndex intervalIndex; + private final SimpleIndex boundsIndex; + private final SimpleIndex conjunctionIntervalIndex; + private final PredicateIntervalStore intervalStore; + private final ConjunctionIndex conjunctionIndex; + private final int[] zeroConstraintDocuments; + private final Config config; + private final AtomicReference<CachedPostingListCounter> postingListCounter; + + + /** + * Package private as the index should be constructed using {@link PredicateIndexBuilder}. + */ + PredicateIndex( + Config config, + int[] internalToExternalIdMapping, + byte[] minFeatureIndex, + short[] intervalEnds, + int highestIntervalEnd, + SimpleIndex intervalIndex, + SimpleIndex boundsIndex, + SimpleIndex conjunctionIntervalIndex, + PredicateIntervalStore intervalStore, + ConjunctionIndex conjunctionIndex, + int[] zeroConstraintDocuments) { + this.internalToExternalIdMapping = internalToExternalIdMapping; + this.minFeatureIndex = minFeatureIndex; + this.intervalEnds = intervalEnds; + this.highestIntervalEnd = highestIntervalEnd; + this.intervalIndex = intervalIndex; + this.boundsIndex = boundsIndex; + this.conjunctionIntervalIndex = conjunctionIntervalIndex; + this.intervalStore = intervalStore; + this.conjunctionIndex = conjunctionIndex; + this.zeroConstraintDocuments = zeroConstraintDocuments; + this.expander = new PredicateRangeTermExpander(config.arity, config.lowerBound, config.upperBound); + this.config = config; + this.postingListCounter = new AtomicReference<>(new CachedPostingListCounter(internalToExternalIdMapping.length)); + } + + public void rebuildPostingListCache() { + postingListCounter.getAndUpdate(CachedPostingListCounter::rebuildCache); + } + + /** + * Create a new searcher. + */ + public Searcher searcher() { + return new Searcher(); + } + + public void writeToOutputStream(DataOutputStream out) throws IOException { + out.writeInt(SERIALIZATION_FORMAT_VERSION); + config.writeToOutputStream(out); + SerializationHelper.writeIntArray(internalToExternalIdMapping, out); + SerializationHelper.writeByteArray(minFeatureIndex, out); + SerializationHelper.writeShortArray(intervalEnds, out); + out.writeInt(highestIntervalEnd); + SerializationHelper.writeIntArray(zeroConstraintDocuments, out); + intervalIndex.writeToOutputStream(out); + boundsIndex.writeToOutputStream(out); + conjunctionIntervalIndex.writeToOutputStream(out); + intervalStore.writeToOutputStream(out); + conjunctionIndex.writeToOutputStream(out); + } + + public static PredicateIndex fromInputStream(DataInputStream in) throws IOException { + int version = in.readInt(); + if (version != SERIALIZATION_FORMAT_VERSION) { + throw new IllegalArgumentException(String.format( + "Invalid serialization format version. Expected %d, was %d.", SERIALIZATION_FORMAT_VERSION, version)); + } + Config config = Config.fromInputStream(in); + int[] internalToExternalIdMapping = SerializationHelper.readIntArray(in); + byte[] minFeatureIndex = SerializationHelper.readByteArray(in); + short[] intervalEnds = SerializationHelper.readShortArray(in); + int highestIntervalEnd = in.readInt(); + int[] zeroConstraintDocuments = SerializationHelper.readIntArray(in); + SimpleIndex intervalIndex = SimpleIndex.fromInputStream(in); + SimpleIndex boundsIndex = SimpleIndex.fromInputStream(in); + SimpleIndex conjunctionIntervalIndex = SimpleIndex.fromInputStream(in); + PredicateIntervalStore intervalStore = PredicateIntervalStore.fromInputStream(in); + ConjunctionIndex conjunctionIndex = ConjunctionIndex.fromInputStream(in); + return new PredicateIndex( + config, + internalToExternalIdMapping, + minFeatureIndex, + intervalEnds, + highestIntervalEnd, + intervalIndex, + boundsIndex, + conjunctionIntervalIndex, + intervalStore, + conjunctionIndex, + zeroConstraintDocuments + ); + } + + @Beta + public class Searcher { + private final byte[] nPostingListsForDocument; + private final ConjunctionIndex.Searcher conjunctionIndexSearcher; + + private Searcher() { + this.nPostingListsForDocument = new byte[internalToExternalIdMapping.length]; + this.conjunctionIndexSearcher = conjunctionIndex.searcher(); + } + + /** + * Retrieves a stream of hits for the given query. + * + * @param query Specifies the boolean variables that are true. + * @return A stream of hits. + */ + public Stream<Hit> search(PredicateQuery query) { + ArrayList<PostingList> postingLists = new ArrayList<>(); + for (PredicateQuery.Feature feature : query.getFeatures()) { + addIntervalPostingList(feature.featureHash, feature.subqueryBitmap, postingLists); + } + for (PredicateQuery.RangeFeature feature : query.getRangeFeatures()) { + expander.expand( + feature.key, + feature.value, + featureHash -> addIntervalPostingList(featureHash, feature.subqueryBitmap, postingLists), + (featureHash, value) -> addBoundsPostingList(featureHash, value, feature.subqueryBitmap, postingLists)); + } + addCompressedZStarPostingList(postingLists); + addConjunctionPostingLists(query, postingLists); + addZeroConstraintPostingList(postingLists); + + CachedPostingListCounter counter = postingListCounter.get(); + counter.registerUsage(postingLists); + counter.countPostingListsPerDocument(postingLists, nPostingListsForDocument); + return new PredicateSearch( + postingLists, nPostingListsForDocument, minFeatureIndex, intervalEnds, highestIntervalEnd).stream() + // Map to external id. Note that internal id for first document is 1. + .map(hit -> new Hit(internalToExternalIdMapping[hit.getDocId()], hit.getSubquery())); + } + + private void addCompressedZStarPostingList(List<PostingList> postingLists) { + SimpleIndex.Entry e = intervalIndex.getPostingList(Feature.Z_STAR_COMPRESSED_ATTRIBUTE_HASH); + if (e != null) { + postingLists.add(new ZstarCompressedPostingList(intervalStore, e.docIds, e.dataRefs)); + } + } + + private void addBoundsPostingList( + long featureHash, int value, long subqueryBitMap, List<PostingList> postingLists) { + SimpleIndex.Entry e = boundsIndex.getPostingList(featureHash); + if (e != null) { + postingLists.add(new BoundsPostingList(intervalStore, e.docIds, e.dataRefs, subqueryBitMap, value)); + } + } + + private void addIntervalPostingList(long featureHash, long subqueryBitMap, List<PostingList> postingLists) { + SimpleIndex.Entry e = intervalIndex.getPostingList(featureHash); + if (e != null) { + postingLists.add(new IntervalPostingList(intervalStore, e.docIds, e.dataRefs, subqueryBitMap)); + } + } + + private void addConjunctionPostingLists(PredicateQuery query, List<PostingList> postingLists) { + List<ConjunctionHit> hits = conjunctionIndexSearcher.search(query); + for (ConjunctionHit hit : hits) { + SimpleIndex.Entry e = conjunctionIntervalIndex.getPostingList(hit.conjunctionId); + if (e != null) { + postingLists.add(new IntervalPostingList(intervalStore, e.docIds, e.dataRefs, hit.subqueryBitmap)); + } + } + } + + private void addZeroConstraintPostingList(ArrayList<PostingList> postingLists) { + if (zeroConstraintDocuments.length > 0) { + postingLists.add(new ZeroConstraintPostingList(zeroConstraintDocuments)); + } + } + + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateIndexBuilder.java b/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateIndexBuilder.java new file mode 100644 index 00000000000..84940e54c02 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateIndexBuilder.java @@ -0,0 +1,269 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.google.common.annotations.Beta; +import com.google.common.base.Preconditions; +import com.google.common.primitives.Bytes; +import com.google.common.primitives.Ints; +import com.google.common.primitives.Shorts; +import com.yahoo.document.predicate.BooleanPredicate; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.search.predicate.annotator.PredicateTreeAnnotations; +import com.yahoo.search.predicate.annotator.PredicateTreeAnnotator; +import com.yahoo.search.predicate.index.Feature; +import com.yahoo.search.predicate.index.Interval; +import com.yahoo.search.predicate.index.IntervalWithBounds; +import com.yahoo.search.predicate.index.Posting; +import com.yahoo.search.predicate.index.PredicateIntervalStore; +import com.yahoo.search.predicate.index.PredicateOptimizer; +import com.yahoo.search.predicate.index.SimpleIndex; +import com.yahoo.search.predicate.index.conjunction.ConjunctionIndexBuilder; +import com.yahoo.search.predicate.index.conjunction.IndexableFeatureConjunction; + +import java.util.ArrayList; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; + +/** + * A builder for {@link PredicateIndex}. + * <p> + * When creating a PredicateIndexBuilder, you must specify an arity. This is used for + * range features, and is a trade-off of index size vs. query speed. Higher + * arities gives larger index but faster search. + * </p> + * <p> + * {@link #indexDocument(int, Predicate)} + * takes a document id and a predicate to insert into the index. + * Predicates should be specified using the predicate syntax described in the documentation. + * Create the {@link Predicate} objects using {@link Predicate#fromString(String)}. + * </p> + * <p> + * Use {@link #build()} to create an instance of {@link PredicateIndex}. + * </p> + * @author bjorncs + */ +@Beta +public class PredicateIndexBuilder { + + // Unique ids / mapping from internal to external id. LinkedHashSet as the insertion order is crucial. + private final Set<Integer> seenIds = new LinkedHashSet<>(); + private final List<Short> intervalEndsBuilder = new ArrayList<>(); + private final List<Byte> minFeatureIndexBuilder = new ArrayList<>(); + private final List<Integer> zeroConstraintDocuments = new ArrayList<>(); + private final SimpleIndex.Builder intervalIndexBuilder = new SimpleIndex.Builder(); + private final SimpleIndex.Builder boundsIndexBuilder = new SimpleIndex.Builder(); + private final SimpleIndex.Builder conjunctionIntervalIndexBuilder = new SimpleIndex.Builder(); + private final ConjunctionIndexBuilder conjunctionIndexBuilder = new ConjunctionIndexBuilder(); + private final PredicateIntervalStore.Builder intervalStoreBuilder; + private final PredicateOptimizer optimizer; + private final Config config; + private int documentIdCounter = 0; + private int nZStarDocuments = 0; + private int nZStarIntervals = 0; + private int highestIntervalEnd = 1; + + /** + * Creates a PredicateIndexBuilder with default upper and lower bounds. + * + * @param arity The arity to use when indexing range predicates. + * Small arity gives smaller index, but more expensive searches. + */ + public PredicateIndexBuilder(int arity) { + this(new Config.Builder().setArity(arity).build()); + } + + /** + * Creates a PredicateIndexBuilder. + * Limiting the range of possible values in range predicates reduces index size + * and increases search performance. + * + * @param arity The arity to use when indexing range predicates. + * Small arity gives smaller index, but more expensive searches. + * @param lowerBound The lower bound for the range of values used by range predicates. + * @param upperBound The upper bound for the range of values used by range predicates. + */ + public PredicateIndexBuilder(int arity, long lowerBound, long upperBound) { + this(new Config.Builder().setArity(arity).setLowerBound(lowerBound).setUpperBound(upperBound).build()); + } + + /** + * Creates a PredicateIndexBuilder based on a Config object. + * + * @param config Configuration for the PredicateIndexBuilder. + */ + public PredicateIndexBuilder(Config config) { + this.config = config; + this.optimizer = new PredicateOptimizer(config); + this.intervalStoreBuilder = new PredicateIntervalStore.Builder(); + } + + /** + * Indexes a predicate with the given id. + * + * @param docId A 32-bit document id, returned in the Hit objects when the predicate matches. + * @param predicate The predicate to index. + */ + public void indexDocument(int docId, Predicate predicate) { + if (documentIdCounter == Integer.MAX_VALUE) { + throw new IllegalStateException("Index is full, max number of documents is: " + Integer.MAX_VALUE); + } else if (seenIds.contains(docId)) { + throw new IllegalArgumentException("Document id is already in use: " + docId); + } else if (isNeverMatchingDocument(predicate)) { + return; + } + seenIds.add(docId); + predicate = optimizer.optimizePredicate(predicate); + int internalId = documentIdCounter++; + if (isAlwaysMatchingDocument(predicate)) { + indexZeroConstraintDocument(internalId); + } else { + indexDocument(internalId, PredicateTreeAnnotator.createPredicateTreeAnnotations(predicate)); + } + } + + private static boolean isAlwaysMatchingDocument(Predicate p) { + return p instanceof BooleanPredicate && ((BooleanPredicate) p).getValue(); + } + + private static boolean isNeverMatchingDocument(Predicate p) { + return p instanceof BooleanPredicate && !((BooleanPredicate) p).getValue(); + } + + private void indexZeroConstraintDocument(int docId) { + minFeatureIndexBuilder.add((byte) 0); + intervalEndsBuilder.add((short) Interval.ZERO_CONSTRAINT_RANGE); + zeroConstraintDocuments.add(docId); + } + + private void indexDocument(int docId, PredicateTreeAnnotations annotations) { + int minFeature = annotations.minFeature; + Preconditions.checkState(minFeature <= 0xFF, + "Predicate is too complex. Expected min-feature less than %d, was %d.", 0xFF, minFeature); + int intervalEnd = annotations.intervalEnd; + Preconditions.checkState(intervalEnd <= Interval.MAX_INTERVAL_END, + "Predicate is too complex. Expected min-feature less than %d, was %d.", + Interval.MAX_INTERVAL_END, intervalEnd); + highestIntervalEnd = Math.max(highestIntervalEnd, intervalEnd); + intervalEndsBuilder.add((short) intervalEnd); + minFeatureIndexBuilder.add((byte) minFeature); + indexDocumentFeatures(docId, annotations.intervalMap); + indexDocumentBoundsFeatures(docId, annotations.boundsMap); + indexDocumentConjunctions(docId, annotations.featureConjunctions); + aggregateZStarStatistics(annotations.intervalMap); + } + + private void aggregateZStarStatistics(Map<Long, List<Integer>> intervalMap) { + List<Integer> intervals = intervalMap.get(Feature.Z_STAR_COMPRESSED_ATTRIBUTE_HASH); + if (intervals != null) { + ++nZStarDocuments; + nZStarIntervals += intervals.size(); + } + } + + private void indexDocumentFeatures(int docId, Map<Long, List<Integer>> intervalMap) { + intervalMap.entrySet().stream() + .forEach(entry -> intervalIndexBuilder.insert(entry.getKey(), + new Posting(docId, + intervalStoreBuilder.insert(entry.getValue())))); + } + + private void indexDocumentBoundsFeatures(int docId, Map<Long, List<IntervalWithBounds>> boundsMap) { + boundsMap.entrySet().stream() + .forEach(entry -> boundsIndexBuilder.insert(entry.getKey(), + new Posting(docId, + intervalStoreBuilder.insert( + entry.getValue().stream().flatMap(IntervalWithBounds::stream).collect(toList()))))); + } + + private void indexDocumentConjunctions( + int docId, Map<IndexableFeatureConjunction, List<Integer>> featureConjunctions) { + for (Map.Entry<IndexableFeatureConjunction, List<Integer>> e : featureConjunctions.entrySet()) { + IndexableFeatureConjunction fc = e.getKey(); + List<Integer> intervals = e.getValue(); + Posting posting = new Posting(docId, intervalStoreBuilder.insert(intervals)); + conjunctionIntervalIndexBuilder.insert(fc.id, posting); + conjunctionIndexBuilder.indexConjunction(fc); + } + } + + public PredicateIndex build() { + return new PredicateIndex( + config, + Ints.toArray(seenIds), + Bytes.toArray(minFeatureIndexBuilder), + Shorts.toArray(intervalEndsBuilder), + highestIntervalEnd, + intervalIndexBuilder.build(), + boundsIndexBuilder.build(), + conjunctionIntervalIndexBuilder.build(), + intervalStoreBuilder.build(), + conjunctionIndexBuilder.build(), + Ints.toArray(zeroConstraintDocuments) + ); + } + + public int getZeroConstraintDocCount() { + return zeroConstraintDocuments.size(); + } + + /** + * Retrieve metrics about the current index. + * @return An object containing metrics. + */ + public PredicateIndexStats getStats() { + return new PredicateIndexStats(zeroConstraintDocuments, intervalIndexBuilder, + boundsIndexBuilder, intervalStoreBuilder, conjunctionIndexBuilder, nZStarDocuments, nZStarIntervals); + } + + /** + * A collection of metrics about the currently built {@link PredicateIndex}. + */ + public static class PredicateIndexStats { + private final Map<String, Object> metrics = new TreeMap<>(); + + public PredicateIndexStats( + List<Integer> zeroConstraintDocuments, + SimpleIndex.Builder intervalIndex, + SimpleIndex.Builder boundsIndex, + PredicateIntervalStore.Builder intervalStore, + ConjunctionIndexBuilder conjunctionIndex, + int nZStarDocuments, + int nZStarIntervals) { + Map<Integer, Integer> intervalStoreEntries = intervalStore.getEntriesForSize(); + metrics.put("Zero-constraint documents", zeroConstraintDocuments.size()); + metrics.put("Interval index keys", intervalIndex.getKeyCount()); + metrics.put("Interval index entries", intervalIndex.getEntryCount()); + metrics.put("Bounds index keys", boundsIndex.getKeyCount()); + metrics.put("Bounds index entries", boundsIndex.getEntryCount()); + metrics.put("Conjunction index feature count", conjunctionIndex.calculateFeatureCount()); + metrics.put("Conjunction index unique conjunction count", conjunctionIndex.getUniqueConjunctionCount()); + metrics.put("Conjunction index conjunction count", conjunctionIndex.getConjunctionsSeen()); + metrics.put("Conjunction index Z list size", conjunctionIndex.getZListSize()); + metrics.put("Interval store cache hits", intervalStore.getCacheHits()); + metrics.put("Interval store insert count", intervalStore.getTotalInserts()); + metrics.put("Interval store interval count", intervalStore.getNumberOfIntervals()); + metrics.put("Documents with ZStar intervals", nZStarDocuments); + metrics.put("Total ZStar intervals", nZStarIntervals); + intervalStoreEntries.entrySet().stream() + .filter(entry -> entry.getKey() != 0) + .forEach(entry -> metrics.put("Size " + entry.getKey() + " intervals", entry.getValue())); + } + + public void putValues(Map<String, Object> valueMap) { + valueMap.putAll(metrics); + } + + @Override + public String toString() { + return metrics.entrySet().stream() + .map(e -> String.format("%50s: %s", e.getKey(), e.getValue())) + .collect(joining("\n")); + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateQuery.java b/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateQuery.java new file mode 100644 index 00000000000..779cd6cfd47 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/PredicateQuery.java @@ -0,0 +1,84 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.google.common.annotations.Beta; + +import java.util.ArrayList; +import java.util.List; + +/** + * Represents a query in the form of a set of boolean variables that are considered true. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +@Beta +public class PredicateQuery { + + private final ArrayList<Feature> features = new ArrayList<>(); + private final ArrayList<RangeFeature> rangeFeatures = new ArrayList<>(); + + /** + * Adds a feature to the query. + * @param key Feature key + * @param value Feature value + */ + public void addFeature(String key, String value) { + addFeature(key, value, SubqueryBitmap.DEFAULT_VALUE); + } + /** + * Adds a feature to the query, e.g. gender = male. + * @param key Feature key + * @param value Feature value + * @param subqueryBitMap The subquery bitmap for which this term is true + */ + public void addFeature(String key, String value, long subqueryBitMap) { + features.add(new Feature(key, value, subqueryBitMap)); + } + public void addRangeFeature(String key, long value) { addRangeFeature(key, value, SubqueryBitmap.DEFAULT_VALUE);} + /** + * Adds a range feature to the query, e.g. age = 25. + * @param key Feature key + * @param value Feature value + * @param subqueryBitMap The subquery bitmap for which this term is true + */ + public void addRangeFeature(String key, long value, long subqueryBitMap) { + rangeFeatures.add(new RangeFeature(key, value, subqueryBitMap)); + } + /** + * @return A list of features + */ + public List<Feature> getFeatures() { return features; } + + /** + * @return A list of range features + */ + public List<RangeFeature> getRangeFeatures() { return rangeFeatures; } + + public static class Feature { + public final String key; + public final String value; + public final long subqueryBitmap; + public final long featureHash; + + public Feature(String key, String value, long subqueryBitmap) { + this.featureHash = com.yahoo.search.predicate.index.Feature.createHash(key, value); + this.subqueryBitmap = subqueryBitmap; + this.value = value; + this.key = key; + } + } + + public static class RangeFeature { + public final String key; + public final long value; + public final long subqueryBitmap; + + public RangeFeature(String key, long value, long subqueryBitmap) { + this.key = key; + this.value = value; + this.subqueryBitmap = subqueryBitmap; + } + } + +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzer.java b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzer.java new file mode 100644 index 00000000000..2019bafa693 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzer.java @@ -0,0 +1,148 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.annotator; + +import com.yahoo.document.predicate.Conjunction; +import com.yahoo.document.predicate.Disjunction; +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.FeatureRange; +import com.yahoo.document.predicate.FeatureSet; +import com.yahoo.document.predicate.Negation; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.document.predicate.PredicateHash; +import com.yahoo.search.predicate.index.Feature; +import com.yahoo.search.predicate.index.conjunction.IndexableFeatureConjunction; + +import java.util.HashMap; +import java.util.Map; + +/** + * This class analyzes a predicate tree to determine two characteristics: + * 1) The sub-tree size for each conjunction/disjunction node. + * 2) The min-feature value: a lower bound of the number of term required to satisfy a predicate. This lower bound is + * an estimate which is guaranteed to be less than or equal to the real lower bound. + * + * @author bjorncs + */ +public class PredicateTreeAnalyzer { + + /** + * @param predicate The predicate tree. + * @return A result object containing the min-feature value, the tree size and sub-tree sizes. + */ + public static PredicateTreeAnalyzerResult analyzePredicateTree(Predicate predicate) { + AnalyzerContext context = new AnalyzerContext(); + int treeSize = aggregatePredicateStatistics(predicate, false, context); + int minFeature = ((int)Math.ceil(findMinFeature(predicate, false, context))) + (context.hasNegationPredicate ? 1 : 0); + return new PredicateTreeAnalyzerResult(minFeature, treeSize, context.subTreeSizes); + } + + // First analysis pass. Traverses tree in depth-first order. Determines the sub-tree sizes and counts the occurrences + // of each feature (used by min-feature calculation in second pass). + // Returns the size of the analyzed subtree. + private static int aggregatePredicateStatistics(Predicate predicate, boolean isNegated, AnalyzerContext context) { + if (predicate instanceof Negation) { + return aggregatePredicateStatistics(((Negation) predicate).getOperand(), !isNegated, context); + } else if (predicate instanceof Conjunction) { + return ((Conjunction)predicate).getOperands().stream() + .mapToInt(child -> { + int size = aggregatePredicateStatistics(child, isNegated, context); + context.subTreeSizes.put(child, size); + return size; + }).sum(); + } else if (predicate instanceof FeatureConjunction) { + if (isNegated) { + context.hasNegationPredicate = true; + return 2; + } + // Count the number of identical feature conjunctions - use the id from IndexableFeatureConjunction as key + IndexableFeatureConjunction ifc = new IndexableFeatureConjunction((FeatureConjunction) predicate); + incrementOccurrence(context.conjunctionOccurrences, ifc.id); + // Handled as leaf in interval algorithm - count a single child + return 1; + } else if (predicate instanceof Disjunction) { + return ((Disjunction)predicate).getOperands().stream() + .mapToInt(child -> aggregatePredicateStatistics(child, isNegated, context)).sum(); + } else if (predicate instanceof FeatureSet) { + if (isNegated) { + context.hasNegationPredicate = true; + return 2; + } else { + FeatureSet featureSet = (FeatureSet) predicate; + for (String value : featureSet.getValues()) { + incrementOccurrence(context.featureOccurrences, Feature.createHash(featureSet.getKey(), value)); + } + return 1; + } + } else if (predicate instanceof FeatureRange) { + if (isNegated) { + context.hasNegationPredicate = true; + return 2; + } else { + incrementOccurrence(context.featureOccurrences, PredicateHash.hash64(((FeatureRange) predicate).getKey())); + return 1; + } + } else { + throw new UnsupportedOperationException("Cannot handle predicate of type " + predicate.getClass().getSimpleName()); + } + } + + // Second analysis pass. Traverses tree in depth-first order. Determines the min-feature value. + private static double findMinFeature(Predicate predicate, boolean isNegated, AnalyzerContext context) { + if (predicate instanceof Conjunction) { + // Sum of children values. + return ((Conjunction) predicate).getOperands().stream() + .mapToDouble(child -> findMinFeature(child, isNegated, context)) + .sum(); + } else if (predicate instanceof FeatureConjunction) { + if (isNegated) { + return 0.0; + } + // The FeatureConjunction is handled as a leaf node in the interval algorithm. + IndexableFeatureConjunction ifc = new IndexableFeatureConjunction((FeatureConjunction) predicate); + return 1.0 / context.conjunctionOccurrences.get(ifc.id); + } else if (predicate instanceof Disjunction) { + // Minimum value of children. + return ((Disjunction) predicate).getOperands().stream() + .mapToDouble(child -> findMinFeature(child, isNegated, context)) + .min() + .getAsDouble(); + } else if (predicate instanceof Negation) { + return findMinFeature(((Negation) predicate).getOperand(), !isNegated, context); + } else if (predicate instanceof FeatureSet) { + if (isNegated) { + return 0.0; + } + double minFeature = 1.0; + FeatureSet featureSet = (FeatureSet) predicate; + for (String value : featureSet.getValues()) { + long featureHash = Feature.createHash(featureSet.getKey(), value); + // Clever mathematics to handle scenarios where same feature is used several places in predicate tree. + minFeature = Math.min(minFeature, 1.0 / context.featureOccurrences.get(featureHash)); + } + return minFeature; + } else if (predicate instanceof FeatureRange) { + if (isNegated) { + return 0.0; + } + return 1.0 / context.featureOccurrences.get(PredicateHash.hash64(((FeatureRange) predicate).getKey())); + } else { + throw new UnsupportedOperationException("Cannot handle predicate of type " + predicate.getClass().getSimpleName()); + } + } + + private static void incrementOccurrence(Map<Long, Integer> featureOccurrences, long featureHash) { + featureOccurrences.merge(featureHash, 1, Integer::sum); + } + + // Data structure to hold aggregated data during analysis. + private static class AnalyzerContext { + // Mapping from feature hash to occurrence count. + public final Map<Long, Integer> featureOccurrences = new HashMap<>(); + // Mapping from conjunction id to occurrence count. + public final Map<Long, Integer> conjunctionOccurrences = new HashMap<>(); + // Mapping from predicate to sub-tree size. + public final Map<Predicate, Integer> subTreeSizes = new HashMap<>(); + // Does the tree contain any Negation nodes? + public boolean hasNegationPredicate = false; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzerResult.java b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzerResult.java new file mode 100644 index 00000000000..ab0991272dd --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzerResult.java @@ -0,0 +1,24 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.annotator; + +import com.yahoo.document.predicate.Predicate; + +import java.util.Map; + +/** + * Holds the results from {@link com.yahoo.search.predicate.annotator.PredicateTreeAnalyzer#analyzePredicateTree(com.yahoo.document.predicate.Predicate)}. + * + * @author bjorncs + */ +public class PredicateTreeAnalyzerResult { + + public final int minFeature; + public final int treeSize; + public final Map<Predicate, Integer> sizeMap; + + public PredicateTreeAnalyzerResult(int minFeature, int treeSize, Map<Predicate, Integer> sizeMap) { + this.minFeature = minFeature; + this.treeSize = treeSize; + this.sizeMap = sizeMap; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotations.java b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotations.java new file mode 100644 index 00000000000..0edf505e7f1 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotations.java @@ -0,0 +1,35 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.annotator; + +import com.yahoo.search.predicate.index.IntervalWithBounds; +import com.yahoo.search.predicate.index.conjunction.IndexableFeatureConjunction; + +import java.util.List; +import java.util.Map; + +/** + * Holds annotations for all the features of a predicate. + * This is sufficient information to insert the predicate into a PredicateIndex. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class PredicateTreeAnnotations { + public final int minFeature; + public final int intervalEnd; + public final Map<Long, List<Integer>> intervalMap; + public final Map<Long, List<IntervalWithBounds>> boundsMap; + public final Map<IndexableFeatureConjunction, List<Integer>> featureConjunctions; + + public PredicateTreeAnnotations( + int minFeature, + int intervalEnd, + Map<Long, List<Integer>> intervalMap, + Map<Long, List<IntervalWithBounds>> boundsMap, + Map<IndexableFeatureConjunction, List<Integer>> featureConjunctions) { + this.minFeature = minFeature; + this.intervalEnd = intervalEnd; + this.intervalMap = intervalMap; + this.boundsMap = boundsMap; + this.featureConjunctions = featureConjunctions; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotator.java b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotator.java new file mode 100644 index 00000000000..9c34d459ab3 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotator.java @@ -0,0 +1,178 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.annotator; + +import com.yahoo.document.predicate.Conjunction; +import com.yahoo.document.predicate.Disjunction; +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.FeatureRange; +import com.yahoo.document.predicate.FeatureSet; +import com.yahoo.document.predicate.Negation; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.document.predicate.PredicateHash; +import com.yahoo.document.predicate.RangeEdgePartition; +import com.yahoo.document.predicate.RangePartition; +import com.yahoo.search.predicate.index.Feature; +import com.yahoo.search.predicate.index.conjunction.IndexableFeatureConjunction; +import com.yahoo.search.predicate.index.Interval; +import com.yahoo.search.predicate.index.IntervalWithBounds; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * Performs the labelling of the predicate tree. The algorithm is based on the label algorithm described in + * <a href="http://dl.acm.org/citation.cfm?id=1807171">Efficiently evaluating complex boolean expressions</a>. + * @author bjorncs + * @see <a href="http://dl.acm.org/citation.cfm?id=1807171">Efficiently evaluating complex boolean expressions</a> + */ +public class PredicateTreeAnnotator { + + private PredicateTreeAnnotator() {} + + /** + * Labels the predicate tree by constructing an interval mapping for each predicate node in the tree. + * @param predicate The predicate tree. + * @return Returns a result object containing the interval mapping and the min-feature value. + */ + public static PredicateTreeAnnotations createPredicateTreeAnnotations(Predicate predicate) { + PredicateTreeAnalyzerResult analyzerResult = PredicateTreeAnalyzer.analyzePredicateTree(predicate); + // The tree size is used as the interval range. + int intervalEnd = analyzerResult.treeSize; + AnnotatorContext context = new AnnotatorContext(intervalEnd, analyzerResult.sizeMap); + assignIntervalLabels(predicate, Interval.INTERVAL_BEGIN, intervalEnd, false, context); + return new PredicateTreeAnnotations( + analyzerResult.minFeature, intervalEnd, context.intervals, context.intervalsWithBounds, + context.featureConjunctions); + } + + /** + * Visits the predicate tree in depth-first order and assigns intervals for features in + * {@link com.yahoo.document.predicate.FeatureSet} and {@link com.yahoo.document.predicate.FeatureRange}. + */ + private static void assignIntervalLabels( + Predicate predicate, int begin, int end, boolean isNegated, AnnotatorContext context) { + // Assumes that all negations happen directly on leaf-nodes. + // Otherwise, conjunctions and disjunctions must be switched if negated (De Morgan's law). + if (predicate instanceof Conjunction) { + List<Predicate> children = ((Conjunction) predicate).getOperands(); + int current = begin; + for (int i = 0; i < children.size(); i++) { + Predicate child = children.get(i); + int subTreeSize = context.subTreeSizes.get(child); + if (i == children.size() - 1) { // Last child (and sometimes the only one) + assignIntervalLabels(child, current, end, isNegated, context); + // No need to update/touch current since this is the last child. + } else if (i == 0) { // First child + int next = context.leftNodeLeaves + subTreeSize + 1; + assignIntervalLabels(child, current, next - 1, isNegated, context); + current = next; + } else { // Middle children + int next = current + subTreeSize; + assignIntervalLabels(child, current, next - 1, isNegated, context); + current = next; + } + } + } else if (predicate instanceof FeatureConjunction) { + // Register FeatureConjunction as it was a FeatureSet with a single child. + // Note: FeatureConjunction should never be negated as AndOrSimplifier will push negations down to + // the leafs (FeatureSets). + int zStarEnd = isNegated ? calculateZStarIntervalEnd(end, context) : end; + IndexableFeatureConjunction indexable = new IndexableFeatureConjunction((FeatureConjunction)predicate); + int interval = Interval.fromBoundaries(begin, zStarEnd); + context.featureConjunctions.computeIfAbsent(indexable, (k) -> new ArrayList<>()).add(interval); + if (isNegated) { + registerZStarInterval(begin, end, zStarEnd, context); + } + context.leftNodeLeaves += 1; + } else if (predicate instanceof Disjunction) { + // All OR children will have the same {begin, end} values, and + // the values will be same as that of the parent OR node + for (Predicate child : ((Disjunction) predicate).getOperands()) { + assignIntervalLabels(child, begin, end, isNegated, context); + } + } else if (predicate instanceof FeatureSet) { + FeatureSet featureSet = (FeatureSet) predicate; + int zStarEnd = isNegated ? calculateZStarIntervalEnd(end, context) : end; + for (String value : featureSet.getValues()) { + long featureHash = Feature.createHash(featureSet.getKey(), value); + int interval = Interval.fromBoundaries(begin, zStarEnd); + registerFeatureInterval(featureHash, interval, context.intervals); + } + if (isNegated) { + registerZStarInterval(begin, end, zStarEnd, context); + } + context.leftNodeLeaves += 1; + } else if (predicate instanceof Negation) { + assignIntervalLabels(((Negation) predicate).getOperand(), begin, end, !isNegated, context); + } else if (predicate instanceof FeatureRange) { + FeatureRange featureRange = (FeatureRange) predicate; + int zStarEnd = isNegated ? calculateZStarIntervalEnd(end, context) : end; + int interval = Interval.fromBoundaries(begin, zStarEnd); + for (RangePartition partition : featureRange.getPartitions()) { + long featureHash = PredicateHash.hash64(partition.getLabel()); + registerFeatureInterval(featureHash, interval, context.intervals); + } + for (RangeEdgePartition edgePartition : featureRange.getEdgePartitions()) { + long featureHash = PredicateHash.hash64(edgePartition.getLabel()); + IntervalWithBounds intervalWithBounds = new IntervalWithBounds( + interval, (int) edgePartition.encodeBounds()); + registerFeatureInterval(featureHash, intervalWithBounds, context.intervalsWithBounds); + } + if (isNegated) { + registerZStarInterval(begin, end, zStarEnd, context); + } + context.leftNodeLeaves += 1; + } else { + throw new UnsupportedOperationException( + "Cannot handle predicate of type " + predicate.getClass().getSimpleName()); + } + } + + private static void registerZStarInterval(int begin, int end, int zStarIntervalEnd, AnnotatorContext context) { + int interval = Interval.fromZStar1Boundaries(begin - 1, zStarIntervalEnd); + registerFeatureInterval(Feature.Z_STAR_COMPRESSED_ATTRIBUTE_HASH, interval, context.intervals); + if (end - zStarIntervalEnd != 1) { + int extraInterval = Interval.fromZStar2Boundaries(end); + registerFeatureInterval(Feature.Z_STAR_COMPRESSED_ATTRIBUTE_HASH, extraInterval, context.intervals); + } + context.leftNodeLeaves += 1; + } + + private static int calculateZStarIntervalEnd(int end, AnnotatorContext context) { + if (!context.finalRangeUsed && end == context.intervalEnd) { + // Extend the first interval to intervalEnd - 1 to get a second Z* interval of size 1. + context.finalRangeUsed = true; + return context.intervalEnd - 1; + } + return context.leftNodeLeaves + 1; + } + + private static <T> void registerFeatureInterval(long featureHash, T interval, Map<Long, List<T>> intervals) { + intervals.computeIfAbsent(featureHash, (k) -> new ArrayList<>()).add(interval); + } + + // Data structure to hold aggregated data during traversal of the predicate tree. + private static class AnnotatorContext { + // End of interval + public final int intervalEnd; + // Mapping from feature to a list of intervals. + public final Map<Long, List<Integer>> intervals = new HashMap<>(); + // Mapping from a range feature to a list of intervals with bounds. + public final Map<Long, List<IntervalWithBounds>> intervalsWithBounds = new HashMap<>(); + // List of feature conjunctions from predicate + public final Map<IndexableFeatureConjunction, List<Integer>> featureConjunctions = new HashMap<>(); + // Mapping from predicate to sub-tree size. + public final Map<Predicate, Integer> subTreeSizes; + // Number of prior leaf nodes visited. + public int leftNodeLeaves = 0; + // Is final interval range used? (Only relevant for Z* interval) + public boolean finalRangeUsed = false; + + public AnnotatorContext(int intervalEnd, Map<Predicate, Integer> subTreeSizes) { + this.intervalEnd = intervalEnd; + this.subTreeSizes = subTreeSizes; + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/HitsVerificationBenchmark.java b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/HitsVerificationBenchmark.java new file mode 100644 index 00000000000..1e63fed737d --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/HitsVerificationBenchmark.java @@ -0,0 +1,189 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.benchmarks; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.yahoo.search.predicate.Config; +import com.yahoo.search.predicate.Hit; +import com.yahoo.search.predicate.PredicateIndex; +import com.yahoo.search.predicate.PredicateIndexBuilder; +import com.yahoo.search.predicate.PredicateQuery; +import com.yahoo.search.predicate.serialization.PredicateQuerySerializer; +import com.yahoo.search.predicate.utils.VespaFeedParser; +import com.yahoo.search.predicate.utils.VespaQueryParser; +import io.airlift.airline.Arguments; +import io.airlift.airline.Command; +import io.airlift.airline.HelpOption; +import io.airlift.airline.Option; +import io.airlift.airline.SingleCommand; + +import javax.inject.Inject; +import java.io.BufferedInputStream; +import java.io.BufferedWriter; +import java.io.DataInputStream; +import java.io.FileInputStream; +import java.io.FileWriter; +import java.io.IOException; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.TreeMap; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import static com.yahoo.search.predicate.benchmarks.HitsVerificationBenchmark.BenchmarkArguments.*; +import static java.util.stream.Collectors.joining; +import static java.util.stream.Collectors.toList; + +/** + * A test that runs outputs the hits for each query into result file. + * + * @author bjorncs + */ +public class HitsVerificationBenchmark { + + public static void main(String[] rawArgs) throws IOException { + Optional<BenchmarkArguments> wrappedArgs = getArguments(rawArgs); + if (!wrappedArgs.isPresent()) return; + BenchmarkArguments args = wrappedArgs.get(); + Map<String, Object> output = new TreeMap<>(); + addArgsToOutput(output, args); + + Config config = new Config.Builder() + .setArity(args.arity) + .setUseConjunctionAlgorithm(args.algorithm == Algorithm.CONJUNCTION) + .build(); + + PredicateIndex index = getIndex(args, config, output); + + Stream<PredicateQuery> queries = parseQueries(args.format, args.queryFile); + int totalHits = runQueries(index, queries, args.outputFile); + output.put("Total hits", totalHits); + writeOutputToStandardOut(output); + } + + private static PredicateIndex getIndex(BenchmarkArguments args, Config config, Map<String, Object> output) throws IOException { + if (args.feedFile != null) { + PredicateIndexBuilder builder = new PredicateIndexBuilder(config); + AtomicInteger idCounter = new AtomicInteger(); + VespaFeedParser.parseDocuments( + args.feedFile, Integer.MAX_VALUE, p -> builder.indexDocument(idCounter.incrementAndGet(), p)); + builder.getStats().putValues(output); + return builder.build(); + } else { + try (DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(args.indexFile)))) { + long start = System.currentTimeMillis(); + PredicateIndex index = PredicateIndex.fromInputStream(in); + output.put("Time deserialize index", System.currentTimeMillis() - start); + return index; + } + } + } + + private static int runQueries( + PredicateIndex index, Stream<PredicateQuery> queries, String outputFile) throws IOException { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(outputFile, false))) { + AtomicInteger i = new AtomicInteger(); + PredicateIndex.Searcher searcher = index.searcher(); + return queries.map(searcher::search) + .peek(hits -> {if (i.get() % 500 == 0) {index.rebuildPostingListCache();}}) + .mapToInt(hits -> writeHits(i.getAndIncrement(), hits, writer)) + .sum(); + + } + } + + private static Stream<PredicateQuery> parseQueries(Format format, String queryFile) + throws IOException { + PredicateQuerySerializer serializer = new PredicateQuerySerializer(); + return Files.lines(Paths.get(queryFile)) + .map(line -> + format == Format.JSON + ? serializer.fromJSON(line) + : VespaQueryParser.parseQueryFromQueryProperties(line)); + + } + + private static int writeHits(int i, Stream<Hit> hitStream, BufferedWriter writer) { + try { + List<Hit> hits = hitStream.collect(toList()); + writer.append(Integer.toString(i)) + .append(": ") + .append(hits.stream() + .map(hit -> String.format("(%d, 0x%x)", hit.getDocId(), hit.getSubquery())) + .collect(joining(", ", "[", "]"))) + .append("\n\n"); + return hits.size(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + private static Optional<BenchmarkArguments> getArguments(String[] rawArgs) { + BenchmarkArguments args = SingleCommand.singleCommand(BenchmarkArguments.class).parse(rawArgs); + if (args.helpOption.showHelpIfRequested()) { + return Optional.empty(); + } + if (args.feedFile == null && args.indexFile == null) { + System.err.println("Provide either a feed file or index file."); + return Optional.empty(); + } + return Optional.of(args); + + } + + private static void addArgsToOutput(Map<String, Object> output, BenchmarkArguments args) { + output.put("Arity", args.arity); + output.put("Algorithm", args.algorithm); + output.put("Query format", args.format); + output.put("Feed file", args.feedFile); + output.put("Query file", args.queryFile); + output.put("Output file", args.outputFile); + output.put("Index file", args.indexFile); + } + + private static void writeOutputToStandardOut(Map<String, Object> output) { + try { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.enable(SerializationFeature.INDENT_OUTPUT); + objectMapper.writeValue(System.out, output); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Command(name = "hits-verifier", + description = "Java predicate search system test that outputs the returned hits for each query") + public static class BenchmarkArguments { + + public enum Format{JSON, VESPA} + public enum Algorithm{CONJUNCTION, INTERVALONLY} + + @Option(name = {"-a", "--arity"}, description = "Arity") + public int arity = 2; + + @Option(name = {"-al", "--algorithm"}, description = "Algorithm (CONJUNCTION or INTERVALONLY)") + public Algorithm algorithm = Algorithm.INTERVALONLY; + + @Option(name = {"-qf", "--query-format"}, description = + "Query format. Valid formats are either 'vespa' (obsolete query property format) or 'json'.") + public Format format = Format.VESPA; + + @Option(name = {"-ff", "--feed-file"}, description = "File path to feed file (Vespa XML feed)") + public String feedFile; + + @Option(name = {"-if", "--index-file"}, description = "File path to index file (Serialized index)") + public String indexFile; + + @Option(name = {"-quf", "--query-file"}, description = "File path to a query file") + public String queryFile; + + @Arguments(title = "Output file", description = "File path to output file") + public String outputFile; + + @Inject + public HelpOption helpOption; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/PredicateIndexBenchmark.java b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/PredicateIndexBenchmark.java new file mode 100644 index 00000000000..f3518edd930 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/PredicateIndexBenchmark.java @@ -0,0 +1,297 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.benchmarks; + +import com.fasterxml.jackson.databind.ObjectMapper; +import com.fasterxml.jackson.databind.SerializationFeature; +import com.google.common.collect.Iterators; +import com.yahoo.search.predicate.Config; +import com.yahoo.search.predicate.PredicateIndex; +import com.yahoo.search.predicate.PredicateIndexBuilder; +import com.yahoo.search.predicate.PredicateQuery; +import com.yahoo.search.predicate.serialization.PredicateQuerySerializer; +import com.yahoo.search.predicate.utils.VespaFeedParser; +import com.yahoo.search.predicate.utils.VespaQueryParser; +import io.airlift.airline.Command; +import io.airlift.airline.HelpOption; +import io.airlift.airline.Option; +import io.airlift.airline.SingleCommand; + +import javax.inject.Inject; +import java.io.BufferedInputStream; +import java.io.BufferedOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Random; +import java.util.TreeMap; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static com.yahoo.search.predicate.benchmarks.PredicateIndexBenchmark.BenchmarkArguments.Algorithm; +import static com.yahoo.search.predicate.benchmarks.PredicateIndexBenchmark.BenchmarkArguments.Format; +import static java.util.stream.Collectors.toList; + +/** + * A benchmark that tests the indexing and search performance. + * + * @author bjorncs + */ +public class PredicateIndexBenchmark { + + private static final Map<String, Object> output = new TreeMap<>(); + + public static void main(String[] rawArgs) throws IOException { + Optional<BenchmarkArguments> optionalArgs = getBenchmarkArguments(rawArgs); + if (!optionalArgs.isPresent()) return; + BenchmarkArguments args = optionalArgs.get(); + + putBenchmarkArgumentsToOutput(args); + + long start = System.currentTimeMillis(); + Config config = new Config.Builder() + .setArity(args.arity) + .setUseConjunctionAlgorithm(args.algorithm == Algorithm.CONJUNCTION) + .build(); + PredicateIndex index = getIndex(args, config); + if (args.indexOutputFile != null) { + writeIndexToFile(index, args.indexOutputFile); + } + if (args.queryFile != null) { + runQueries(args, index); + } + output.put("Total time", System.currentTimeMillis() - start); + output.put("Timestamp", new Date().toString()); + writeOutputToStandardOut(); + } + + private static Optional<BenchmarkArguments> getBenchmarkArguments(String[] rawArgs) { + BenchmarkArguments args = SingleCommand.singleCommand(BenchmarkArguments.class).parse(rawArgs); + if (args.helpOption.showHelpIfRequested()) { + return Optional.empty(); + } + if (args.feedFile == null && args.indexFile == null) { + System.err.println("Provide either a feed file or index file."); + return Optional.empty(); + } + return Optional.of(args); + } + + private static PredicateIndex getIndex(BenchmarkArguments args, Config config) throws IOException { + if (args.feedFile != null) { + PredicateIndexBuilder builder = new PredicateIndexBuilder(config); + long start = System.currentTimeMillis(); + AtomicInteger idCounter = new AtomicInteger(); + int documentCount = VespaFeedParser.parseDocuments( + args.feedFile, args.maxDocuments, p -> builder.indexDocument(idCounter.incrementAndGet(), p)); + output.put("Indexed document count", documentCount); + output.put("Time indexing documents", System.currentTimeMillis() - start); + builder.getStats().putValues(output); + + start = System.currentTimeMillis(); + PredicateIndex index = builder.build(); + output.put("Time prepare index", System.currentTimeMillis() - start); + return index; + } else { + try (DataInputStream in = new DataInputStream(new BufferedInputStream(new FileInputStream(args.indexFile)))) { + long start = System.currentTimeMillis(); + PredicateIndex index = PredicateIndex.fromInputStream(in); + output.put("Time deserialize index", System.currentTimeMillis() - start); + return index; + } + } + } + + private static void writeIndexToFile(PredicateIndex index, String indexOutputFile) throws IOException { + try (DataOutputStream out = new DataOutputStream(new BufferedOutputStream(new FileOutputStream(indexOutputFile)))) { + long start = System.currentTimeMillis(); + index.writeToOutputStream(out); + output.put("Time write index", System.currentTimeMillis() - start); + } + } + + private static void putBenchmarkArgumentsToOutput(BenchmarkArguments args) { + output.put("Arity", args.arity); + output.put("Max documents", args.maxDocuments); + output.put("Max queries", args.maxQueries); + output.put("Threads", args.nThreads); + output.put("Runtime", args.runtime); + output.put("Algorithm", args.algorithm); + output.put("Serialized index output file", args.indexOutputFile); + output.put("Feed file", args.feedFile); + output.put("Query file", args.queryFile); + output.put("Index file", args.indexFile); + output.put("Query format", args.format); + output.put("Warmup", args.warmup); + } + + private static void runQueries(BenchmarkArguments args, PredicateIndex index) throws IOException { + List<PredicateQuery> queries = parseQueries(args.queryFile, args.maxQueries, args.format); + long warmup1 = warmup(queries, index, args.nThreads, args.warmup / 2); + output.put("Time warmup before building posting cache", warmup1); + rebuildPostingListCache(index); + long warmup2 = warmup(queries, index, args.nThreads, args.warmup / 2); + output.put("Time warmup after building posting cache", warmup2); + searchIndex(queries, index, args.nThreads, args.runtime); + } + + private static void rebuildPostingListCache(PredicateIndex index) { + long start = System.currentTimeMillis(); + index.rebuildPostingListCache(); + output.put("Time rebuild posting list cache", System.currentTimeMillis() - start); + } + + private static List<PredicateQuery> parseQueries(String queryFile, int maxQueryCount, Format format) throws IOException { + long start = System.currentTimeMillis(); + List<PredicateQuery> queries = format == Format.VESPA ? + VespaQueryParser.parseQueries(queryFile, maxQueryCount) : + PredicateQuerySerializer.parseQueriesFromFile(queryFile, maxQueryCount); + output.put("Time parse queries", System.currentTimeMillis() - start); + output.put("Queries parsed", queries.size()); + return queries; + } + + private static long warmup(List<PredicateQuery> queries, PredicateIndex index, int nThreads, int warmup) { + ExecutorService executor = Executors.newFixedThreadPool(nThreads); + Random random = new Random(42); + for (int i = 0; i < nThreads; i++) { + List<PredicateQuery> shuffledQueries = new ArrayList<>(queries); + Collections.shuffle(shuffledQueries, random); + executor.submit(new QueryRunner(shuffledQueries, index.searcher())); + } + long start = System.currentTimeMillis(); + waitAndShutdown(warmup, executor); + return System.currentTimeMillis() - start; + } + + private static void searchIndex(List<PredicateQuery> queries, PredicateIndex index, int nThreads, int runtime) { + ExecutorService executor = Executors.newFixedThreadPool(nThreads); + Random random = new Random(42); + List<QueryRunner> runners = new ArrayList<>(); + for (int i = 0; i < nThreads; i++) { + List<PredicateQuery> shuffledQueries = new ArrayList<>(queries); + Collections.shuffle(shuffledQueries, random); + runners.add(new QueryRunner(shuffledQueries, index.searcher())); + } + long start = System.currentTimeMillis(); + List<Future<ResultMetrics>> futureResults = runners.stream().map(executor::submit).collect(toList()); + waitAndShutdown(runtime, executor); + long searchTime = System.currentTimeMillis() - start; + getResult(futureResults).writeMetrics(output, searchTime); + } + + private static void waitAndShutdown(int warmup, ExecutorService executor) { + try { + Thread.sleep(warmup * 1000); + executor.shutdownNow(); + executor.awaitTermination(2, TimeUnit.SECONDS); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + } + + private static ResultMetrics getResult(List<Future<ResultMetrics>> futureResults) { + try { + ResultMetrics combined = futureResults.get(0).get(); + for (int i = 1; i < futureResults.size(); i++) { + combined.combine(futureResults.get(i).get()); + } + return combined; + } catch (InterruptedException | ExecutionException e) { + throw new RuntimeException(e); + } + } + + private static class QueryRunner implements Callable<ResultMetrics> { + private final List<PredicateQuery> queries; + private final PredicateIndex.Searcher searcher; + + public QueryRunner(List<PredicateQuery> queries, PredicateIndex.Searcher seacher) { + this.queries = queries; + this.searcher = seacher; + } + + @Override + public ResultMetrics call() throws Exception { + Iterator<PredicateQuery> iterator = Iterators.cycle(queries); + ResultMetrics result = new ResultMetrics(); + while (!Thread.interrupted()) { + long start = System.nanoTime(); + long hits = searcher.search(iterator.next()).count(); + double latencyMilliseconds = (System.nanoTime() - start) / 1_000_000d; + result.registerResult(hits, latencyMilliseconds); + } + return result; + } + } + + private static void writeOutputToStandardOut() { + try { + ObjectMapper objectMapper = new ObjectMapper(); + objectMapper.enable(SerializationFeature.INDENT_OUTPUT); + objectMapper.writeValue(System.out, output); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + @Command(name = "benchmark", description = "Java predicate search library benchmark") + public static class BenchmarkArguments { + public enum Format{JSON, VESPA} + public enum Algorithm{CONJUNCTION, INTERVALONLY} + + @Option(name = {"-t", "--threads"}, description = "Number of search threads") + public int nThreads = 1; + + @Option(name = {"-a", "--arity"}, description = "Arity") + public int arity = 2; + + @Option(name = {"-r", "--runtime"}, description = "Number of seconds to run queries") + public int runtime = 30; + + @Option(name = {"-md", "--max-documents"}, + description = "The maximum number of documents to index from feed file") + public int maxDocuments = Integer.MAX_VALUE; + + @Option(name = {"-mq", "--max-queries"}, description = "The maximum number of queries to run from query file") + public int maxQueries = Integer.MAX_VALUE; + + @Option(name = {"-al", "--algorithm"}, description = "Algorithm (CONJUNCTION or INTERVALONLY)") + public Algorithm algorithm = Algorithm.INTERVALONLY; + + @Option(name = {"-w", "--warmup"}, description = "Warmup in seconds.") + public int warmup = 30; + + @Option(name = {"-qf", "--query-format"}, + description = "Query format. Valid formats are either 'VESPA' (obsolete query property format) or 'JSON'.") + public Format format = Format.VESPA; + + @Option(name = {"-ff", "--feed-file"}, description = "File path to feed file (Vespa XML feed)") + public String feedFile; + + @Option(name = {"-if", "--index-file"}, description = "File path to index file (Serialized index)") + public String indexFile; + + @Option(name = {"-wi", "--write-index"}, description = "Serialize index to the given file") + public String indexOutputFile; + + @Option(name = {"-quf", "--query-file"}, description = "File path to a query file") + public String queryFile; + + @Inject + public HelpOption helpOption; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java new file mode 100644 index 00000000000..801937c995f --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/benchmarks/ResultMetrics.java @@ -0,0 +1,84 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.benchmarks; + +import java.util.Map; + +/** + * Various metrics stored during query execution + * + * @author bjorncs + */ +public class ResultMetrics { + private static final int MAX_LATENCY = 100; // ms + private static final int RESOLUTION = 25; // sample points per ms + private static final int SLOTS = MAX_LATENCY * RESOLUTION; + + private long totalQueries = 0; + private long totalHits = 0; + private double maxLatency = Double.MIN_VALUE; + private double minLatency = Double.MAX_VALUE; + private final long[] latencyHistogram = new long[SLOTS]; + + public void registerResult(long hits, double latencyMilliseconds) { + if (latencyMilliseconds > maxLatency) { + maxLatency = latencyMilliseconds; + } + if (latencyMilliseconds < minLatency) { + minLatency = latencyMilliseconds; + } + totalHits += hits; + ++totalQueries; + int latencySlot = (int) Math.round(latencyMilliseconds * RESOLUTION); + // Note: extreme latency values are ignored in the histogram for simplicity + if (latencySlot < SLOTS) { + ++latencyHistogram[latencySlot]; + } + } + + public void combine(ResultMetrics other) { + totalQueries += other.totalQueries; + minLatency = Math.min(minLatency, other.minLatency); + maxLatency = Math.max(maxLatency, other.maxLatency); + totalHits += other.totalHits; + for (int i = 0; i < SLOTS; i++) { + latencyHistogram[i] += other.latencyHistogram[i]; + } + } + + public void writeMetrics(Map<String, Object> metricMap, long timeSearch) { + double qps = timeSearch == 0 ? 0 : (1000d * totalQueries / timeSearch); + metricMap.put("QPS", qps); + metricMap.put("Time search", timeSearch); + metricMap.put("Total hits", totalHits); + metricMap.put("Total queries", totalQueries); + metricMap.put("Max latency", latencyToString(maxLatency)); + metricMap.put("Min latency", latencyToString(minLatency)); + metricMap.put("99.9 percentile", latencyToString(percentile(0.999))); + metricMap.put("99 percentile", latencyToString(percentile(0.99))); + metricMap.put("90 percentile", latencyToString(percentile(0.90))); + metricMap.put("75 percentile", latencyToString(percentile(0.75))); + metricMap.put("50 percentile", latencyToString(percentile(0.50))); + } + + private double percentile(double percentile) { + int targetCount = (int) Math.round(totalQueries * percentile); + int currentCount = 0; + int index = 0; + while (currentCount < targetCount && index < SLOTS) { + currentCount += latencyHistogram[index]; + ++index; + } + if (index == SLOTS) { + return maxLatency; + } + return toLatency(currentCount == targetCount ? index + 1 : index); + } + + private static String latencyToString(double averageLatency) { + return String.format("%.2fms", averageLatency); + } + + private static double toLatency(int index) { + return (index + 0.5) / (double) RESOLUTION; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/BoundsPostingList.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/BoundsPostingList.java new file mode 100644 index 00000000000..d17b6589693 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/BoundsPostingList.java @@ -0,0 +1,49 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +/** + * Wraps a posting stream of IntervalWithBounds objects (for collapsed + * fixed tree leaf nodes) into a PostingList. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class BoundsPostingList extends MultiIntervalPostingList { + private final int valueDiff; + private final IntervalWithBounds intervalWithBounds = new IntervalWithBounds(); + private final PredicateIntervalStore store; + private int currentInterval; + + /** + * @param valueDiff Difference from the collapsed leaf node's actual value. + */ + public BoundsPostingList(PredicateIntervalStore store, int[] docIds, int[] dataRefs, long subquery, int valueDiff) { + super(docIds, dataRefs, subquery); + this.valueDiff = valueDiff; + this.store = store; + } + + @Override + protected boolean prepareIntervals(int dataRef) { + intervalWithBounds.setIntervalArray(store.get(dataRef), 0); + return nextInterval(); + } + + @Override + public boolean nextInterval() { + while (intervalWithBounds.hasValue()) { + if (intervalWithBounds.contains(valueDiff)) { + this.currentInterval = intervalWithBounds.getInterval(); + intervalWithBounds.nextValue(); + return true; + } + intervalWithBounds.nextValue(); + } + return false; + } + + @Override + public int getInterval() { + return currentInterval; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java new file mode 100644 index 00000000000..d19357cd8ab --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/CachedPostingListCounter.java @@ -0,0 +1,134 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.google.common.collect.MinMaxPriorityQueue; +import com.gs.collections.api.tuple.primitive.ObjectLongPair; +import com.gs.collections.impl.map.mutable.primitive.ObjectIntHashMap; +import com.gs.collections.impl.map.mutable.primitive.ObjectLongHashMap; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * Counts the number of posting lists per document id. + * Caches the most expensive posting list in a bit vector. + * + * @author bjorncs + */ +public class CachedPostingListCounter { + // Only use bit vector for counting if the documents covered is more than the threshold (relative to nDocuments) + private static final double THRESHOLD_USE_BIT_VECTOR = 1; + + private final int nDocuments; + private final ObjectLongHashMap<int[]> frequency = new ObjectLongHashMap<>(); + private final ObjectIntHashMap<int[]> postingListMapping; + private final int[] bitVector; + + public CachedPostingListCounter(int nDocuments) { + this.nDocuments = nDocuments; + this.postingListMapping = new ObjectIntHashMap<>(); + this.bitVector = new int[0]; + } + + private CachedPostingListCounter(ObjectIntHashMap<int[]> postingListMapping, int[] bitVector) { + this.nDocuments = bitVector.length; + this.postingListMapping = postingListMapping; + this.bitVector = bitVector; + } + + public synchronized void registerUsage(List<PostingList> postingLists) { + for (PostingList postingList : postingLists) { + frequency.updateValue(postingList.getDocIds(), 0, v -> v + 1); + } + } + + public void countPostingListsPerDocument(List<PostingList> postingLists, byte[] nPostingListsForDocument) { + Arrays.fill(nPostingListsForDocument, (byte) 0); + List<int[]> nonCachedPostingLists = new ArrayList<>(postingLists.size()); + List<int[]> cachedPostingLists = new ArrayList<>(postingLists.size()); + long nDocumentsCachedPostingLists = 0; + int postingListBitmap = 0; + for (PostingList postingList : postingLists) { + int[] docIds = postingList.getDocIds(); + int index = postingListMapping.getIfAbsent(docIds, -1); + if (index >= 0) { + cachedPostingLists.add(docIds); + postingListBitmap |= (1 << index); + nDocumentsCachedPostingLists += docIds.length; + } else { + nonCachedPostingLists.add(docIds); + } + } + if (postingListBitmap != 0) { + if (nDocumentsCachedPostingLists > nDocuments * THRESHOLD_USE_BIT_VECTOR) { + countUsingBitVector(nPostingListsForDocument, postingListBitmap); + } else { + nonCachedPostingLists.addAll(cachedPostingLists); + } + } + if (!nonCachedPostingLists.isEmpty()) { + countUsingDocIdIteration(nPostingListsForDocument, nonCachedPostingLists); + } + } + + private void countUsingBitVector(byte[] nPostingListsForDocument, int postingListBitmap) { + for (int docId = 0; docId < nDocuments; docId++) { + nPostingListsForDocument[docId] += Integer.bitCount(bitVector[docId] & postingListBitmap); + } + } + + private static void countUsingDocIdIteration(byte[] nPostingListsForDocument, List<int[]> nonCachedPostingLists) { + for (int[] docIds : nonCachedPostingLists) { + for (int docId : docIds) { + ++nPostingListsForDocument[docId]; + } + } + } + + public CachedPostingListCounter rebuildCache() { + MinMaxPriorityQueue<Entry> mostExpensive = MinMaxPriorityQueue + .maximumSize(32).expectedSize(32).create(); + synchronized (this) { + for (ObjectLongPair<int[]> p : frequency.keyValuesView()) { + mostExpensive.add(new Entry(p.getOne(), p.getTwo())); + } + } + ObjectIntHashMap<int[]> postingListMapping = new ObjectIntHashMap<>(); + int[] bitVector = new int[nDocuments]; + int length = mostExpensive.size(); + for (int i = 0; i < length; i++) { + Entry e = mostExpensive.removeFirst(); + int[] docIds = e.docIds; + postingListMapping.put(docIds, i); + for (int docId : docIds) { + bitVector[docId] |= (1 << i); + } + } + return new CachedPostingListCounter(postingListMapping, bitVector); + } + + int[] getBitVector() { + return bitVector; + } + + ObjectIntHashMap<int[]> getPostingListMapping() { + return postingListMapping; + } + + private static class Entry implements Comparable<Entry> { + public final int[] docIds; + public final double cost; + + private Entry(int[] docIds, long frequency) { + this.docIds = docIds; + this.cost = docIds.length * (double) frequency; + assert cost > 0; + } + + @Override + public int compareTo(Entry o) { + return -Double.compare(cost, o.cost); + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/Feature.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/Feature.java new file mode 100644 index 00000000000..6a998413ec0 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/Feature.java @@ -0,0 +1,20 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.document.predicate.PredicateHash; + +/** + * Utility class for feature related constants and methods. + * + * @author bjorncs + */ +public class Feature { + public static final String Z_STAR_COMPRESSED_ATTRIBUTE_NAME = "z-star-compressed"; + public static final long Z_STAR_COMPRESSED_ATTRIBUTE_HASH = PredicateHash.hash64(Z_STAR_COMPRESSED_ATTRIBUTE_NAME); + + private Feature() {} + + public static long createHash(String key, String value) { + return PredicateHash.hash64(key + "=" + value); + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/Interval.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/Interval.java new file mode 100644 index 00000000000..f63a13f3641 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/Interval.java @@ -0,0 +1,87 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +/** + * Utility class for interval related constants and methods. + * An interval consists of a begin and end value indicating the start and end of the interval. + * Both value are inclusive, eg (1,2) is an interval of size 2. + * + * There are 3 types of interval; normal, ZStar1 and ZStar2. + * + * Normal intervals have begin value in 16 MSB and end in 16 LSB. + * ZStar1 intervals have end value in 16 MSB and begin in 16 LSB. + * ZStar2 intervals have only an end value located at 16 LSB. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class Interval { + + public static final int INTERVAL_BEGIN = 0x01; + public static final int MAX_INTERVAL_END = 0xffff; + public static final int ZERO_CONSTRAINT_RANGE = 1; + + private Interval() {} + + public static int fromBoundaries(int begin, int end) { + assert begin >= INTERVAL_BEGIN && begin <= MAX_INTERVAL_END + && end >= INTERVAL_BEGIN && end <= MAX_INTERVAL_END : toString(begin, end); + return (begin << 16) | end; + } + + public static int fromZStar1Boundaries(int begin, int end) { + assert begin >= 0 && begin <= MAX_INTERVAL_END + && end >= INTERVAL_BEGIN && end <= MAX_INTERVAL_END : toString(end, begin); + return (end << 16) | begin; + } + + public static int fromZStar2Boundaries(int end) { + assert end >= INTERVAL_BEGIN && end <= MAX_INTERVAL_END : toString(0, end); + return end; + } + + public static boolean isZStar1Interval(int interval) { + return getBegin(interval) > getEnd(interval); + } + + public static boolean isZStar2Interval(int interval) { + return (interval & 0xffff0000) == 0; + } + + public static int getBegin(int interval) { + return interval >>> 16; + } + + public static int getEnd(int interval) { + return interval & 0xffff; + } + + public static int getZStar1Begin(int interval) { + return getEnd(interval); + } + + public static int getZStar1End(int interval) { + return getBegin(interval); + } + + public static int getZStar2End(int interval) { + return interval; + } + + /** + * @return A new ZStar1 interval with boundaries [end(zStar1)+1, end(zStar2)] + */ + public static int combineZStarIntervals(int zStar1, int zStar2) { + return zStar1 >>> 16 | zStar2 << 16; + } + + private static String toString(int begin, int end) { + if (begin == 0) { + return String.format("[%d]**", end); + } else if (begin > end) { + return String.format("[%d, %d]*", begin, end); + } + return String.format("[%d, %d]", begin, end); + } + +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/IntervalPostingList.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/IntervalPostingList.java new file mode 100644 index 00000000000..40e8bf39c98 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/IntervalPostingList.java @@ -0,0 +1,42 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +/** + * Implementation of PostingList for regular features that store + * their intervals and nothing else. + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class IntervalPostingList extends MultiIntervalPostingList { + private final PredicateIntervalStore store; + private int[] currentIntervals; + private int currentIntervalIndex; + private int currentInterval; + + public IntervalPostingList(PredicateIntervalStore store, int[] docIds, int[] dataRefs, long subquery) { + super(docIds, dataRefs, subquery); + this.store = store; + } + + @Override + protected boolean prepareIntervals(int dataRef) { + currentIntervals = store.get(dataRef); + currentIntervalIndex = 1; + currentInterval = currentIntervals[0]; + return true; + } + + @Override + public boolean nextInterval() { + if (currentIntervalIndex < currentIntervals.length) { + this.currentInterval = currentIntervals[currentIntervalIndex++]; + return true; + } + return false; + } + + @Override + public int getInterval() { + return currentInterval; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/IntervalWithBounds.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/IntervalWithBounds.java new file mode 100644 index 00000000000..1d21896e853 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/IntervalWithBounds.java @@ -0,0 +1,85 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import java.util.stream.Stream; + +/** + * Represents a collapsed leaf node in the fixed tree range representation. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class IntervalWithBounds { + + private int[] intervalBoundsArray; + private int arrayIndex; + + public IntervalWithBounds() { + setIntervalArray(null, 0); + } + public IntervalWithBounds(int interval, int bounds) { + setIntervalArray(new int[] {interval, bounds}, 0); + } + + public void setIntervalArray(int[] intervalBoundsArray, int arrayIndex) { + this.intervalBoundsArray = intervalBoundsArray; + this.arrayIndex = arrayIndex; + } + public boolean hasValue() { return arrayIndex < intervalBoundsArray.length - 1; } + public void nextValue() { arrayIndex += 2; } + + public Stream<Integer> stream() { return Stream.of(getInterval(), getBounds()); } + /** + * 16 MSB represents interval begin, 16 LSB represents interval end. + */ + public int getInterval() { + return intervalBoundsArray[arrayIndex]; + } + /* + * 2 MSB determines mode for remaining 30 bits. + * 10 => Greater or equal + * 01 => Less than + * 00 => 16 LSB > X >= 16 MSB + */ + public int getBounds() { + return intervalBoundsArray[arrayIndex + 1]; + } + + /** + * Checks if a value is contained within the specified bounds. + * @param value Value to check against + * @return true if value is contained within the specified bounds + */ + public boolean contains(int value) { + int bounds = getBounds(); + if ((bounds & 0x80000000) != 0) { + return value >= (bounds & 0x3fffffff); + } else if ((bounds & 0x40000000) != 0) { + return value < (bounds & 0x3fffffff); + } else { + return (value >= (bounds >> 16)) && (value < (bounds & 0xffff)); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + IntervalWithBounds that = (IntervalWithBounds) o; + return getInterval() == that.getInterval() && getBounds() == that.getBounds(); + } + + @Override + public int hashCode() { + return 31 * getInterval() + getBounds(); + } + + @Override + public String toString() { + return "IntervalWithBounds{" + + "interval=" + getInterval() + + ", bounds=" + getBounds() + + '}'; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/MultiIntervalPostingList.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/MultiIntervalPostingList.java new file mode 100644 index 00000000000..1811f11b621 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/MultiIntervalPostingList.java @@ -0,0 +1,67 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.search.predicate.utils.PostingListSearch; + +/** + * Shared implementation for posting lists that may have multiple intervals. + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public abstract class MultiIntervalPostingList implements PostingList { + private final int[] docIds; + private final int[] dataRefs; + private final long subquery; + private final int length; + private int currentIndex; + private int currentDocId; + + public MultiIntervalPostingList(int[] docIds, int[] dataRefs, long subquery) { + this.docIds = docIds; + this.dataRefs = dataRefs; + this.subquery = subquery; + this.length = docIds.length; + this.currentIndex = 0; + this.currentDocId = -1; + } + + @Override + public final boolean nextDocument(int docId) { + int index = currentIndex; + index = PostingListSearch.interpolationSearch(docIds, index, length, docId); + if (index == length) { + return false; + } + this.currentDocId = docIds[index]; + this.currentIndex = index; + assert currentDocId > docId; + return true; + } + + @Override + public final boolean prepareIntervals() { + return prepareIntervals(dataRefs[currentIndex]); + } + + protected abstract boolean prepareIntervals(int dataRef); + + @Override + public final int size() { + return length; + } + + @Override + public final int getDocId() { + return currentDocId; + } + + @Override + public final int[] getDocIds() { + return docIds; + } + + @Override + public final long getSubquery() { + return subquery; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/Posting.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/Posting.java new file mode 100644 index 00000000000..776e428a6ff --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/Posting.java @@ -0,0 +1,50 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +/** + * Represents an entry in a posting list, containing an integer id and integer data reference. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class Posting implements Comparable<Posting> { + + private final int id; + private final int dataRef; + + public Posting(int id, int dataRef) { + this.id = id; + this.dataRef = dataRef; + } + + public int getId() { + return id; + } + + public int getDataRef() { + return dataRef; + } + + @Override + public int compareTo(Posting o) { + return Integer.compareUnsigned(id, o.id); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Posting posting = (Posting) o; + + if (id != posting.id) return false; + return dataRef == posting.dataRef; + + } + + @Override + public int hashCode() { + int result = id; + result = 31 * result + dataRef; + return result; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/PostingList.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PostingList.java new file mode 100644 index 00000000000..f0f310f1962 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PostingList.java @@ -0,0 +1,53 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +/** + * Interface for posting lists to be used by the algorithm implemented in PredicateSearch. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public interface PostingList { + /** + * Moves the posting list past the supplied document id. + * @param docId Document id to move past. + * @return True if a new document was found + */ + boolean nextDocument(int docId); + + /** + * Prepare iterator for interval iteration. + * @return True if the iterator has any intervals. + */ + boolean prepareIntervals(); + + /** + * Fetches the next interval for the current document. + * @return True if there was a next interval + */ + boolean nextInterval(); + + /** + * @return The doc id for the current document + */ + int getDocId(); + + /** + * @return The number of documents (actual count or estimate) + */ + int size(); + + /** + * @return The current interval for the current document + */ + int getInterval(); + + /** + * @return the subquery bitmap for this posting list. + */ + long getSubquery(); + + /** + * @return The document ids + */ + int[] getDocIds(); +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateIntervalStore.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateIntervalStore.java new file mode 100644 index 00000000000..2e6598ff252 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateIntervalStore.java @@ -0,0 +1,123 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.google.common.primitives.Ints; +import com.yahoo.search.predicate.serialization.SerializationHelper; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * @author bjorncs + */ +public class PredicateIntervalStore { + + private final int[][] intervalsList; + + public PredicateIntervalStore(int[][] intervalsList) { + this.intervalsList = intervalsList; + } + + public int[] get(int intervalRef) { + assert intervalRef < intervalsList.length; + return intervalsList[intervalRef]; + } + + public void writeToOutputStream(DataOutputStream out) throws IOException { + out.writeInt(intervalsList.length); + for (int[] intervals : intervalsList) { + SerializationHelper.writeIntArray(intervals, out); + } + } + + public static PredicateIntervalStore fromInputStream(DataInputStream in) throws IOException { + int length = in.readInt(); + int[][] intervalsList = new int[length][]; + for (int i = 0; i < length; i++) { + intervalsList[i] = SerializationHelper.readIntArray(in); + } + return new PredicateIntervalStore(intervalsList); + } + + public static class Builder { + private final List<int[]> intervalsListBuilder = new ArrayList<>(); + private final Map<Entry, Integer> intervalsListIndexes = new HashMap<>(); + private final Map<Integer, Integer> entriesForSize = new HashMap<>(); + private int cacheHits = 0; + private int totalInserts = 0; + + public int insert(List<Integer> intervals) { + int size = intervals.size(); + if (size == 0) { + throw new IllegalArgumentException("Cannot insert interval list of size 0"); + } + int[] array = Ints.toArray(intervals); + Entry entry = new Entry(array); + ++totalInserts; + if (intervalsListIndexes.containsKey(entry)) { + ++cacheHits; + return intervalsListIndexes.get(entry); + } else { + int index = intervalsListBuilder.size(); + intervalsListBuilder.add(array); + intervalsListIndexes.put(entry, index); + entriesForSize.merge(size, 1, Integer::sum); + return index; + } + } + + public PredicateIntervalStore build() { + int nIntervals = intervalsListBuilder.size(); + int[][] intervalsList = new int[nIntervals][]; + for (int i = 0; i < nIntervals; i++) { + intervalsList[i] = intervalsListBuilder.get(i); + } + return new PredicateIntervalStore(intervalsList); + } + + public int getCacheHits() { + return cacheHits; + } + + public int getTotalInserts() { + return totalInserts; + } + + public Map<Integer, Integer> getEntriesForSize() { + return entriesForSize; + } + + public int getNumberOfIntervals() { + return intervalsListBuilder.size(); + } + + private static class Entry { + public final int[] intervals; + public final int hashCode; + + public Entry(int[] intervals) { + this.intervals = intervals; + this.hashCode = Arrays.hashCode(intervals); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + Entry entry = (Entry) o; + return Arrays.equals(intervals, entry.intervals); + } + + @Override + public int hashCode() { + return hashCode; + } + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateOptimizer.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateOptimizer.java new file mode 100644 index 00000000000..c3388ffea2b --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateOptimizer.java @@ -0,0 +1,46 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.document.predicate.Predicate; +import com.yahoo.search.predicate.Config; +import com.yahoo.search.predicate.optimization.AndOrSimplifier; +import com.yahoo.search.predicate.optimization.BooleanSimplifier; +import com.yahoo.search.predicate.optimization.ComplexNodeTransformer; +import com.yahoo.search.predicate.optimization.FeatureConjunctionTransformer; +import com.yahoo.search.predicate.optimization.NotNodeReorderer; +import com.yahoo.search.predicate.optimization.OrSimplifier; +import com.yahoo.search.predicate.optimization.PredicateOptions; +import com.yahoo.search.predicate.optimization.PredicateProcessor; + +/** + * Prepares the predicate for indexing. + * Performs several optimization passes on the predicate. + * + * @author bjorncs + */ +public class PredicateOptimizer { + private final PredicateProcessor[] processors; + private final PredicateOptions options; + + public PredicateOptimizer(Config config) { + this.options = new PredicateOptions(config.arity, config.lowerBound, config.upperBound); + processors = new PredicateProcessor[]{ + new AndOrSimplifier(), + new BooleanSimplifier(), + new ComplexNodeTransformer(), + new OrSimplifier(), + new NotNodeReorderer(), + new FeatureConjunctionTransformer(config.useConjunctionAlgorithm) + }; + } + + /** + * @return The optimized predicate. + */ + public Predicate optimizePredicate(Predicate predicate) { + for (PredicateProcessor processor : processors) { + predicate = processor.process(predicate, options); + } + return predicate; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateRangeTermExpander.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateRangeTermExpander.java new file mode 100644 index 00000000000..290c81c2ca8 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateRangeTermExpander.java @@ -0,0 +1,116 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.document.predicate.PredicateHash; + +/** + * Expands range terms from a query to find the set of features they translate to. + * + * @author bjorncs + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class PredicateRangeTermExpander { + private final int arity; + private final int maxPositiveLevels; + private final int maxNegativeLevels; + private final long lowerBound; + private final long upperBound; + + /** + * Creates a PredicateRangeTermExpander with default value range. + * + * @param arity The arity used to index the predicates + */ + public PredicateRangeTermExpander(int arity) { + this(arity, Long.MIN_VALUE, Long.MAX_VALUE); + } + + /** + * @param arity The arity used to index the predicates + * @param lowerBound The minimum value used by any range predicate in the system + * @param upperBound The maximum value used by any range predicate in the system + */ + public PredicateRangeTermExpander(int arity, long lowerBound, long upperBound) { + this.arity = arity; + this.lowerBound = lowerBound; + this.upperBound = upperBound; + this.maxPositiveLevels = calculateMaxLevels(upperBound); + this.maxNegativeLevels = calculateMaxLevels(-lowerBound); + } + + private int calculateMaxLevels(long t) { + int maxLevels = 1; + while ((t /= this.arity) != 0) { + maxLevels++; + } + return maxLevels; + } + + /** + * Expands a range term to a set of features (ranges and edges) to be used in a query. + * + * @param key The term key + * @param value The term value + * @param rangeHandler Handler for range features (long) + * @param edgeHandler Handler for edge features (long, int) + */ + public void expand(String key, long value, RangeHandler rangeHandler, EdgeHandler edgeHandler) { + if (value < lowerBound || value > upperBound) { + // Value outside bounds -> expand to nothing. + return; + } + int maxLevels = value > 0 ? maxPositiveLevels : maxNegativeLevels; + int sign = value > 0 ? 1 : -1; + // Append key to feature string builder + StringBuilder builder = new StringBuilder(128); + builder.append(key).append('='); + + long levelSize = arity; + long edgeInterval = (value / arity) * arity; + edgeHandler.handleEdge(createEdgeFeatureHash(builder, edgeInterval), (int) Math.abs(value - edgeInterval)); + for (int i = 0; i < maxLevels; ++i) { + long start = (value / levelSize) * levelSize; + if (Math.abs(start) + levelSize - 1 < 0) { // overflow + break; + } + rangeHandler.handleRange(createRangeFeatureHash(builder, start, start + sign * (levelSize - 1))); + levelSize *= arity; + if (levelSize <= 0 && levelSize != Long.MIN_VALUE) { //overflow + break; + } + } + } + + private long createRangeFeatureHash(StringBuilder builder, long start, long end) { + int prefixLength = builder.length(); + String feature = end > 0 + ? builder.append(start).append('-').append(end).toString() + : builder.append(end).append('-').append(Math.abs(start)).toString(); + + builder.setLength(prefixLength); + return PredicateHash.hash64(feature); + } + + private long createEdgeFeatureHash(StringBuilder builder, long edgeInterval) { + int prefixLength = builder.length(); + String feature = builder.append(edgeInterval).toString(); + builder.setLength(prefixLength); + return PredicateHash.hash64(feature); + } + + /** + * Callback for ranges generated by the expansion. + */ + @FunctionalInterface + public interface RangeHandler { + void handleRange(long featureHash); + } + + /** + * Callback for edges generated by the expansion. + */ + @FunctionalInterface + public interface EdgeHandler { + void handleEdge(long featureHash, int value); + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateSearch.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateSearch.java new file mode 100644 index 00000000000..c40bd944b7b --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/PredicateSearch.java @@ -0,0 +1,281 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.search.predicate.Hit; +import com.yahoo.search.predicate.SubqueryBitmap; +import com.yahoo.search.predicate.utils.PrimitiveArraySorter; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.List; +import java.util.Optional; +import java.util.Spliterator; +import java.util.function.Consumer; +import java.util.stream.Stream; +import java.util.stream.StreamSupport; + +/** + * Implementation of the "Interval" boolean search algorithm. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class PredicateSearch { + private final PostingList[] postingLists; + private final byte[] nPostingListsForDocument; + private final byte[] minFeatureIndex; + private final int[] docIds; + private final int[] intervals; + private final long[] subqueries; + private final long[] subqueryMarkers; + private final boolean[] visited; + private final short[] intervalEnds; + + private short[] sortedIndexes; + private short[] sortedIndexesMergeBuffer; + private int nPostingLists; + + /** + * Creates a search for a set of posting lists. + * @param postingLists Posting lists for the boolean variables that evaluate to true + * @param nPostingListsForDocument The number of posting list for each docId + * @param minFeatureIndex Index from docId to min-feature value. + * @param intervalEnds The interval end for each document. + * @param highestIntervalEnd The highest end value. + */ + public PredicateSearch( + List<PostingList> postingLists, byte[] nPostingListsForDocument, + byte[] minFeatureIndex, short[] intervalEnds, int highestIntervalEnd) { + int size = postingLists.size(); + this.nPostingListsForDocument = nPostingListsForDocument; + this.minFeatureIndex = minFeatureIndex; + this.nPostingLists = size; + this.postingLists = postingLists.toArray(new PostingList[postingLists.size()]); + this.sortedIndexes = new short[size]; + this.sortedIndexesMergeBuffer = new short[size]; + this.docIds = new int[size]; + this.intervals = new int[size]; + this.subqueries = new long[size]; + this.subqueryMarkers = new long[highestIntervalEnd + 1]; + this.visited = new boolean[highestIntervalEnd + 1]; + this.intervalEnds = intervalEnds; + + // Sort posting list array based on the underlying number of documents (largest first). + Arrays.sort(this.postingLists, (l, r) -> -Integer.compare(l.size(), r.size())); + + for (short i = 0; i < size; ++i) { + PostingList postingList = this.postingLists[i]; + sortedIndexes[i] = i; + docIds[i] = postingList.getDocId(); + subqueries[i] = postingList.getSubquery(); + } + // All posting lists start at beginId, so no need to sort yet. + } + + /** + * @return A stream of Hit-objects from a lazy evaluation of the boolean search algorithm. + */ + public Stream<Hit> stream() { + if (nPostingLists == 0) { + return Stream.empty(); + } + return StreamSupport.stream(new PredicateSpliterator(), false); + } + + private class PredicateSpliterator implements java.util.Spliterator<Hit> { + private int lastHit = -1; + + @Override + public boolean tryAdvance(Consumer<? super Hit> action) { + Optional<Hit> optionalHit = seek(lastHit + 1); + optionalHit.ifPresent(hit -> { + lastHit = hit.getDocId(); + action.accept(hit); + }); + return optionalHit.isPresent(); + } + + @Override + public Spliterator<Hit> trySplit() { + return null; + } + + @Override + public long estimateSize() { + return Long.MAX_VALUE; + } + + @Override + public int characteristics() { + return ORDERED | DISTINCT | SORTED | NONNULL; + } + + @Override + public Comparator<Hit> getComparator() { + return null; + } + } + + private Optional<Hit> seek(int docId) { + boolean skippedToEnd = skipMinFeature(docId); + while (nPostingLists > 0 && !skippedToEnd) { + int docId0 = docIds[sortedIndexes[0]]; + int minFeature = minFeatureIndex[docId0]; + int k = minFeature > 0 ? minFeature - 1 : 0; + int intervalEnd = Short.toUnsignedInt(intervalEnds[docId0]); + if (k < nPostingLists) { + int docIdK = docIds[sortedIndexes[k]]; + if (docId0 == docIdK) { + if (evaluateHit(docId0, k, intervalEnd)) { + return Optional.of(new Hit(docId0, subqueryMarkers[intervalEnd])); + } + } + } + skippedToEnd = skipMinFeature(docId0 + 1); + } + return Optional.empty(); + } + + private boolean skipMinFeature(int docId) { + int nDocuments = nPostingListsForDocument.length; + while (docId < nDocuments && minFeatureIndex[docId] > nPostingListsForDocument[docId]) { + ++docId; + } + if (docId < nDocuments) { + advanceAllTo(docId); + return false; + } + return true; + } + + private boolean evaluateHit(int docId, int k, int intervalEnd) { + int candidates = k + 1; + for (int i = candidates; i < nPostingLists; ++i) { + if (docIds[sortedIndexes[i]] == docId) { + ++candidates; + } else { + break; + } + } + + int nNoIntervalIterators = 0; + for (int i = 0; i < candidates; ++i) { + short index = sortedIndexes[i]; + PostingList postingList = postingLists[index]; + if (postingList.prepareIntervals()) { + intervals[index] = postingList.getInterval(); + } else { + ++nNoIntervalIterators; + intervals[index] = 0xFFFFFFFF; + } + } + PrimitiveArraySorter.sort(sortedIndexes, 0, candidates, (a, b) -> Integer.compareUnsigned(intervals[a], intervals[b])); + candidates -= nNoIntervalIterators; + + Arrays.fill(subqueryMarkers, 0, intervalEnd + 1, 0); + subqueryMarkers[0] = SubqueryBitmap.ALL_SUBQUERIES; + Arrays.fill(visited, 0, intervalEnd + 1, false); + visited[0] = true; + int highestEndSeen = 1; + for (int i = 0; i < candidates; ) { + int index = sortedIndexes[i]; + int lastEnd = addInterval(index, highestEndSeen); + if (lastEnd == -1) { + return false; + } + highestEndSeen = Math.max(lastEnd, highestEndSeen); + PostingList postingList = postingLists[index]; + if (postingList.nextInterval()) { + intervals[index] = postingList.getInterval(); + restoreSortedOrder(i, candidates); + } else { + ++i; + } + } + return subqueryMarkers[intervalEnd] != 0; + } + + private void restoreSortedOrder(int first, int last) { + short indexToMove = sortedIndexes[first]; + long intervalToMove = Integer.toUnsignedLong(intervals[indexToMove]); + while (++first < last && intervalToMove > Integer.toUnsignedLong(intervals[sortedIndexes[first]])) { + sortedIndexes[first - 1] = sortedIndexes[first]; + } + sortedIndexes[first - 1] = indexToMove; + } + + /** + * Returns the end value of the interval, + * or -1 if the highest end value seen is less than the interval begin. + */ + private int addInterval(int index, int highestEndSeen) { + int interval = intervals[index]; + long subqueryBitMap = subqueries[index]; + if (Interval.isZStar1Interval(interval)) { + int begin = Interval.getZStar1Begin(interval); + int end = Interval.getZStar1End(interval); + if (highestEndSeen < begin) return -1; + markSubquery(begin, end, ~subqueryMarkers[begin]); + return end; + } else { + int begin = Interval.getBegin(interval); + int end = Interval.getEnd(interval); + if (highestEndSeen < begin -1) return -1; + markSubquery(begin - 1, end, subqueryMarkers[begin - 1] & subqueryBitMap); + return end; + } + } + + private void markSubquery(int begin, int end, long subqueryBitmap) { + if (visited[begin]) { + visited[end] = true; + subqueryMarkers[end] |= subqueryBitmap; + } + } + + // Advances all posting lists to (or beyond) docId. + private void advanceAllTo(int docId) { + int i = 0; + int completedCount = 0; + for (; i < nPostingLists; ++i) { + if (docIds[sortedIndexes[i]] >= docId) { + break; + } + if (!advanceOneTo(docId, i)) { + ++completedCount; + } + } + // No need to sort if all posting lists are finished. + if (i > 0 && nPostingLists > completedCount) { + sortIndexes(i); + // Decrement the number of posting lists. + } + nPostingLists -= completedCount; + } + + // Advances a single posting list to (or beyond) docId. + private boolean advanceOneTo(int docId, int index) { + int i = sortedIndexes[index]; + PostingList postingList = postingLists[i]; + if (postingList.nextDocument(docId - 1)) { + docIds[i] = postingList.getDocId(); + return true; + } + docIds[i] = Integer.MAX_VALUE; // will be last after sorting. + return false; + } + + private void sortIndexes(int numUpdated) { + // Sort the updated elements + boolean swapMergeBuffer = + PrimitiveArraySorter.sortAndMerge(sortedIndexes, sortedIndexesMergeBuffer, numUpdated, nPostingLists, + (a, b) -> Integer.compare(docIds[a], docIds[b])); + if (swapMergeBuffer) { + // Swap references + short[] temp = sortedIndexes; + sortedIndexes = sortedIndexesMergeBuffer; + sortedIndexesMergeBuffer = temp; + } + + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/SimpleIndex.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/SimpleIndex.java new file mode 100644 index 00000000000..da30a27d09e --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/SimpleIndex.java @@ -0,0 +1,110 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.gs.collections.api.map.primitive.LongObjectMap; +import com.gs.collections.api.tuple.primitive.LongObjectPair; +import com.gs.collections.impl.map.mutable.primitive.LongObjectHashMap; +import com.yahoo.search.predicate.serialization.SerializationHelper; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +/** + * An index mapping keys of type Long to lists of postings of generic data. + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class SimpleIndex { + + private final LongObjectMap<Entry> dictionary; + + public SimpleIndex(LongObjectMap<Entry> dictionary) { + this.dictionary = dictionary; + } + + /** + * Retrieves a posting list for a given key + * @param key Key to lookup + * @return List of postings + */ + public Entry getPostingList(long key) { + return dictionary.get(key); + } + + public void writeToOutputStream(DataOutputStream out) throws IOException { + out.writeInt(dictionary.size()); + for (LongObjectPair<Entry> pair : dictionary.keyValuesView()) { + out.writeLong(pair.getOne()); + Entry entry = pair.getTwo(); + SerializationHelper.writeIntArray(entry.docIds, out); + SerializationHelper.writeIntArray(entry.dataRefs, out); + } + } + + public static SimpleIndex fromInputStream(DataInputStream in) throws IOException { + int nEntries = in.readInt(); + LongObjectHashMap<Entry> dictionary = new LongObjectHashMap<>(nEntries); + for (int i = 0; i < nEntries; i++) { + long key = in.readLong(); + int[] docIds = SerializationHelper.readIntArray(in); + int[] dataRefs = SerializationHelper.readIntArray(in); + dictionary.put(key, new Entry(docIds, dataRefs)); + } + dictionary.compact(); + return new SimpleIndex(dictionary); + } + + public static class Entry { + public final int[] docIds; + public final int[] dataRefs; + + private Entry(int[] docIds, int[] dataRefs) { + this.docIds = docIds; + this.dataRefs = dataRefs; + } + } + + public static class Builder { + private final HashMap<Long, List<Posting>> dictionaryBuilder = new HashMap<>(); + private int entryCount; + + /** + * Inserts an object with an id for a key. + * @param key Key to map from + * @param posting Entry for the posting list + */ + public void insert(long key, Posting posting) { + dictionaryBuilder.computeIfAbsent(key, (k) -> new ArrayList<>()).add(posting); + ++entryCount; + } + + public SimpleIndex build() { + LongObjectHashMap<Entry> dictionary = new LongObjectHashMap<>(); + for (Map.Entry<Long, List<Posting>> entry : dictionaryBuilder.entrySet()) { + List<Posting> postings = entry.getValue(); + Collections.sort(postings); + int size = postings.size(); + int[] docIds = new int[size]; + int[] dataRefs = new int[size]; + for (int i = 0; i < size; i++) { + Posting posting = postings.get(i); + docIds[i] = posting.getId(); + dataRefs[i] = posting.getDataRef(); + } + dictionary.put(entry.getKey(), new Entry(docIds, dataRefs)); + } + dictionary.compact(); + return new SimpleIndex(dictionary); + } + + public int getEntryCount() { return entryCount; } + public int getKeyCount() { return dictionaryBuilder.size(); } + } + +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/ZeroConstraintPostingList.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/ZeroConstraintPostingList.java new file mode 100644 index 00000000000..0dcd6533b34 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/ZeroConstraintPostingList.java @@ -0,0 +1,73 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.search.predicate.SubqueryBitmap; + +/** + * Wraps an int stream of document ids into a PostingList. + * All documents in the stream are considered matches. + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class ZeroConstraintPostingList implements PostingList { + private final int[] docIds; + private final int length; + private int currentIndex; + private int currentDocId; + + public ZeroConstraintPostingList(int[] docIds) { + this.docIds = docIds; + this.currentIndex = 0; + this.currentDocId = -1; + this.length = docIds.length; + } + + @Override + public boolean nextDocument(int docId) { + int currentDocId = this.currentDocId; + while (currentIndex < length && currentDocId <= docId) { + currentDocId = docIds[currentIndex++]; + } + if (currentDocId <= docId) { + return false; + } + this.currentDocId = currentDocId; + return true; + } + + @Override + public boolean prepareIntervals() { + return true; + } + + @Override + public boolean nextInterval() { + return false; + } + + @Override + public int size() { + return length; + } + + @Override + public int getInterval() { + return Interval.fromBoundaries(1, Interval.ZERO_CONSTRAINT_RANGE); + } + + @Override + public int getDocId() { + return currentDocId; + } + + @Override + public long getSubquery() { + return SubqueryBitmap.ALL_SUBQUERIES; + } + + @Override + public int[] getDocIds() { + return docIds; + } + +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/ZstarCompressedPostingList.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/ZstarCompressedPostingList.java new file mode 100644 index 00000000000..90d2d6352c2 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/ZstarCompressedPostingList.java @@ -0,0 +1,66 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.search.predicate.SubqueryBitmap; + +/** + * Wraps a posting list of compressed NOT-features. + * The compression works by implying an interval of size 1 after each + * stored interval, unless the next interval starts with 16 bits of 0, + * in which case the current interval is extended to the next. + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class ZstarCompressedPostingList extends MultiIntervalPostingList { + private final PredicateIntervalStore store; + private int[] currentIntervals; + private int currentIntervalIndex; + private int prevInterval; + private int currentInterval; + + + /** + * @param docIds Posting list as a stream. + */ + public ZstarCompressedPostingList(PredicateIntervalStore store, int[] docIds, int[] dataRefs) { + super(docIds, dataRefs, SubqueryBitmap.ALL_SUBQUERIES); + this.store = store; + } + + @Override + protected boolean prepareIntervals(int dataRef) { + currentIntervals = store.get(dataRef); + currentIntervalIndex = 0; + return nextInterval(); + } + + @Override + public boolean nextInterval() { + int nextInterval = -1; + if (currentIntervalIndex < currentIntervals.length) { + nextInterval = currentIntervals[currentIntervalIndex]; + } + if (prevInterval != 0) { + if (Interval.isZStar2Interval(nextInterval)) { + this.currentInterval = Interval.combineZStarIntervals(prevInterval, nextInterval); + ++currentIntervalIndex; + } else { + int end = Interval.getZStar1End(prevInterval); + this.currentInterval = Interval.fromZStar1Boundaries(end, end + 1); + } + prevInterval = 0; + return true; + } else if (nextInterval != -1) { + this.currentInterval = nextInterval; + ++currentIntervalIndex; + prevInterval = nextInterval; + return true; + } + return false; + } + + @Override + public int getInterval() { + return currentInterval; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionHit.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionHit.java new file mode 100644 index 00000000000..230150f43dc --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionHit.java @@ -0,0 +1,52 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +import com.yahoo.search.predicate.SubqueryBitmap; + +/** + * Represents a conjunction hit. See {@link ConjunctionIndex}. + * + * @author bjorncs + */ +public class ConjunctionHit implements Comparable<ConjunctionHit> { + public final long conjunctionId; + public final long subqueryBitmap; + + public ConjunctionHit(long conjunctionId, long subqueryBitmap) { + this.conjunctionId = conjunctionId; + this.subqueryBitmap = subqueryBitmap; + } + + @Override + public int compareTo(ConjunctionHit other) { + return Long.compareUnsigned(conjunctionId, other.conjunctionId); + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + ConjunctionHit that = (ConjunctionHit) o; + + if (conjunctionId != that.conjunctionId) return false; + return subqueryBitmap == that.subqueryBitmap; + + } + + @Override + public int hashCode() { + int result = (int) (conjunctionId ^ (conjunctionId >>> 32)); + result = 31 * result + (int) (subqueryBitmap ^ (subqueryBitmap >>> 32)); + return result; + } + + @Override + public String toString() { + if (subqueryBitmap == SubqueryBitmap.DEFAULT_VALUE) { + return "" + conjunctionId; + } else { + return "[" + conjunctionId + ",0x" + Long.toHexString(subqueryBitmap) + "]"; + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionId.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionId.java new file mode 100644 index 00000000000..b51f648dcb6 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionId.java @@ -0,0 +1,28 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +/** + * Conjunction id format: + * bit 31-1: id/hash + * bit 0: 0: negated, 1: not negated + * + * @author bjorncs + */ +public class ConjunctionId { + + public static int compare(int c1, int c2) { + return Integer.compare(c1 | 1, c2 | 1); + } + + public static boolean equals(int c1, int c2) { + return (c1 | 1) == (c2 | 1); + } + + public static boolean isPositive(int c) { + return (c & 1) == 1; + } + + public static int nextId(int c) { + return (c | 1) + 1; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIdIterator.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIdIterator.java new file mode 100644 index 00000000000..e0859e93609 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIdIterator.java @@ -0,0 +1,47 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +/** + * Conjunction id posting list iterator for a single feature/assignment (e.g. a=b). + * + * @author bjorncs + */ +public class ConjunctionIdIterator { + + private final int[] conjunctionIds; + private final long subqueryBitmap; + private int currentConjunctionId; + private int length; + private int index; + + public ConjunctionIdIterator(long subqueryBitmap, int[] conjunctionIds) { + this.subqueryBitmap = subqueryBitmap; + this.conjunctionIds = conjunctionIds; + this.currentConjunctionId = conjunctionIds[0]; + this.length = conjunctionIds.length; + this.index = 0; + } + + public boolean next(int conjunctionId) { + if (index == length) return false; + + int candidate = currentConjunctionId; + while (ConjunctionId.compare(conjunctionId, candidate) > 0 && ++index < length) { + candidate = conjunctionIds[index]; + } + currentConjunctionId = candidate; + return ConjunctionId.compare(conjunctionId, candidate) <= 0; + } + + public long getSubqueryBitmap() { + return subqueryBitmap; + } + + public int getConjunctionId() { + return currentConjunctionId; + } + + public int[] getConjunctionIds() { + return conjunctionIds; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndex.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndex.java new file mode 100644 index 00000000000..c272cb5fb92 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndex.java @@ -0,0 +1,279 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +import com.gs.collections.api.map.primitive.IntObjectMap; +import com.gs.collections.api.map.primitive.LongObjectMap; +import com.gs.collections.api.tuple.primitive.IntObjectPair; +import com.gs.collections.api.tuple.primitive.LongObjectPair; +import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; +import com.gs.collections.impl.map.mutable.primitive.LongObjectHashMap; +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.search.predicate.PredicateQuery; +import com.yahoo.search.predicate.SubqueryBitmap; +import com.yahoo.search.predicate.serialization.SerializationHelper; +import com.yahoo.search.predicate.utils.PrimitiveArraySorter; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Optional; + +/** + * A searchable index of conjunctions (see {@link FeatureConjunction} / {@link IndexableFeatureConjunction}). + * Implements the algorithm described in the paper <a href="http://dl.acm.org/citation.cfm?id=1687633">Indexing Boolean Expressions</a>. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class ConjunctionIndex { + // A map from K value to FeatureIndex + private final IntObjectMap<FeatureIndex> kIndex; + private final int[] zList; + private final long[] idMapping; + + public ConjunctionIndex(IntObjectMap<FeatureIndex> kIndex, int[] zList, long[] idMapping) { + this.kIndex = kIndex; + this.zList = zList; + this.idMapping = idMapping; + } + + public Searcher searcher() { + return new Searcher(); + } + + public void writeToOutputStream(DataOutputStream out) throws IOException { + SerializationHelper.writeIntArray(zList, out); + SerializationHelper.writeLongArray(idMapping, out); + out.writeInt(kIndex.size()); + for (IntObjectPair<FeatureIndex> p : kIndex.keyValuesView()) { + out.writeInt(p.getOne()); + p.getTwo().writeToOutputStream(out); + } + } + + public static ConjunctionIndex fromInputStream(DataInputStream in) throws IOException { + int[] zList = SerializationHelper.readIntArray(in); + long[] idMapping = SerializationHelper.readLongArray(in); + int kIndexSize = in.readInt(); + IntObjectHashMap<FeatureIndex> kIndex = new IntObjectHashMap<>(kIndexSize); + for (int i = 0; i < kIndexSize; i++) { + int key = in.readInt(); + kIndex.put(key, FeatureIndex.fromInputStream(in)); + } + kIndex.compact(); + return new ConjunctionIndex(kIndex, zList, idMapping); + } + + public static class FeatureIndex { + // Maps a feature id to conjunction id + private final LongObjectMap<int[]> map; + + public FeatureIndex(LongObjectMap<int[]> map) { + this.map = map; + } + + public Optional<int[]> getConjunctionIdsForFeature(long featureId) { + return Optional.ofNullable(map.get(featureId)); + } + + public void writeToOutputStream(DataOutputStream out) throws IOException { + out.writeInt(map.size()); + for (LongObjectPair<int[]> p : map.keyValuesView()) { + out.writeLong(p.getOne()); + SerializationHelper.writeIntArray(p.getTwo(), out); + } + } + + public static FeatureIndex fromInputStream(DataInputStream in) throws IOException { + int mapSize = in.readInt(); + LongObjectHashMap<int[]> map = new LongObjectHashMap<>(mapSize); + for (int i = 0; i < mapSize; i++) { + long key = in.readLong(); + map.put(key, SerializationHelper.readIntArray(in)); + } + map.compact(); + return new FeatureIndex(map); + } + } + + public class Searcher { + private final byte[] iteratorsPerConjunction; + + private Searcher() { + this.iteratorsPerConjunction = new byte[idMapping.length]; + } + + /** + * Retrieves a list of hits for the given query. + * + * @param query Specifies the boolean variables that are true. + * @return List of hits + */ + public List<ConjunctionHit> search(PredicateQuery query) { + List<ConjunctionHit> conjunctionHits = new ArrayList<>(); + int uniqueKeys = (int) query.getFeatures().stream().map(e -> e.key).distinct().count(); + for (int k = uniqueKeys; k >= 0; k--) { + List<ConjunctionIdIterator> iterators = new ArrayList<>(); + getFeatureIndex(k) + .ifPresent(featureIndex -> addFeatureIterators(query, featureIndex, iterators)); + if (k == 0 && zList.length > 0) { + iterators.add(new ConjunctionIdIterator(SubqueryBitmap.ALL_SUBQUERIES, zList)); + } + if (!iterators.isEmpty()) { + calculateIteratorsPerConjunction(iterators); + findMatchingConjunctions(k, iterators, conjunctionHits, iteratorsPerConjunction); + } + } + return conjunctionHits; + } + + private void calculateIteratorsPerConjunction(List<ConjunctionIdIterator> iterators) { + Arrays.fill(iteratorsPerConjunction, (byte)0); + for (ConjunctionIdIterator iterator : iterators) { + for (int id : iterator.getConjunctionIds()) { + if (ConjunctionId.isPositive(id)) { + ++iteratorsPerConjunction[id >>> 1]; + } + } + } + } + + private Optional<FeatureIndex> getFeatureIndex(int k) { + return Optional.ofNullable(kIndex.get(k)); + } + + private void addFeatureIterators(PredicateQuery query, FeatureIndex featureIndex, List<ConjunctionIdIterator> iterators) { + query.getFeatures().stream() + .map(e -> toSingleTermIterator(e, featureIndex)) + .filter(Optional::isPresent) + .map(Optional::get) + .forEach(iterators::add); + } + + private Optional<ConjunctionIdIterator> toSingleTermIterator(PredicateQuery.Feature feature, FeatureIndex featureIndex) { + return featureIndex.getConjunctionIdsForFeature(feature.featureHash) + .map(conjunctions -> new ConjunctionIdIterator(feature.subqueryBitmap, conjunctions)); + } + + private void findMatchingConjunctions(int k, List<ConjunctionIdIterator> iterators, List<ConjunctionHit> matchingIds, byte[] iteratorsPerConjunction) { + if (k == 0) { + k = 1; + } + int nextId = getNextId(0, k, iteratorsPerConjunction); + if (nextId == -1) { + return; // no hits + } + + int nIterators = iterators.size(); + if (nIterators < k) { + return; // No hits + } + short[] sortedIndexes = new short[nIterators]; + short[] sortedIndexesMergeBuffer = new short[nIterators]; + for (short i = 0; i < nIterators; ++i) { + sortedIndexes[i] = i; + } + + int[] currentIds = new int[nIterators]; + int nCompleted = initializeIterators(iterators, sortedIndexes, currentIds, nextId); + nIterators -= nCompleted; + + while (nIterators >= k) { + int id0 = currentIds[sortedIndexes[0]]; + int idK = currentIds[sortedIndexes[k - 1]]; + + // There should be at least k iterators for conjunction. + if (ConjunctionId.equals(id0, idK)) { + long matchingSubqueries = SubqueryBitmap.ALL_SUBQUERIES; + // Find first positive conjunction + int firstPositive = 0; + while (firstPositive < nIterators && !ConjunctionId.isPositive(currentIds[sortedIndexes[firstPositive]])) { + // AND in the complement of the bitmap for negative conjunctions. + matchingSubqueries &= ~iterators.get(sortedIndexes[firstPositive]).getSubqueryBitmap(); + ++firstPositive; + } + if (firstPositive + k <= nIterators) { + // Verify that at there are k positive iterators for the current conjunction. + id0 = currentIds[sortedIndexes[firstPositive]]; + idK = currentIds[sortedIndexes[firstPositive + k - 1]]; + if (id0 == idK) { // We know that id0 is positive conjunction + for (int i = firstPositive; i < firstPositive + k; i++) { + matchingSubqueries &= iterators.get(sortedIndexes[i]).getSubqueryBitmap(); + } + if (matchingSubqueries != 0) { + matchingIds.add(new ConjunctionHit(toExternalId(id0), matchingSubqueries)); + } + } + } + } + + // Advance iterators to next conjunction. + nextId = getNextId(ConjunctionId.nextId(id0), k, iteratorsPerConjunction); + if (nextId == -1) { + return; + } + int completed = 0; + int i; + for (i = 0; i < nIterators; ++i) { + short index = sortedIndexes[i]; + if (ConjunctionId.compare(currentIds[index], nextId) < 0) { + ConjunctionIdIterator iterator = iterators.get(index); + if (iterator.next(nextId)) { + currentIds[index] = iterator.getConjunctionId(); + } else { + currentIds[index] = Integer.MAX_VALUE; + ++completed; + } + } else { + break; + } + } + if (i > 0 && nIterators - completed >= k) { + boolean swapMergeBuffer = + PrimitiveArraySorter.sortAndMerge(sortedIndexes, sortedIndexesMergeBuffer, i, nIterators, + (a, b) -> Integer.compare(currentIds[a], currentIds[b])); + if (swapMergeBuffer) { + short[] temp = sortedIndexes; + sortedIndexes = sortedIndexesMergeBuffer; + sortedIndexesMergeBuffer = temp; + } + } + nIterators -= completed; + } + } + + private int initializeIterators(List<ConjunctionIdIterator> iterators, short[] sortedIndexes, int[] currentIds, int nextId) { + int nCompleted = 0; + int nIterators = iterators.size(); + for (int i = 0; i < nIterators; i++) { + ConjunctionIdIterator iterator = iterators.get(i); + if (iterator.next(nextId)) { + currentIds[i] = iterator.getConjunctionId(); + } else { + currentIds[i] = Integer.MAX_VALUE; + ++nCompleted; + + } + } + PrimitiveArraySorter.sort(sortedIndexes, (a, b) -> Integer.compare(currentIds[a], currentIds[b])); + return nCompleted; + } + + private int getNextId(int fromId, int k, byte[] iteratorsPerConjunction) { + int id = fromId >>> 1; + int nDocuments = iteratorsPerConjunction.length; + while (id < nDocuments && iteratorsPerConjunction[id] < k) { + ++id; + } + return id == nDocuments ? -1 : ((id << 1) | 1); + + } + + private long toExternalId(int internalId) { + return idMapping[internalId >>> 1]; + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndexBuilder.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndexBuilder.java new file mode 100644 index 00000000000..d1086eaca23 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndexBuilder.java @@ -0,0 +1,120 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +import com.google.common.primitives.Ints; +import com.google.common.primitives.Longs; +import com.gs.collections.api.map.primitive.IntObjectMap; +import com.gs.collections.impl.map.mutable.primitive.IntObjectHashMap; +import com.gs.collections.impl.map.mutable.primitive.LongObjectHashMap; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.TreeSet; + +/** + * A builder for {@link ConjunctionIndex}. + * + * @author bjorncs + */ +public class ConjunctionIndexBuilder { + // A map from K value to FeatureIndex + private final HashMap<Integer, FeatureIndexBuilder> kIndexBuilder = new HashMap<>(); + private final List<Integer> zListBuilder = new ArrayList<>(); + // Unique ids / mapping from internal to external id. LinkedHashSet as the insertion order is crucial. + private final Set<Long> seenIds = new LinkedHashSet<>(); + private int idCounter = 0; + private int conjunctionsSeen = 0; + + private static class FeatureIndexBuilder { + // Maps a feature id to conjunction id + private final Map<Long, Set<Integer>> map = new HashMap<>(); + + public void insert(long featureId, int conjunctionId) { + map.computeIfAbsent(featureId, k -> new TreeSet<>()).add(conjunctionId); + } + } + + public void indexConjunction(IndexableFeatureConjunction c) { + ++conjunctionsSeen; + long externalId = c.id; + if (seenIds.contains(externalId)) return; + + seenIds.add(externalId); + int internalId = generateInternalId(); + FeatureIndexBuilder featureIndexBuilder = kIndexBuilder.computeIfAbsent(c.k, (k) -> new FeatureIndexBuilder()); + c.features.forEach(f -> featureIndexBuilder.insert(f, internalId)); + c.negatedFeatures.forEach(f -> featureIndexBuilder.insert(f, internalId & ~1)); + if (c.k == 0) { + zListBuilder.add(internalId); + } + } + + private int generateInternalId() { + return ((idCounter++) << 1) | 1; + } + + public ConjunctionIndex build() { + int[] zList = Ints.toArray(zListBuilder); + IntObjectMap<ConjunctionIndex.FeatureIndex> kIndex = buildKIndex(kIndexBuilder); + long[] idMapping = Longs.toArray(seenIds); + return new ConjunctionIndex(kIndex, zList, idMapping); + } + + /** + * @return The number of unique features in index. + */ + public long calculateFeatureCount() { + return kIndexBuilder.values().stream() + .map(index -> index.map.keySet()) + .reduce( + new HashSet<>(), + (acc, keySet) -> { + keySet.forEach(acc::add); + return acc; + }, (acc1, acc2) -> { + acc1.addAll(acc2); + return acc1; + }) + .size(); + } + + /** + * @return The number of unique conjunctions indexed. + */ + public long getUniqueConjunctionCount() { + return seenIds.size(); + } + + public int getZListSize() { + return zListBuilder.size(); + } + + public int getConjunctionsSeen() { + return conjunctionsSeen; + } + + private static IntObjectMap<ConjunctionIndex.FeatureIndex> buildKIndex(HashMap<Integer, FeatureIndexBuilder> kIndexBuilder) { + IntObjectHashMap<ConjunctionIndex.FeatureIndex> map = new IntObjectHashMap<>(); + for (Map.Entry<Integer, FeatureIndexBuilder> entry : kIndexBuilder.entrySet()) { + map.put(entry.getKey(), buildFeatureIndex(entry.getValue())); + } + map.compact(); + return map; + } + + private static ConjunctionIndex.FeatureIndex buildFeatureIndex(FeatureIndexBuilder featureIndexBuilder) { + LongObjectHashMap<int[]> map = new LongObjectHashMap<>(); + for (Map.Entry<Long, Set<Integer>> featureEntry : featureIndexBuilder.map.entrySet()) { + int[] conjunctionIds = Ints.toArray(featureEntry.getValue()); + map.put(featureEntry.getKey(), conjunctionIds); + } + map.compact(); + return new ConjunctionIndex.FeatureIndex(map); + } + +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/IndexableFeatureConjunction.java b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/IndexableFeatureConjunction.java new file mode 100644 index 00000000000..016b2ddfc8e --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/index/conjunction/IndexableFeatureConjunction.java @@ -0,0 +1,75 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.FeatureSet; +import com.yahoo.document.predicate.Negation; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.search.predicate.index.Feature; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * IndexableFeatureConjunction is a post-processed {@link FeatureConjunction} which can be indexed by {@link ConjunctionIndex}. + * + * @author bjorncs + */ +public class IndexableFeatureConjunction { + /** Conjunction id */ + public final long id; + /** K value - number of non-negated operands */ + public final int k; + // Hashed features from non-negated feature sets. + public final Set<Long> features = new HashSet<>(); + // Hash features from negated feature sets. + public final Set<Long> negatedFeatures = new HashSet<>(); + + public IndexableFeatureConjunction(FeatureConjunction conjunction) { + List<Predicate> operands = conjunction.getOperands(); + int nNegatedFeatureSets = 0; + for (Predicate operand : operands) { + if (operand instanceof FeatureSet) { + addFeatures((FeatureSet)operand, features); + } else { + FeatureSet featureSet = (FeatureSet)((Negation) operand).getOperand(); + addFeatures(featureSet, negatedFeatures); + ++nNegatedFeatureSets; + } + } + + id = calculateConjunctionId(); + k = operands.size() - nNegatedFeatureSets; + } + + private static void addFeatures(FeatureSet featureSet, Set<Long> features) { + String key = featureSet.getKey(); + featureSet.getValues().forEach(value -> features.add(Feature.createHash(key, value))); + } + + private long calculateConjunctionId() { + long posHash = 0; + for (long feature : features) { + posHash ^= feature; + } + long negHash = 0; + for (long feature : negatedFeatures) { + negHash ^= feature; + } + return (posHash + 3 * negHash) | 1; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + IndexableFeatureConjunction that = (IndexableFeatureConjunction) o; + return id == that.id; + } + + @Override + public int hashCode() { + return (int) (id ^ (id >>> 32)); + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/optimization/FeatureConjunctionTransformer.java b/predicate-search/src/main/java/com/yahoo/search/predicate/optimization/FeatureConjunctionTransformer.java new file mode 100644 index 00000000000..07786645250 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/optimization/FeatureConjunctionTransformer.java @@ -0,0 +1,135 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.optimization; + +import com.yahoo.document.predicate.Conjunction; +import com.yahoo.document.predicate.Disjunction; +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.FeatureRange; +import com.yahoo.document.predicate.FeatureSet; +import com.yahoo.document.predicate.Negation; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.search.predicate.index.conjunction.ConjunctionIndex; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +/** + * Transforms Conjunctions with only (negated) {@link FeatureSet} instances to {@link FeatureConjunction}. + * The {@link FeatureConjunction}s are indexed by the {@link ConjunctionIndex}. + * + * @author bjorncs + */ +public class FeatureConjunctionTransformer implements PredicateProcessor { + + // Only Conjunctions having less or equal number of FeatureSet operands than threshold are converted to FeatureConjunction. + private static final int CONVERSION_THRESHOLD = Integer.MAX_VALUE; + private final boolean useConjunctionAlgorithm; + + public FeatureConjunctionTransformer(boolean useConjunctionAlgorithm) { + this.useConjunctionAlgorithm = useConjunctionAlgorithm; + } + + @Override + public Predicate process(Predicate predicate, PredicateOptions options) { + if (useConjunctionAlgorithm) { + return transform(predicate); + } + return predicate; + } + + private static Predicate transform(Predicate predicate) { + if (predicate instanceof Conjunction) { + Conjunction conjunction = (Conjunction) predicate; + conjunction.getOperands().replaceAll(FeatureConjunctionTransformer::transform); + long nValidOperands = numberOfValidFeatureSetOperands(conjunction); + if (nValidOperands > 1 && nValidOperands <= CONVERSION_THRESHOLD) { + return convertConjunction(conjunction, nValidOperands); + } + } else if (predicate instanceof Disjunction) { + ((Disjunction)predicate).getOperands().replaceAll(FeatureConjunctionTransformer::transform); + } else if (predicate instanceof Negation) { + Negation negation = (Negation) predicate; + negation.setOperand(transform(negation.getOperand())); + } + return predicate; + } + + /** + * Conversion rules: + * 1) A {@link FeatureConjunction} may only consist of FeatureSets having unique keys. + * If multiple {@link FeatureSet} share the same key, they have to be placed into separate FeatureConjunctions. + * 2) A FeatureConjunction must have at least 2 operands. + * 3) Any operand that is not a FeatureSet, negated or not, + * (e.g {@link FeatureRange}) cannot be placed into a FeatureConjunction. + * 4) All FeatureSets may only have a single value. + * + * See the tests in FeatureConjunctionTransformerTest for conversion examples. + */ + private static Predicate convertConjunction(Conjunction conjunction, long nValidOperands) { + List<Predicate> operands = conjunction.getOperands(); + // All operands are instance of FeatureSet are valid and may therefor be placed into a single FeatureConjunction. + if (nValidOperands == operands.size()) { + return new FeatureConjunction(operands); + } + + List<Predicate> invalidFeatureConjunctionOperands = new ArrayList<>(); + List<Map<String, Predicate>> featureConjunctionOperandsList = new ArrayList<>(); + featureConjunctionOperandsList.add(new TreeMap<>()); + for (Predicate operand : operands) { + if (FeatureConjunction.isValidFeatureConjunctionOperand(operand)) { + addFeatureConjunctionOperand(featureConjunctionOperandsList, operand); + } else { + invalidFeatureConjunctionOperands.add(operand); + } + } + + // Create a Conjunction root. + Conjunction newConjunction = new Conjunction(); + newConjunction.addOperands(invalidFeatureConjunctionOperands); + // For all operand partitions: create FeatureConjunction if partition has more than a single predicate. + for (Map<String, Predicate> featureConjunctionOperands : featureConjunctionOperandsList) { + Collection<Predicate> values = featureConjunctionOperands.values(); + if (featureConjunctionOperands.size() == 1) { + // Add single operand directly to root conjunction. + newConjunction.addOperands(values); + } else { + newConjunction.addOperand(new FeatureConjunction(new ArrayList<>(values))); + } + } + return newConjunction; + } + + private static void addFeatureConjunctionOperand(List<Map<String, Predicate>> featureConjunctionOperandsList, Predicate operand) { + String key = getFeatureSetKey(operand); + for (Map<String, Predicate> featureConjunctionOperands : featureConjunctionOperandsList) { + if (!featureConjunctionOperands.containsKey(key)) { + featureConjunctionOperands.put(key, operand); + return; + } + } + Map<String, Predicate> conjunctionOperands = new TreeMap<>(); + conjunctionOperands.put(key, operand); + featureConjunctionOperandsList.add(conjunctionOperands); + } + + private static long numberOfValidFeatureSetOperands(Conjunction conjunction) { + return conjunction.getOperands().stream() + .filter(FeatureConjunction::isValidFeatureConjunctionOperand) + .map(FeatureConjunctionTransformer::getFeatureSetKey) + .distinct() + .count(); + } + + private static String getFeatureSetKey(Predicate predicate) { + if (predicate instanceof FeatureSet) { + return ((FeatureSet) predicate).getKey(); + } else { + Negation negation = (Negation) predicate; + return ((FeatureSet) negation.getOperand()).getKey(); + } + } + +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/package-info.java b/predicate-search/src/main/java/com/yahoo/search/predicate/package-info.java new file mode 100644 index 00000000000..4fd744770b4 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/package-info.java @@ -0,0 +1,8 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +@com.yahoo.osgi.annotation.ExportPackage +@com.yahoo.api.annotations.PublicApi +/* + The predicate package is exported by the document module (OSGi). + Do not remove unless the intention is to modify the public API of document. + */ +package com.yahoo.search.predicate; diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/serialization/PredicateQuerySerializer.java b/predicate-search/src/main/java/com/yahoo/search/predicate/serialization/PredicateQuerySerializer.java new file mode 100644 index 00000000000..80c96ea32d6 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/serialization/PredicateQuerySerializer.java @@ -0,0 +1,108 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.serialization; + +import com.fasterxml.jackson.core.JsonFactory; +import com.fasterxml.jackson.core.JsonGenerator; +import com.yahoo.search.predicate.PredicateQuery; +import com.yahoo.search.predicate.PredicateQueryParser; +import com.yahoo.search.predicate.SubqueryBitmap; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.io.StringWriter; +import java.io.Writer; +import java.util.List; + +import static java.util.stream.Collectors.toList; + +/** + * Converts {@link PredicateQuery} to and from JSON + * + * Example: + * { + * features: [ + * {"k": "key-name", "v":"value", "s":"0xDEADBEEFDEADBEEF"} + * ], + * rangeFeatures: [ + * {"k": "key-name", "v":42, "s":"0xDEADBEEFDEADBEEF"} + * ] + * } + * + * @author bjorncs + */ +public class PredicateQuerySerializer { + private final JsonFactory factory = new JsonFactory(); + private final PredicateQueryParser parser = new PredicateQueryParser(); + + public String toJSON(PredicateQuery query) { + try { + StringWriter writer = new StringWriter(1024); + toJSON(query, writer); + return writer.toString(); + } catch (IOException e) { + throw new RuntimeException(e); + } + } + + public void toJSON(PredicateQuery query, Writer writer) throws IOException { + try (JsonGenerator g = factory.createGenerator(writer)) { + g.writeStartObject(); + + // Write features + g.writeArrayFieldStart("features"); + for (PredicateQuery.Feature feature : query.getFeatures()) { + writeFeature(feature.key, feature.value, feature.subqueryBitmap, g, JsonGenerator::writeStringField); + } + g.writeEndArray(); + + // Write rangeFeatures + g.writeArrayFieldStart("rangeFeatures"); + for (PredicateQuery.RangeFeature rangeFeature : query.getRangeFeatures()) { + writeFeature(rangeFeature.key, rangeFeature.value, rangeFeature.subqueryBitmap, g, + JsonGenerator::writeNumberField); + } + g.writeEndArray(); + + g.writeEndObject(); + } + } + + private static <T> void writeFeature( + String key, T value, long subqueryBitmap, JsonGenerator g, ValueWriter<T> valueWriter) + throws IOException { + + g.writeStartObject(); + g.writeStringField("k", key); + valueWriter.write(g, "v", value); + if (subqueryBitmap != SubqueryBitmap.DEFAULT_VALUE) { + g.writeStringField("s", toHexString(subqueryBitmap)); + } + g.writeEndObject(); + } + + @FunctionalInterface + private interface ValueWriter<T> { + void write(JsonGenerator g, String key, T value) throws IOException; + } + + public PredicateQuery fromJSON(String json) { + PredicateQuery query = new PredicateQuery(); + parser.parseJsonQuery(json, query::addFeature, query::addRangeFeature); + return query; + } + + public static List<PredicateQuery> parseQueriesFromFile(String queryFile, int maxQueryCount) throws IOException { + PredicateQuerySerializer serializer = new PredicateQuerySerializer(); + try (BufferedReader reader = new BufferedReader(new FileReader(queryFile), 8 * 1024)) { + return reader.lines() + .limit(maxQueryCount) + .map(serializer::fromJSON) + .collect(toList()); + } + } + + private static String toHexString(long subqueryBitMap) { + return "0x" + Long.toHexString(subqueryBitMap); + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/serialization/SerializationHelper.java b/predicate-search/src/main/java/com/yahoo/search/predicate/serialization/SerializationHelper.java new file mode 100644 index 00000000000..e63ac946c7e --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/serialization/SerializationHelper.java @@ -0,0 +1,79 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.serialization; + +import com.yahoo.search.predicate.PredicateIndex; + +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +/** + * Misc utility functions to help serialization of {@link PredicateIndex}. + * + * @author bjorncs + */ +public class SerializationHelper { + public static void writeIntArray(int[] array, DataOutputStream out) throws IOException { + out.writeInt(array.length); + for (int v : array) { + out.writeInt(v); + } + } + + public static int[] readIntArray(DataInputStream in) throws IOException { + int length = in.readInt(); + int[] array = new int[length]; + for (int i = 0; i < length; i++) { + array[i] = in.readInt(); + } + return array; + } + + public static void writeByteArray(byte[] array, DataOutputStream out) throws IOException { + out.writeInt(array.length); + for (int v : array) { + out.writeByte(v); + } + } + + public static byte[] readByteArray(DataInputStream in) throws IOException { + int length = in.readInt(); + byte[] array = new byte[length]; + for (int i = 0; i < length; i++) { + array[i] = in.readByte(); + } + return array; + } + + public static void writeLongArray(long[] array, DataOutputStream out) throws IOException { + out.writeInt(array.length); + for (long v : array) { + out.writeLong(v); + } + } + + public static long[] readLongArray(DataInputStream in) throws IOException { + int length = in.readInt(); + long[] array = new long[length]; + for (int i = 0; i < length; i++) { + array[i] = in.readLong(); + } + return array; + } + + public static void writeShortArray(short[] array, DataOutputStream out) throws IOException { + out.writeInt(array.length); + for (short v : array) { + out.writeShort(v); + } + } + + public static short[] readShortArray(DataInputStream in) throws IOException { + int length = in.readInt(); + short[] array = new short[length]; + for (int i = 0; i < length; i++) { + array[i] = in.readShort(); + } + return array; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/utils/PostingListSearch.java b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/PostingListSearch.java new file mode 100644 index 00000000000..93246bfaf85 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/PostingListSearch.java @@ -0,0 +1,89 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +/** + * Algorithms for searching in the docId arrays in posting lists. + * @author bjorncs + */ +public class PostingListSearch { + + // Use linear search when size less than threshold + public static final int LINEAR_SEARCH_THRESHOLD = 16; + // Use linear search when value difference between first value and key is less than threshold + public static final int LINEAR_SEARCH_THRESHOLD_2 = 32; + // User binary search when size is less than threshold + public static final int BINARY_SEARCH_THRESHOLD = 32768; + + public static int interpolationSearch(int[] a, int fromIndex, int toIndex, int key) { + int low = fromIndex; + int lowVal = a[low]; + if (key - lowVal < LINEAR_SEARCH_THRESHOLD_2) { + return linearSearch(a, low, toIndex, key); + } + int high = toIndex - 1; + int diff = high - low; + if (diff <= BINARY_SEARCH_THRESHOLD) { + return binarySearch(a, low, toIndex, key); + } + int highVal = a[high]; + do { + if (key == lowVal) { + return low + 1; + } + if (key >= highVal) { + return high + 1; + } + int mean = (int) (diff * (long) (key - lowVal) / (highVal - lowVal)); + int eps = diff >>> 4; + int lowMid = low + Math.max(0, mean - eps); + int highMid = low + Math.min(diff, mean + eps); + assert lowMid <= highMid; + assert lowMid >= low; + assert highMid <= high; + + if (a[lowMid] > key) { + high = lowMid; + highVal = a[lowMid]; + } else if (a[highMid] <= key) { + low = highMid; + lowVal = a[highMid]; + } else { + low = lowMid; + lowVal = a[lowMid]; + high = highMid; + highVal = a[highMid]; + } + assert low <= high; + diff = high - low; + } while (diff >= BINARY_SEARCH_THRESHOLD); + return binarySearch(a, low, high + 1, key); + } + + /** + * Modified binary search: + * - Returns the first index where a[index] is larger then key + */ + private static int binarySearch(int[] a, int fromIndex, int toIndex, int key) { + assert fromIndex < toIndex; + int low = fromIndex; + int high = toIndex - 1; + while (high - low > LINEAR_SEARCH_THRESHOLD) { + int mid = (low + high) >>> 1; + assert mid < high; + if (a[mid] < key) { + low = mid + 1; + } else { + high = mid; + } + } + return linearSearch(a, low, high + 1, key); + } + + private static int linearSearch(int[] a, int low, int high, int key) { + assert low < high; + while (low < high && a[low] <= key) { + ++low; + } + return low; + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/utils/PrimitiveArraySorter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/PrimitiveArraySorter.java new file mode 100644 index 00000000000..63b7acc6042 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/PrimitiveArraySorter.java @@ -0,0 +1,97 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +/** + * This class enables sorting of an array of primitive short values using a supplied comparator for custom ordering. + * The sort methods in Java standard library cannot sort using a comparator for primitive arrays. + * Sorting is performed using Quicksort. + * + * @author bjorncs + */ +public class PrimitiveArraySorter { + + @FunctionalInterface + public interface ShortComparator { + int compare(short l, short r); + } + + private PrimitiveArraySorter() {} + + public static void sort(short[] array, ShortComparator comparator) { + sort(array, 0, array.length, comparator); + } + + public static void sort(short[] array, int fromIndex, int toIndex, ShortComparator comparator) { + // Sort using insertion sort for size less then 20. + if (toIndex - fromIndex <= 20) { + insertionSort(array, fromIndex, toIndex, comparator); + return; + } + int i = fromIndex; + int j = toIndex - 1; + short pivotValue = array[i + (j - i) / 2]; // Use middle item as pivot value. + while (i < j) { + while (comparator.compare(pivotValue, array[i]) > 0) ++i; + while (comparator.compare(array[j], pivotValue) > 0) --j; + if (i < j) { + short temp = array[i]; + array[i] = array[j]; + array[j] = temp; + ++i; + --j; + } + } + if (fromIndex < j) { + sort(array, fromIndex, j + 1, comparator); + } + if (i < toIndex - 1) { + sort(array, i, toIndex, comparator); + } + } + + public static boolean sortAndMerge(short[] array, short[] mergeArray, int pivotIndex, int toIndex, ShortComparator comparator) { + if (array.length == 1) return false; + sort(array, 0, pivotIndex, comparator); + if (pivotIndex == toIndex || comparator.compare(array[pivotIndex - 1], array[pivotIndex]) <= 0) { + return false; + } + merge(array, mergeArray, pivotIndex, toIndex, comparator); + return true; + } + + public static void merge(short[] array, short[] mergeArray, int pivotIndex, ShortComparator comparator) { + merge(array, mergeArray, pivotIndex, array.length, comparator); + } + + public static void merge(short[] array, short[] mergeArray, int pivotIndex, int toIndex, ShortComparator comparator) { + int indexMergeArray = 0; + int indexPartition0 = 0; + int indexPartition1 = pivotIndex; + while (indexPartition0 < pivotIndex && indexPartition1 < toIndex) { + short val0 = array[indexPartition0]; + short val1 = array[indexPartition1]; + if (comparator.compare(val0, val1) <= 0) { + mergeArray[indexMergeArray++] = val0; + ++indexPartition0; + } else { + mergeArray[indexMergeArray++] = val1; + ++indexPartition1; + } + } + int nLeftPartition0 = pivotIndex - indexPartition0; + System.arraycopy(array, indexPartition0, mergeArray, indexMergeArray, nLeftPartition0); + System.arraycopy(array, indexPartition1, mergeArray, indexMergeArray + nLeftPartition0, toIndex - indexPartition1); + } + + private static void insertionSort(short[] array, int fromIndex, int toIndex, ShortComparator comparator) { + for (int i = fromIndex + 1; i < toIndex; ++i) { + int j = i; + while (j > 0 && comparator.compare(array[j - 1], array[j]) > 0) { + short temp = array[j - 1]; + array[j - 1] = array[j]; + array[j] = temp; + --j; + } + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/utils/TargetingQueryFileConverter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/TargetingQueryFileConverter.java new file mode 100644 index 00000000000..a333286b465 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/TargetingQueryFileConverter.java @@ -0,0 +1,289 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +import com.google.common.net.UrlEscapers; +import com.yahoo.search.predicate.PredicateQuery; +import com.yahoo.search.predicate.serialization.PredicateQuerySerializer; + +import java.io.BufferedReader; +import java.io.BufferedWriter; +import java.io.File; +import java.io.FileReader; +import java.io.FileWriter; +import java.io.IOException; +import java.util.*; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.stream.Stream; + +import static java.util.stream.Collectors.joining; + +/** + * Converts a targeting query (the format provided by targeting team) into a file of Vespa queries formatted as URLs. + * + * The format is the following: + * - Each line represents one bulk query (upto 64 subqueries) + * - Each bulk query has a set of subqueries separated by ";" + * - Each subquery is of the format: attrName\tattrValue\tsubqueryIndex\tisRangeTerm; + * - Some attributes have no value. + * - Value may contain ";" + * + * @author bjorncs + */ +public class TargetingQueryFileConverter { + + // Subqueries having more than this value are skipped. + private static final int MAX_NUMBER_OF_TERMS = 100; + + private enum OutputFormat {JSON, YQL} + + private TargetingQueryFileConverter() {} + + public static void main(String[] args) throws IOException { + int nQueries = 123042; + int batchFactor = 64; + Subqueries subqueries = parseRiseQueries(new File("test-data/rise-query2.txt"), nQueries); + filterOutHugeSubqueries(subqueries); + List<Query> queries = batchSubqueries(subqueries, batchFactor); + writeSubqueriesToFile( + queries, + new File("test-data/targeting-queries-json-" + batchFactor + "b-" + nQueries + "n.txt"), + OutputFormat.JSON); + writeSubqueriesToFile( + queries, + new File("test-data/targeting-queries-yql-" + batchFactor + "b-" + nQueries + "n.txt"), + OutputFormat.YQL); + } + + + private static void writeSubqueriesToFile(List<Query> queries, File output, OutputFormat outputFormat) + throws IOException { + try (BufferedWriter writer = new BufferedWriter(new FileWriter(output))) { + if (outputFormat == OutputFormat.JSON) { + writeJSONOutput(writer, queries); + } else { + writeYQLOutput(writer, queries); + } + + } + } + + private static void writeJSONOutput(BufferedWriter writer, List<Query> queries) throws IOException { + PredicateQuerySerializer serializer = new PredicateQuerySerializer(); + for (Query query : queries) { + PredicateQuery predicateQuery = toPredicateQuery(query); + String json = serializer.toJSON(predicateQuery); + writer.append(json).append('\n'); + } + } + + private static PredicateQuery toPredicateQuery(Query query) { + PredicateQuery predicateQuery = new PredicateQuery(); + for (Map.Entry<Long, Set<Feature>> e : query.valuesForSubqueries.entrySet()) { + e.getValue().forEach(f -> predicateQuery.addFeature(f.key, f.strValue, e.getKey())); + } + for (Map.Entry<Long, Set<Feature>> e : query.rangesForSubqueries.entrySet()) { + e.getValue().forEach(f -> predicateQuery.addRangeFeature(f.key, f.longValue, e.getKey())); + } + return predicateQuery; + } + + private static void writeYQLOutput(BufferedWriter writer, List<Query> queries) throws IOException { + for (Query query : queries) { + writer.append(toYqlString(query)).append('\n'); + } + } + + private static String toYqlString(Query query) { + StringBuilder yqlBuilder = new StringBuilder("select * from sources * where predicate(boolean, "); + yqlBuilder + .append(createYqlFormatSubqueryMapString(query.valuesForSubqueries, query.isSingleQuery)) + .append(", ") + .append(createYqlFormatSubqueryMapString(query.rangesForSubqueries, query.isSingleQuery)) + .append(");"); + return "/search/?query&nocache&yql=" + UrlEscapers.urlFormParameterEscaper().escape(yqlBuilder.toString()); + } + + /* + * The subqueryBatchFactor determines the batch factor for each query. A maximum of 64 queries can be batched + * into a single query (as subqueries). + * 0 => Do not batch and output plain queries (no subquery). + * 1 => Do not batch, but output queries with single subquery. + */ + private static List<Query> batchSubqueries(Subqueries subqueries, int subqueryBatchFactor) { + Iterator<Integer> iterator = subqueries.subqueries.iterator(); + List<Query> result = new ArrayList<>(); + while (iterator.hasNext()) { + // Aggregate the subqueries that contains a given value. + Map<Feature, Long> subqueriesForValue = new TreeMap<>(); + Map<Feature, Long> subqueriesForRange = new TreeMap<>(); + // Batch single to single subquery for batch factor 0. + for (int i = 0; i < Math.max(1, subqueryBatchFactor) && iterator.hasNext(); ++i) { + Integer subquery = iterator.next(); + registerSubqueryValues(i, subqueries.valuesForSubquery.get(subquery), subqueriesForValue); + registerSubqueryValues(i, subqueries.rangesForSubquery.get(subquery), subqueriesForRange); + } + + // Aggregate the values that are contained in a given set of subqueries. + Query query = new Query(subqueryBatchFactor == 0); + simplifyAndFillQueryValues(query.valuesForSubqueries, subqueriesForValue); + simplifyAndFillQueryValues(query.rangesForSubqueries, subqueriesForRange); + result.add(query); + } + return result; + } + + private static void registerSubqueryValues(int subquery, Set<Feature> values, Map<Feature, Long> subqueriesForValue) { + if (values != null) { + values.forEach(value -> subqueriesForValue.merge(value, 1L << subquery, (ids1, ids2) -> ids1 | ids2)); + } + } + + private static void simplifyAndFillQueryValues(Map<Long, Set<Feature>> queryValues, Map<Feature, Long> subqueriesForValue) { + for (Map.Entry<Feature, Long> entry : subqueriesForValue.entrySet()) { + Feature feature = entry.getKey(); + Long subqueryBitmap = entry.getValue(); + Set<Feature> featureSet = queryValues.computeIfAbsent(subqueryBitmap, (k) -> new HashSet<>()); + featureSet.add(feature); + } + } + + private static String createYqlFormatSubqueryMapString(Map<Long, Set<Feature>> subqueriesForString, boolean isSingleQuery) { + return subqueriesForString.entrySet().stream() + .map(e -> { + Stream<String> features = e.getValue().stream().map(Feature::asYqlString); + if (isSingleQuery) { + return features.collect(joining(", ")); + } else { + // Note: Cannot use method reference as both method toString(int) and method toString() match. + String values = features.collect(joining(", ", "{", "}")); + return String.format("\"0x%s\":%s", Long.toHexString(e.getKey()), values); + } + }) + .collect(joining(", ", "{", "}")); + } + + private static Subqueries parseRiseQueries(File riseQueryFile, int maxQueries) throws IOException { + try (BufferedReader reader = new BufferedReader(new FileReader(riseQueryFile))) { + Subqueries parsedSubqueries = new Subqueries(); + AtomicInteger counter = new AtomicInteger(1); + reader.lines() + .limit(maxQueries) + .forEach(riseQuery -> parseRiseQuery(parsedSubqueries, riseQuery, counter.getAndIncrement())); + return parsedSubqueries; + } + } + + private static void filterOutHugeSubqueries(Subqueries subqueries) { + Iterator<Integer> iterator = subqueries.subqueries.iterator(); + while (iterator.hasNext()) { + Integer subquery = iterator.next(); + Set<Feature> values = subqueries.valuesForSubquery.get(subquery); + Set<Feature> ranges = subqueries.rangesForSubquery.get(subquery); + int sizeValues = values == null ? 0 : values.size(); + int sizeRanges = ranges == null ? 0 : ranges.size(); + if (sizeValues + sizeRanges > MAX_NUMBER_OF_TERMS) { + iterator.remove(); + subqueries.valuesForSubquery.remove(subquery); + subqueries.rangesForSubquery.remove(subquery); + } + } + } + + private static void parseRiseQuery(Subqueries subqueries, String queryString, int queryId) { + StringTokenizer subQueryTokenizer = new StringTokenizer(queryString, "\t", true); + while (subQueryTokenizer.hasMoreTokens()) { + String key = subQueryTokenizer.nextToken("\t"); + subQueryTokenizer.nextToken(); // Consume delimiter + String value = subQueryTokenizer.nextToken(); + if (value.equals("\t")) { + value = ""; + } else { + subQueryTokenizer.nextToken(); // Consume delimiter + } + int subQueryIndex = Integer.parseInt(subQueryTokenizer.nextToken()); + subQueryTokenizer.nextToken(); // Consume delimiter + boolean isRangeTerm = Boolean.parseBoolean(subQueryTokenizer.nextToken(";")); + if (subQueryTokenizer.hasMoreTokens()) { + subQueryTokenizer.nextToken(); // Consume delimiter + } + int subqueryId = subQueryIndex + 64 * queryId; + if (isRangeTerm) { + Set<Feature> rangeFeatures = subqueries.rangesForSubquery.computeIfAbsent( + subqueryId, (id) -> new HashSet<>()); + rangeFeatures.add(new Feature(key, Long.parseLong(value))); + } else { + Set<Feature> features = subqueries.valuesForSubquery.computeIfAbsent(subqueryId, (id) -> new HashSet<>()); + features.add(new Feature(key, value)); + } + subqueries.subqueries.add(subqueryId); + } + } + + private static class Subqueries { + public final TreeSet<Integer> subqueries = new TreeSet<>(); + public final Map<Integer, Set<Feature>> valuesForSubquery = new HashMap<>(); + public final Map<Integer, Set<Feature>> rangesForSubquery = new HashMap<>(); + } + + private static class Query { + public final boolean isSingleQuery; + public final Map<Long, Set<Feature>> valuesForSubqueries = new TreeMap<>(); + public final Map<Long, Set<Feature>> rangesForSubqueries = new TreeMap<>(); + + public Query(boolean isSingleQuery) { + this.isSingleQuery = isSingleQuery; + } + } + + private static class Feature implements Comparable<Feature> { + public final String key; + private final String strValue; + private final long longValue; + + public Feature(String key, String value) { + this.key = key; + this.strValue = value; + this.longValue = 0; + } + + public Feature(String key, long value) { + this.key = key; + this.strValue = null; + this.longValue = value; + } + + public String asYqlString() { + if (strValue != null) { + return String.format("\"%s\":\"%s\"", key, strValue); + } else { + return String.format("\"%s\":%dl", key, longValue); + } + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (!(o instanceof Feature)) return false; + + Feature feature = (Feature) o; + + if (longValue != feature.longValue) return false; + if (!key.equals(feature.key)) return false; + return !(strValue != null ? !strValue.equals(feature.strValue) : feature.strValue != null); + + } + + @Override + public int hashCode() { + int result = key.hashCode(); + result = 31 * result + (strValue != null ? strValue.hashCode() : 0); + result = 31 * result + (int) (longValue ^ (longValue >>> 32)); + return result; + } + + @Override + public int compareTo(Feature o) { + return asYqlString().compareTo(o.asYqlString()); + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaFeedParser.java b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaFeedParser.java new file mode 100644 index 00000000000..8ba9236a66c --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaFeedParser.java @@ -0,0 +1,44 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +import com.yahoo.document.predicate.Predicate; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.util.function.Consumer; + +/** + * Parses a feed file containing documents in XML format. Its implementation is based on the following assumptions: + * 1. Each document has single predicate field. + * 2. The predicate is stored in a field named "boolean". + * + * @author bjorncs + */ +public class VespaFeedParser { + + public static int parseDocuments(String feedFile, int maxDocuments, Consumer<Predicate> consumer) throws IOException { + int documentCount = 0; + try (BufferedReader reader = new BufferedReader(new FileReader(feedFile), 8 * 1024)) { + reader.readLine(); + reader.readLine(); // Skip to start of first document + String line = reader.readLine(); + while (!line.startsWith("</vespafeed>") && documentCount < maxDocuments) { + while (!line.startsWith("<boolean>")) { + line = reader.readLine(); + } + Predicate predicate = Predicate.fromString(extractBooleanExpression(line)); + consumer.accept(predicate); + ++documentCount; + while (!line.startsWith("<document") && !line.startsWith("</vespafeed>")) { + line = reader.readLine(); + } + } + } + return documentCount; + } + + private static String extractBooleanExpression(String line) { + return line.substring(9, line.length() - 10); + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaFeedWriter.java b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaFeedWriter.java new file mode 100644 index 00000000000..544a9a12af0 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaFeedWriter.java @@ -0,0 +1,43 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +import com.yahoo.document.predicate.Predicate; +import org.apache.commons.lang.StringEscapeUtils; + +import java.io.BufferedWriter; +import java.io.IOException; +import java.io.Writer; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class VespaFeedWriter extends BufferedWriter { + private String namespace; + private String documentType; + + VespaFeedWriter(Writer writer, String namespace, String documentType) throws IOException { + super(writer); + this.namespace = namespace; + this.documentType = documentType; + + this.append("<?xml version=\"1.0\" encoding=\"UTF-8\" standalone=\"no\"?>\n"); + this.append("<vespafeed>\n"); + } + + @Override + public void close() throws IOException { + this.append("</vespafeed>\n"); + super.close(); + } + + public void writePredicateDocument(int id, String fieldName, Predicate predicate) { + try { + this.append(String.format("<document documenttype=\"%2$s\" documentid=\"id:%1$s:%2$s::%3$d\">\n", + namespace, documentType, id)); + this.append("<" + fieldName + ">" + StringEscapeUtils.escapeHtml(predicate.toString()) + "</" + fieldName + ">\n"); + this.append("</document>\n"); + } catch (IOException e) { + throw new RuntimeException(e); + } + } +} diff --git a/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaQueryParser.java b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaQueryParser.java new file mode 100644 index 00000000000..b8ec20c59a0 --- /dev/null +++ b/predicate-search/src/main/java/com/yahoo/search/predicate/utils/VespaQueryParser.java @@ -0,0 +1,105 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +import com.yahoo.search.predicate.PredicateQuery; + +import java.io.BufferedReader; +import java.io.FileReader; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.util.ArrayList; +import java.util.List; +import java.util.function.BiConsumer; + +import static java.util.stream.Collectors.toList; + +/** + * Parses query file containing Vespa queries using the deprecated predicate format (query properties - not YQL). + * + * @author bjorncs + */ +public class VespaQueryParser { + + /** + * Parses a query formatted using the deprecated boolean query format (query properties). + */ + public static List<PredicateQuery> parseQueries(String queryFile, int maxQueryCount) throws IOException { + try (BufferedReader reader = new BufferedReader(new FileReader(queryFile), 8 * 1024)) { + List<PredicateQuery> queries = reader.lines() + .limit(maxQueryCount) + .map(VespaQueryParser::parseQueryFromQueryProperties) + .collect(toList()); + return queries; + } + } + + public static PredicateQuery parseQueryFromQueryProperties(String queryString) { + try { + // Decode the URL in case the query property content is escaped. + queryString = URLDecoder.decode(queryString, "UTF-8"); + PredicateQuery query = new PredicateQuery(); + extractQueryValues(queryString, "boolean.attributes", query::addFeature); + extractQueryValues(queryString, "boolean.rangeAttributes", + (k, v) -> query.addRangeFeature(k, Integer.parseInt(v))); + return query; + } catch (UnsupportedEncodingException e) { + throw new RuntimeException(e); + } + } + + private static void extractQueryValues(String query, String prefix, BiConsumer<String, String> registerTerm) { + int rangeIndex = query.indexOf(prefix); + if (rangeIndex != -1) { + // Adding 2 to skip '={' + int startIndex = rangeIndex + prefix.length() + 2; + // '%7D' represents the end of the predicate string. + int endIndex = query.indexOf("}", startIndex); + String rangeString = query.substring(startIndex, endIndex); + List<Feature> features = new ArrayList<>(); + String[] keyValuePairs = rangeString.split(","); + + for (String keyValuePair : keyValuePairs) { + String[] keyAndValue = keyValuePair.split(":"); + // If not colon is found, the string is part of the previous value. + if (keyAndValue.length == 1) { + Feature feature = features.get(features.size() - 1); + feature.value += ("," + keyValuePair); + } else { + features.add(new Feature(keyAndValue[0], keyAndValue[1])); + } + } + features.stream().forEach(f -> registerTerm.accept(f.key, f.value)); + } + } + + private static class Feature { + public final String key; + public String value; + + private Feature(String key, String value) { + this.key = key; + this.value = value; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + + Feature feature = (Feature) o; + + if (!key.equals(feature.key)) return false; + if (!value.equals(feature.value)) return false; + + return true; + } + + @Override + public int hashCode() { + int result = key.hashCode(); + result = 31 * result + value.hashCode(); + return result; + } + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/PredicateIndexBuilderTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/PredicateIndexBuilderTest.java new file mode 100644 index 00000000000..0c673ffa267 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/PredicateIndexBuilderTest.java @@ -0,0 +1,42 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.yahoo.document.predicate.BooleanPredicate; +import com.yahoo.document.predicate.Predicate; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class PredicateIndexBuilderTest { + + @Test(expected = IllegalArgumentException.class) + public void requireThatIndexingMultiDocumentsWithSameIdThrowsException() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(2); + builder.indexDocument(1, Predicate.fromString("a in ['b']")); + builder.indexDocument(1, Predicate.fromString("c in ['d']")); + } + + @Test + public void requireThatEmptyDocumentsCanBeIndexed() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + assertEquals(0, builder.getZeroConstraintDocCount()); + builder.indexDocument(2, new BooleanPredicate(true)); + assertEquals(1, builder.getZeroConstraintDocCount()); + builder.build(); + } + + @Test + public void requireThatMultipleDocumentsCanBeIndexed() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(1, Predicate.fromString("a in ['b']")); + builder.indexDocument(2, Predicate.fromString("a in ['b']")); + builder.indexDocument(3, Predicate.fromString("a in ['b']")); + builder.indexDocument(4, Predicate.fromString("a in ['b']")); + builder.indexDocument(5, Predicate.fromString("a in ['b']")); + builder.build(); + } + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/PredicateIndexTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/PredicateIndexTest.java new file mode 100644 index 00000000000..25effda4cb9 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/PredicateIndexTest.java @@ -0,0 +1,141 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate; + +import com.yahoo.document.predicate.Predicate; +import org.junit.Test; + +import java.io.IOException; + +import static com.yahoo.search.predicate.serialization.SerializationTestHelper.assertSerializationDeserializationMatches; +import static java.util.stream.Collectors.toList; +import static org.junit.Assert.assertEquals; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class PredicateIndexTest { + + private static final int DOC_ID = 42; + + @Test + public void requireThatPredicateIndexCanSearch() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(1, Predicate.fromString("country in ['no', 'se'] and gender in ['male']")); + builder.indexDocument(0x3fffffe, Predicate.fromString("country in ['no'] and gender in ['female']")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + PredicateQuery query = new PredicateQuery(); + query.addFeature("country", "no"); + query.addFeature("gender", "male"); + assertEquals("[1]", searcher.search(query).collect(toList()).toString()); + query.addFeature("gender", "female"); + assertEquals("[1, 67108862]", searcher.search(query).collect(toList()).toString()); + } + + @Test + public void requireThatPredicateIndexCanSearchWithNotExpression() { + { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(1, Predicate.fromString("country in ['no'] and gender not in ['male']")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + PredicateQuery query = new PredicateQuery(); + query.addFeature("country", "no"); + query.addFeature("gender", "female"); + assertEquals("[1]", searcher.search(query).collect(toList()).toString()); + } + { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(DOC_ID, Predicate.fromString("country in ['no'] and gender in ['male']")); + builder.indexDocument(DOC_ID + 1, Predicate.fromString("country not in ['no']")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + assertEquals("[43]", searcher.search(query).collect(toList()).toString()); + query.addFeature("country", "no"); + assertEquals(0, searcher.search(query).count()); + } + { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(DOC_ID, Predicate.fromString("country not in ['no'] and gender not in ['male']")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + assertEquals(1, searcher.search(query).count()); + query.addFeature("country", "no"); + assertEquals(0, searcher.search(query).count()); + query.addFeature("gender", "male"); + assertEquals(0, searcher.search(query).count()); + + query = new PredicateQuery(); + query.addFeature("gender", "male"); + assertEquals(0, searcher.search(query).count()); + } + } + + @Test + public void requireThatSearchesCanUseSubqueries() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(DOC_ID, Predicate.fromString("country in [no] and gender in [male]")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + query.addFeature("country", "no", 0x3); + assertEquals(0, searcher.search(query).count()); + query.addFeature("gender", "male", 0x6); + assertEquals("[[42,0x2]]", searcher.search(query).collect(toList()).toString()); + } + + @Test + public void requireThatPredicateIndexCanSearchWithRange() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(1, Predicate.fromString("gender in ['male'] and age in [20..40]")); + builder.indexDocument(2, Predicate.fromString("gender in ['female'] and age in [20..40]")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + PredicateQuery query = new PredicateQuery(); + query.addFeature("gender", "male"); + query.addRangeFeature("age", 36); + assertEquals("[1]", searcher.search(query).collect(toList()).toString()); + query.addFeature("gender", "female"); + assertEquals("[1, 2]", searcher.search(query).collect(toList()).toString()); + } + + @Test + public void requireThatPredicateIndexCanSearchWithEmptyDocuments() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(1, Predicate.fromString("true")); + builder.indexDocument(2, Predicate.fromString("false")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + PredicateQuery query = new PredicateQuery(); + assertEquals("[1]", searcher.search(query).collect(toList()).toString()); + } + + @Test + public void requireThatPredicatesHavingMultipleIdenticalConjunctionsAreSupported() { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(DOC_ID, Predicate.fromString( + "((a in ['b'] and c in ['d']) or x in ['y']) and ((a in ['b'] and c in ['d']) or z in ['w'])")); + PredicateIndex index = builder.build(); + PredicateIndex.Searcher searcher = index.searcher(); + PredicateQuery query = new PredicateQuery(); + query.addFeature("a", "b"); + query.addFeature("c", "d"); + assertEquals("[42]", searcher.search(query).collect(toList()).toString()); + } + + @Test + public void require_that_serialization_and_deserialization_retain_data() throws IOException { + PredicateIndexBuilder builder = new PredicateIndexBuilder(10); + builder.indexDocument(1, Predicate.fromString("country in ['no', 'se'] and gender in ['male']")); + builder.indexDocument(0x3fffffe, Predicate.fromString("country in ['no'] and gender in ['female']")); + PredicateIndex index = builder.build(); + assertSerializationDeserializationMatches( + index, PredicateIndex::writeToOutputStream, PredicateIndex::fromInputStream); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzerTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzerTest.java new file mode 100644 index 00000000000..4d08c34b0b5 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/annotator/PredicateTreeAnalyzerTest.java @@ -0,0 +1,238 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.annotator; + +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.document.predicate.PredicateOperator; +import org.junit.Test; + +import java.util.Arrays; + +import static com.yahoo.document.predicate.Predicates.and; +import static com.yahoo.document.predicate.Predicates.feature; +import static com.yahoo.document.predicate.Predicates.not; +import static com.yahoo.document.predicate.Predicates.or; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertTrue; + +public class PredicateTreeAnalyzerTest { + + @Test + public void require_that_minfeature_is_1_for_simple_term() { + Predicate p = feature("foo").inSet("bar"); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(1, r.minFeature); + assertEquals(1, r.treeSize); + assertTrue(r.sizeMap.isEmpty()); + } + + @Test + public void require_that_minfeature_is_1_for_simple_negative_term() { + Predicate p = not(feature("foo").inSet("bar")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(1, r.minFeature); + } + + @Test + public void require_that_minfeature_is_sum_for_and() { + Predicate p = + and( + feature("foo").inSet("bar"), + feature("baz").inSet("qux"), + feature("quux").inSet("corge")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(3, r.minFeature); + assertEquals(3, r.treeSize); + assertEquals(3, r.sizeMap.size()); + assertSizeMapContains(r, pred(p).child(0), 1); + assertSizeMapContains(r, pred(p).child(1), 1); + assertSizeMapContains(r, pred(p).child(2), 1); + } + + @Test + public void require_that_minfeature_is_min_for_or() { + Predicate p = + or( + and( + feature("foo").inSet("bar"), + feature("baz").inSet("qux"), + feature("quux").inSet("corge")), + and( + feature("grault").inSet("garply"), + feature("waldo").inSet("fred"))); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(2, r.minFeature); + assertEquals(5, r.treeSize); + assertEquals(5, r.sizeMap.size()); + assertSizeMapContains(r, pred(p).child(0).child(0), 1); + assertSizeMapContains(r, pred(p).child(0).child(1), 1); + assertSizeMapContains(r, pred(p).child(0).child(2), 1); + assertSizeMapContains(r, pred(p).child(1).child(0), 1); + assertSizeMapContains(r, pred(p).child(1).child(1), 1); + } + + @Test + public void require_that_minfeature_rounds_up() { + Predicate p = + or( + feature("foo").inSet("bar"), + feature("foo").inSet("bar"), + feature("foo").inSet("bar")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(1, r.minFeature); + assertEquals(3, r.treeSize); + } + + @Test + public void require_that_minvalue_feature_set_considers_all_values() { + { + Predicate p = + and( + feature("foo").inSet("A", "B"), + feature("foo").inSet("B")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(1, r.minFeature); + assertEquals(2, r.treeSize); + } + { + Predicate p = + and( + feature("foo").inSet("A", "B"), + feature("foo").inSet("C")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(2, r.minFeature); + assertEquals(2, r.treeSize); + } + } + + @Test + public void require_that_not_features_dont_count_towards_minfeature_calculation() { + Predicate p = + and( + feature("foo").inSet("A"), + not(feature("foo").inSet("A")), + not(feature("foo").inSet("B")), + feature("foo").inSet("B")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(3, r.minFeature); + assertEquals(6, r.treeSize); + } + + @Test + public void require_that_multilevel_and_stores_size() { + Predicate p = + and( + and( + feature("foo").inSet("bar"), + feature("baz").inSet("qux"), + feature("quux").inSet("corge")), + and( + feature("grault").inSet("garply"), + feature("waldo").inSet("fred"))); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(5, r.minFeature); + assertEquals(5, r.treeSize); + assertEquals(7, r.sizeMap.size()); + assertSizeMapContains(r, pred(p).child(0), 3); + assertSizeMapContains(r, pred(p).child(1), 2); + assertSizeMapContains(r, pred(p).child(0).child(0), 1); + assertSizeMapContains(r, pred(p).child(0).child(1), 1); + assertSizeMapContains(r, pred(p).child(0).child(2), 1); + assertSizeMapContains(r, pred(p).child(1).child(0), 1); + assertSizeMapContains(r, pred(p).child(1).child(1), 1); + } + + @Test + public void require_that_not_ranges_dont_count_towards_minfeature_calculation() { + Predicate p = + and( + feature("foo").inRange(0, 10), + not(feature("foo").inRange(0, 10)), + feature("bar").inRange(0, 10), + not(feature("bar").inRange(0, 10))); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(3, r.minFeature); + assertEquals(6, r.treeSize); + } + + @Test + public void require_that_featureconjunctions_contribute_as_one_feature() { + Predicate p = + conj( + feature("foo").inSet("bar"), + feature("baz").inSet("qux")); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(1, r.minFeature); + assertEquals(1, r.treeSize); + } + + @Test + public void require_that_featureconjunctions_count_as_leaf_in_subtree_calculation() { + Predicate p = + and( + and( + feature("grault").inRange(0, 10), + feature("waldo").inRange(0, 10)), + conj( + feature("foo").inSet("bar"), + feature("baz").inSet("qux"), + feature("quux").inSet("corge"))); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(3, r.minFeature); + assertEquals(3, r.treeSize); + assertEquals(4, r.sizeMap.size()); + assertSizeMapContains(r, pred(p).child(0), 2); + assertSizeMapContains(r, pred(p).child(0).child(0), 1); + assertSizeMapContains(r, pred(p).child(0).child(1), 1); + assertSizeMapContains(r, pred(p).child(1), 1); + } + + @Test + public void require_that_multiple_indentical_feature_conjunctions_does_not_contribute_more_than_one() { + Predicate p = + and( + or( + conj( + feature("a").inSet("b"), + feature("c").inSet("d") + ), + feature("x").inSet("y")), + or( + conj( + feature("a").inSet("b"), + feature("c").inSet("d") + ), + feature("z").inSet("w"))); + PredicateTreeAnalyzerResult r = PredicateTreeAnalyzer.analyzePredicateTree(p); + assertEquals(1, r.minFeature); + assertEquals(4, r.treeSize); + } + + private static FeatureConjunction conj(Predicate... operands) { + return new FeatureConjunction(Arrays.asList(operands)); + } + + private static void assertSizeMapContains(PredicateTreeAnalyzerResult r, PredicateSelector selector, int expectedValue) { + Integer actualValue = r.sizeMap.get(selector.predicate); + assertNotNull(actualValue); + assertEquals(expectedValue, actualValue.intValue()); + } + + private static class PredicateSelector { + public final Predicate predicate; + + public PredicateSelector(Predicate predicate) { + this.predicate = predicate; + } + + public PredicateSelector child(int index) { + PredicateOperator op = (PredicateOperator) predicate; + return new PredicateSelector(op.getOperands().get(index)); + } + } + + private static PredicateSelector pred(Predicate p) { + return new PredicateSelector(p); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotatorTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotatorTest.java new file mode 100644 index 00000000000..7ccc910a4bf --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/annotator/PredicateTreeAnnotatorTest.java @@ -0,0 +1,271 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.annotator; + +import com.google.common.primitives.Ints; +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.FeatureRange; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.document.predicate.PredicateHash; +import com.yahoo.document.predicate.RangeEdgePartition; +import com.yahoo.document.predicate.RangePartition; +import com.yahoo.search.predicate.index.Feature; +import com.yahoo.search.predicate.index.conjunction.IndexableFeatureConjunction; +import com.yahoo.search.predicate.index.IntervalWithBounds; +import org.apache.commons.lang.ArrayUtils; +import org.junit.Test; + +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static com.yahoo.document.predicate.Predicates.and; +import static com.yahoo.document.predicate.Predicates.feature; +import static com.yahoo.document.predicate.Predicates.not; +import static com.yahoo.document.predicate.Predicates.or; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +public class PredicateTreeAnnotatorTest { + + @Test + public void require_that_or_intervals_are_the_same() { + Predicate p = + or( + feature("key1").inSet("value1"), + feature("key2").inSet("value2")); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(1, r.minFeature); + assertEquals(2, r.intervalEnd); + assertEquals(2, r.intervalMap.size()); + assertIntervalContains(r, "key1=value1", 0x00010002); + assertIntervalContains(r, "key2=value2", 0x00010002); + } + + @Test + public void require_that_ands_below_ors_get_different_intervals() { + Predicate p = + or( + and( + feature("key1").inSet("value1"), + feature("key1").inSet("value1"), + feature("key1").inSet("value1")), + and( + feature("key2").inSet("value2"), + feature("key2").inSet("value2"), + feature("key2").inSet("value2"))); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(1, r.minFeature); + assertEquals(6, r.intervalEnd); + assertEquals(2, r.intervalMap.size()); + assertIntervalContains(r, "key1=value1", 0x00010001, 0x00020002, 0x00030006); + assertIntervalContains(r, "key2=value2", 0x00010004, 0x00050005, 0x00060006); + } + + @Test + public void require_that_nots_get_correct_intervals() { + Predicate p = + and( + feature("key").inSet("value"), + not(feature("key").inSet("value")), + feature("key").inSet("value"), + not(feature("key").inSet("value"))); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(2, r.minFeature); + assertEquals(6, r.intervalEnd); + assertEquals(2, r.intervalMap.size()); + assertIntervalContains(r, "key=value", 0x00010001, 0x00020002, 0x00040004, 0x00050005); + assertIntervalContains(r, Feature.Z_STAR_COMPRESSED_ATTRIBUTE_NAME, 0x00020001, 0x00050004); + } + + @Test + public void require_that_final_first_not_interval_is_extended() { + Predicate p = not(feature("key").inSet("A")); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(1, r.minFeature); + assertEquals(2, r.intervalEnd); + assertEquals(2, r.intervalMap.size()); + assertIntervalContains(r, "key=A", 0x00010001); + assertIntervalContains(r, Feature.Z_STAR_COMPRESSED_ATTRIBUTE_NAME, 0x00010000); + } + + @Test + public void show_different_types_of_not_intervals() { + { + Predicate p = + and( + or( + and( + feature("key").inSet("A"), + not(feature("key").inSet("B"))), + and( + not(feature("key").inSet("C")), + feature("key").inSet("D"))), + feature("foo").inSet("bar")); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(3, r.minFeature); + assertEquals(7, r.intervalEnd); + assertEquals(6, r.intervalMap.size()); + + assertIntervalContains(r, "foo=bar", 0x00070007); + assertIntervalContains(r, "key=A", 0x00010001); + assertIntervalContains(r, "key=B", 0x00020002); + assertIntervalContains(r, "key=C", 0x00010004); + assertIntervalContains(r, "key=D", 0x00060006); + assertIntervalContains(r, Feature.Z_STAR_COMPRESSED_ATTRIBUTE_NAME, 0x00020001, 0x00000006, 0x00040000); + } + { + Predicate p = + or( + not(feature("key").inSet("A")), + not(feature("key").inSet("B"))); + + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(1, r.minFeature); + assertEquals(4, r.intervalEnd); + assertEquals(3, r.intervalMap.size()); + assertIntervalContains(r, "key=A", 0x00010003); + assertIntervalContains(r, "key=B", 0x00010003); + assertIntervalContains(r, Feature.Z_STAR_COMPRESSED_ATTRIBUTE_NAME, 0x00030000, 0x00030000); + } + { + Predicate p = + or( + and( + not(feature("key").inSet("A")), + not(feature("key").inSet("B"))), + and( + not(feature("key").inSet("C")), + not(feature("key").inSet("D")))); + + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(1, r.minFeature); + assertEquals(8, r.intervalEnd); + assertEquals(5, r.intervalMap.size()); + assertIntervalContains(r, "key=A", 0x00010001); + assertIntervalContains(r, "key=B", 0x00030007); + assertIntervalContains(r, "key=C", 0x00010005); + assertIntervalContains(r, "key=D", 0x00070007); + assertIntervalContains(r, Feature.Z_STAR_COMPRESSED_ATTRIBUTE_NAME, + 0x00010000, 0x00070002, 0x00050000, 0x00070006); + } + } + + @Test + public void require_that_hashed_ranges_get_correct_intervals() { + Predicate p = + and( + range("key", + partition("key=10-19"), + partition("key=20-29"), + edgePartition("key=0", 5, 10, 20), + edgePartition("key=30", 0, 0, 3)), + range("foo", + partition("foo=10-19"), + partition("foo=20-29"), + edgePartition("foo=0", 5, 40, 60), + edgePartition("foo=30", 0, 0, 3))); + + + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(2, r.minFeature); + assertEquals(2, r.intervalEnd); + assertEquals(4, r.intervalMap.size()); + assertEquals(4, r.boundsMap.size()); + assertIntervalContains(r, "key=10-19", 0x00010001); + assertIntervalContains(r, "key=20-29", 0x00010001); + assertBoundsContains(r, "key=0", bound(0x00010001, 0x000a0015)); // [10..20] + assertBoundsContains(r, "key=30", bound(0x00010001, 0x40000004)); // [..3] + + assertIntervalContains(r, "foo=10-19", 0x00020002); + assertIntervalContains(r, "foo=20-29", 0x00020002); + assertBoundsContains(r, "foo=0", bound(0x00020002, 0x0028003d)); // [40..60] + assertBoundsContains(r, "foo=30", bound(0x00020002, 0x40000004)); // [..3] + } + + @Test + public void require_that_extreme_ranges_works() { + Predicate p = + and( + range("max range", partition("max range=9223372036854775806-9223372036854775807")), + range("max edge", edgePartition("max edge=9223372036854775807", 0, 0, 1)), + range("min range", partition("min range=-9223372036854775807-9223372036854775806")), + range("min edge", edgePartition("min edge=-9223372036854775808", 0, 0, 1))); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(4, r.minFeature); + assertEquals(4, r.intervalEnd); + assertEquals(2, r.intervalMap.size()); + assertEquals(2, r.boundsMap.size()); + assertIntervalContains(r, "max range=9223372036854775806-9223372036854775807", 0x00010001); + assertBoundsContains(r, "max edge=9223372036854775807", bound(0x00020002, 0x40000002)); + assertIntervalContains(r, "min range=-9223372036854775807-9223372036854775806", 0x00030003); + assertBoundsContains(r, "min edge=-9223372036854775808", bound(0x00040004, 0x40000002)); + } + + @Test + public void require_that_featureconjunctions_are_registered_and_given_an_interval() { + Predicate p = + and( + or( + range("key", + partition("key=10-19"), + partition("key=20-29"), + edgePartition("key=0", 5, 10, 20), + edgePartition("key=30", 0, 0, 3)), + conj( + not(feature("keyA").inSet("C")), + feature("keyB").inSet("D"))), + feature("foo").inSet("bar")); + PredicateTreeAnnotations r = PredicateTreeAnnotator.createPredicateTreeAnnotations(p); + assertEquals(2, r.minFeature); + assertEquals(3, r.intervalEnd); + assertEquals(3, r.intervalMap.size()); + assertEquals(2, r.boundsMap.size()); + assertEquals(1, r.featureConjunctions.size()); + + Map.Entry<IndexableFeatureConjunction, List<Integer>> entry = r.featureConjunctions.entrySet().iterator().next(); + assertEquals(1, entry.getValue().size()); + assertEquals(0b1_0000000000000010, entry.getValue().get(0).longValue()); + } + + private static void assertIntervalContains(PredicateTreeAnnotations r, String feature, Integer... expectedIntervals) { + long hash = PredicateHash.hash64(feature); + List<Integer> actualIntervals = r.intervalMap.get(hash); + assertNotNull(actualIntervals); + assertArrayEquals(ArrayUtils.toPrimitive(expectedIntervals), Ints.toArray(actualIntervals)); + } + + private static void assertBoundsContains(PredicateTreeAnnotations r, String feature, IntervalWithBounds expectedBounds) { + long hash = PredicateHash.hash64(feature); + List<IntervalWithBounds> actualBounds = r.boundsMap.get(hash); + assertNotNull(actualBounds); + assertEquals(1, actualBounds.size()); + assertEquals(expectedBounds, actualBounds.get(0)); + } + + private static IntervalWithBounds bound(int interval, int bounds) { + return new IntervalWithBounds(interval, bounds); + } + + private static RangePartition partition(String label) { + return new RangePartition(label); + } + + private static RangePartition edgePartition(String label, long value, int lower, int upper) { + return new RangeEdgePartition(label, value, lower, upper); + } + + private static FeatureRange range(String key, RangePartition... partitions) { + return range(key, null, null, partitions); + } + + private static FeatureRange range(String key, Long lower, Long upper, RangePartition... partitions) { + FeatureRange range = new FeatureRange(key, lower, upper); + Arrays.asList(partitions).forEach(range::addPartition); + return range; + } + + private static FeatureConjunction conj(Predicate... operands) { + return new FeatureConjunction(Arrays.asList(operands)); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/BoundsPostingListTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/BoundsPostingListTest.java new file mode 100644 index 00000000000..06782e3603a --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/BoundsPostingListTest.java @@ -0,0 +1,84 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.google.common.primitives.Ints; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; + +import static java.util.stream.Collectors.toList; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * This test lies in com.yahoo.search.predicate to get access to some methods of PredicateIndex. + * + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class BoundsPostingListTest { + + @Test + public void requireThatPostingListChecksBounds() { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + List<Integer> docIds = new ArrayList<>(); + List<Integer> dataRefs = new ArrayList<>(); + for (int id = 1; id < 100; ++id) { + List<IntervalWithBounds> boundsList = new ArrayList<>(); + for (int i = 0; i <= id; ++i) { + int bounds; + if (id < 30) { + bounds = 0x80000000 | i; // diff >= i + } else if (id < 60) { + bounds = 0x40000000 | i; // diff < i + } else { + bounds = (i << 16) | (i + 10); // i < diff < i + 10 + } + boundsList.add(new IntervalWithBounds((i + 1) << 16 | 0xffff, bounds)); + } + docIds.add(id); + dataRefs.add(builder.insert(boundsList.stream().flatMap(IntervalWithBounds::stream).collect(toList()))); + } + + PredicateIntervalStore store = builder.build(); + BoundsPostingList postingList = new BoundsPostingList( + store, Ints.toArray(docIds), Ints.toArray(dataRefs), 0xffffffffffffffffL, 5); + assertEquals(-1, postingList.getDocId()); + assertEquals(0, postingList.getInterval()); + assertEquals(0xffffffffffffffffL, postingList.getSubquery()); + + checkNext(postingList, 0, 1, 2); // [0..] .. [1..] + checkNext(postingList, 1, 2, 3); // [0..] .. [2..] + checkNext(postingList, 10, 11, 6); // [0..] .. [5..] + checkNext(postingList, 20, 21, 6); + + checkNext(postingList, 30, 31, 26); // [..5] .. [..30] + checkNext(postingList, 50, 51, 46); + + checkNext(postingList, 60, 61, 6); // [0..10] .. [5..15] + + postingList = new BoundsPostingList(store, Ints.toArray(docIds), Ints.toArray(dataRefs), 0xffffffffffffffffL, 40); + checkNext(postingList, 0, 1, 2); + checkNext(postingList, 20, 21, 22); + + checkNext(postingList, 30, 31, 0); // skip ahead to match + checkNext(postingList, 32, 33, 0); // skip ahead to match + checkNext(postingList, 33, 34, 0); // skip ahead to match + checkNext(postingList, 40, 41, 1); + checkNext(postingList, 50, 51, 11); // [..40] .. [..50] + + checkNext(postingList, 60, 61, 10); // [31..40] .. [40..49] + } + + private void checkNext(BoundsPostingList postingList, int movePast, int docId, int intervalCount) { + assertTrue("Unable to move past " + movePast, postingList.nextDocument(movePast)); + assertEquals(intervalCount > 0, postingList.prepareIntervals()); + assertEquals(docId, postingList.getDocId()); + for (int i = 0; i < intervalCount - 1; ++i) { + assertTrue("Too few intervals, expected " + intervalCount, postingList.nextInterval()); + } + assertFalse("Too many intervals, expected " + intervalCount, postingList.nextInterval()); + } + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/CachedPostingListCounterTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/CachedPostingListCounterTest.java new file mode 100644 index 00000000000..d1b8dd01039 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/CachedPostingListCounterTest.java @@ -0,0 +1,116 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.gs.collections.impl.map.mutable.primitive.ObjectIntHashMap; +import org.apache.commons.lang.ArrayUtils; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * @author bjorncs + */ +public class CachedPostingListCounterTest { + + @Test + public void require_that_docids_are_counted_correctly() { + int nDocuments = 4; + byte[] nPostingListsPerDocument = new byte[nDocuments]; + CachedPostingListCounter c = new CachedPostingListCounter(nDocuments); + c.countPostingListsPerDocument( + list( + postingList(0, 1, 2, 3), + postingList(1, 2), + postingList(1, 3), + postingList(3)), + nPostingListsPerDocument); + assertArrayEquals(new byte[]{1, 3, 2, 3}, nPostingListsPerDocument); + } + + @Test + public void require_that_most_costly_posting_lists_are_first_in_bit_vector() { + int nDocuments = 5; + CachedPostingListCounter c = new CachedPostingListCounter(nDocuments); + List<PostingList> list = new ArrayList<>(); + PostingList p1 = postingList(1, 2, 4); + PostingList p2 = postingList(0, 1, 2, 3, 4); + PostingList p3 = postingList(1, 2, 3, 4); + PostingList p4 = postingList(3, 4); + list.add(p1); list.add(p2); list.add(p3); list.add(p4); + for (int i = 0; i < 100; i++) { + list.add(postingList(0)); + } + c.registerUsage(list); + CachedPostingListCounter newC = c.rebuildCache(); + ObjectIntHashMap<int[]> mapping = newC.getPostingListMapping(); + assertEquals(0, mapping.getIfAbsent(p2.getDocIds(), -1)); + assertEquals(1, mapping.getIfAbsent(p3.getDocIds(), -1)); + assertEquals(2, mapping.getIfAbsent(p1.getDocIds(), -1)); + assertEquals(3, mapping.getIfAbsent(p4.getDocIds(), -1)); + + int[] bitVector = newC.getBitVector(); + assertEquals(0b0001, bitVector[0] & 0b1111); + assertEquals(0b0111, bitVector[1] & 0b1111); + assertEquals(0b0111, bitVector[2] & 0b1111); + assertEquals(0b1011, bitVector[3] & 0b1111); + assertEquals(0b1111, bitVector[4] & 0b1111); + } + + @Test + public void require_that_cached_docids_are_counted_correctly() { + int nDocuments = 4; + byte[] nPostingListsPerDocument = new byte[nDocuments]; + CachedPostingListCounter c = new CachedPostingListCounter(nDocuments); + PostingList p1 = postingList(0, 1, 2, 3); + PostingList p2 = postingList(1, 2); + PostingList p3 = postingList(1, 3); + PostingList p4 = postingList(3); + List<PostingList> postingLists = list(p1, p2, p3, p4); + c.registerUsage(postingLists); + CachedPostingListCounter newC = c.rebuildCache(); + newC.countPostingListsPerDocument(postingLists, nPostingListsPerDocument); + assertArrayEquals(new byte[]{1, 3, 2, 3}, nPostingListsPerDocument); + newC.countPostingListsPerDocument(list(p1, p2), nPostingListsPerDocument); + assertArrayEquals(new byte[]{1, 2, 2, 1}, nPostingListsPerDocument); + } + + @Test + public void require_that_cache_rebuilding_behaves_correctly_for_large_amount_of_posting_lists() { + int nDocuments = 4; + byte[] nPostingListsPerDocument = new byte[nDocuments]; + CachedPostingListCounter c = new CachedPostingListCounter(nDocuments); + List<PostingList> postingLists = new ArrayList<>(100 * nDocuments); + for (int i = 0; i < 100 * nDocuments; i++) { + postingLists.add(postingList(i % nDocuments)); + } + c.registerUsage(postingLists); + CachedPostingListCounter newC = c.rebuildCache(); + newC.countPostingListsPerDocument(postingLists, nPostingListsPerDocument); + assertArrayEquals(new byte[]{100, 100, 100, 100}, nPostingListsPerDocument); + + List<PostingList> doc0PostingLists = new ArrayList<>(); + for (int i = 0; i < 100 * nDocuments; i += nDocuments) { + doc0PostingLists.add(postingLists.get(i)); + } + newC.countPostingListsPerDocument(doc0PostingLists, nPostingListsPerDocument); + assertArrayEquals(new byte[]{100, 0, 0, 0}, nPostingListsPerDocument); + } + + private static List<PostingList> list(PostingList... postingLists) { + return Arrays.asList(postingLists); + } + + private static PostingList postingList(Integer... docIds) { + PostingList postingList = mock(PostingList.class); + when(postingList.getDocIds()).thenReturn(ArrayUtils.toPrimitive(docIds)); + return postingList; + } + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/IntervalPostingListTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/IntervalPostingListTest.java new file mode 100644 index 00000000000..41f4ba55750 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/IntervalPostingListTest.java @@ -0,0 +1,43 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.search.predicate.SubqueryBitmap; +import org.junit.Test; + +import java.util.Arrays; + +import static junit.framework.TestCase.assertFalse; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class IntervalPostingListTest { + @Test + public void requireThatPostingListCanIterate() { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + int ref1 = builder.insert(Arrays.asList(0x1ffff)); + int ref2 = builder.insert(Arrays.asList(0x1ffff)); + int ref3 = builder.insert(Arrays.asList(0x10001, 0x2ffff)); + IntervalPostingList postingList = new IntervalPostingList( + builder.build(), new int[]{2, 4, 6}, new int[] {ref1, ref2, ref3}, SubqueryBitmap.ALL_SUBQUERIES); + assertEquals(-1, postingList.getDocId()); + assertEquals(0, postingList.getInterval()); + assertEquals(0xffffffffffffffffL, postingList.getSubquery()); + + assertTrue(postingList.nextDocument(0)); + assertTrue(postingList.prepareIntervals()); + assertEquals(2, postingList.getDocId()); + assertEquals(0x1ffff, postingList.getInterval()); + assertFalse(postingList.nextInterval()); + + assertTrue(postingList.nextDocument(4)); + assertTrue(postingList.prepareIntervals()); + assertEquals(6, postingList.getDocId()); + assertEquals(0x10001, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x2ffff, postingList.getInterval()); + assertFalse(postingList.nextInterval()); + + assertFalse(postingList.nextDocument(8)); + } + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateIntervalStoreTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateIntervalStoreTest.java new file mode 100644 index 00000000000..1f2fd390cb5 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateIntervalStoreTest.java @@ -0,0 +1,82 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.google.common.primitives.Ints; +import org.junit.Test; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static com.yahoo.search.predicate.serialization.SerializationTestHelper.assertSerializationDeserializationMatches; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +/** + * @author bjorncs + */ +public class PredicateIntervalStoreTest { + + @Test(expected = IllegalArgumentException.class) + public void requireThatEmptyIntervalListThrows() { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + builder.insert(new ArrayList<>()); + } + + @Test + public void requireThatSingleIntervalCanBeInserted() { + testInsertAndRetrieve(0x0001ffff); + } + + @Test + public void requireThatMultiIntervalEntriesCanBeInserted() { + testInsertAndRetrieve(0x00010001, 0x00020002, 0x0003ffff); + testInsertAndRetrieve(0x00010001, 0x00020002, 0x00030003, 0x00040004, 0x00050005, 0x00060006, + 0x00070007, 0x00080008, 0x00090009, 0x000a000a); + } + + @Test + public void requireThatDifferentSizeIntervalArraysCanBeInserted() { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + int intervals1[] = new int[] {0x00010001, 0x00020002}; + int intervals2[] = new int[] {0x00010001, 0x00020002, 0x00030003}; + assertEquals(0, builder.insert(Ints.asList(intervals1))); + assertEquals(1, builder.insert(Ints.asList(intervals2))); + } + + @Test + public void requireThatSerializationAndDeserializationRetainIntervals() throws IOException { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + builder.insert(Arrays.asList(0x00010001, 0x00020002)); + builder.insert(Arrays.asList(0x00010001, 0x00020002, 0x00030003)); + builder.insert(Arrays.asList(0x0fffffff, 0x00020002, 0x00030003)); + PredicateIntervalStore store = builder.build(); + assertSerializationDeserializationMatches( + store, PredicateIntervalStore::writeToOutputStream, PredicateIntervalStore::fromInputStream); + } + + @Test + public void requireThatEqualIntervalListsReturnsSameReference() { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + List<Integer> intervals1 = Arrays.asList(0x00010001, 0x00020002); + List<Integer> intervals2 = Arrays.asList(0x00010001, 0x00020002); + int ref1 = builder.insert(intervals1); + int ref2 = builder.insert(intervals2); + PredicateIntervalStore store = builder.build(); + int[] a1 = store.get(ref1); + int[] a2 = store.get(ref2); + assertTrue(a1 == a2); + } + + private static void testInsertAndRetrieve(int... intervals) { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + int ref = builder.insert(Ints.asList(intervals)); + PredicateIntervalStore store = builder.build(); + + int retrieved[] = store.get(ref); + assertArrayEquals(intervals, retrieved); + } + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateRangeTermExpanderTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateRangeTermExpanderTest.java new file mode 100644 index 00000000000..8eacf126bd7 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateRangeTermExpanderTest.java @@ -0,0 +1,354 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.document.predicate.PredicateHash; +import org.junit.Test; + +import java.util.Arrays; +import java.util.Iterator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.fail; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class PredicateRangeTermExpanderTest { + @Test + public void requireThatSmallRangeIsExpanded() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10); + Iterator<String> expectedLabels = Arrays.asList( + "key=40-49", + "key=0-99", + "key=0-999", + "key=0-9999", + "key=0-99999", + "key=0-999999", + "key=0-9999999", + "key=0-99999999", + "key=0-999999999", + "key=0-9999999999", + "key=0-99999999999", + "key=0-999999999999", + "key=0-9999999999999", + "key=0-99999999999999", + "key=0-999999999999999", + "key=0-9999999999999999", + "key=0-99999999999999999", + "key=0-999999999999999999").iterator(); + expander.expand("key", 42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=40"), edge); assertEquals(2, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatLargeRangeIsExpanded() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10); + Iterator<String> expectedLabels = Arrays.asList( + "key=123456789012345670-123456789012345679", + "key=123456789012345600-123456789012345699", + "key=123456789012345000-123456789012345999", + "key=123456789012340000-123456789012349999", + "key=123456789012300000-123456789012399999", + "key=123456789012000000-123456789012999999", + "key=123456789010000000-123456789019999999", + "key=123456789000000000-123456789099999999", + "key=123456789000000000-123456789999999999", + "key=123456780000000000-123456789999999999", + "key=123456700000000000-123456799999999999", + "key=123456000000000000-123456999999999999", + "key=123450000000000000-123459999999999999", + "key=123400000000000000-123499999999999999", + "key=123000000000000000-123999999999999999", + "key=120000000000000000-129999999999999999", + "key=100000000000000000-199999999999999999", + "key=0-999999999999999999").iterator(); + expander.expand("key", 123456789012345678L, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=123456789012345670"), edge); assertEquals(8, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatMaxRangeIsExpanded() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10); + expander.expand("key", 9223372036854775807L, range -> fail(), + (edge, value) -> { + assertEquals(PredicateHash.hash64("key=9223372036854775800"), edge); + assertEquals(7, value); + }); + } + + @Test + public void requireThatSmallNegativeRangeIsExpanded() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10); + Iterator<String> expectedLabels = Arrays.asList( + "key=-49-40", + "key=-99-0", + "key=-999-0", + "key=-9999-0", + "key=-99999-0", + "key=-999999-0", + "key=-9999999-0", + "key=-99999999-0", + "key=-999999999-0", + "key=-9999999999-0", + "key=-99999999999-0", + "key=-999999999999-0", + "key=-9999999999999-0", + "key=-99999999999999-0", + "key=-999999999999999-0", + "key=-9999999999999999-0", + "key=-99999999999999999-0", + "key=-999999999999999999-0").iterator(); + expander.expand("key", -42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=-40"), edge); assertEquals(2, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatMinRangeIsExpanded() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10); + expander.expand("key", -9223372036854775808L, range -> fail(), + (edge, value) -> { + assertEquals(PredicateHash.hash64("key=-9223372036854775800"), edge); + assertEquals(8, value); + }); + } + + @Test + public void requireThatMinRangeMinus9IsExpanded() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10); + Iterator<String> expectedLabels = Arrays.asList( + "key=-9223372036854775799-9223372036854775790", + "key=-9223372036854775799-9223372036854775700").iterator(); + expander.expand("key", -9223372036854775799L, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=-9223372036854775790"), edge); assertEquals(9, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatMinRangeIsExpandedWithArity8() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(8); + expander.expand("key", -9223372036854775808L, range -> fail(), + (edge, value) -> { + assertEquals(PredicateHash.hash64("key=-9223372036854775808"), edge); + assertEquals(0, value); + }); + } + + @Test + public void requireThatSmallRangeIsExpandedInArity2() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(2); + Iterator<String> expectedLabels = Arrays.asList( + "key=42-43", + "key=40-43", + "key=40-47", + "key=32-47", + "key=32-63", + "key=0-63", + "key=0-127", + "key=0-255", + "key=0-511", + "key=0-1023", + "key=0-2047", + "key=0-4095", + "key=0-8191", + "key=0-16383", + "key=0-32767", + "key=0-65535", + "key=0-131071", + "key=0-262143", + "key=0-524287", + "key=0-1048575", + "key=0-2097151", + "key=0-4194303", + "key=0-8388607", + "key=0-16777215", + "key=0-33554431", + "key=0-67108863", + "key=0-134217727", + "key=0-268435455", + "key=0-536870911", + "key=0-1073741823", + "key=0-2147483647", + "key=0-4294967295", + "key=0-8589934591", + "key=0-17179869183", + "key=0-34359738367", + "key=0-68719476735", + "key=0-137438953471", + "key=0-274877906943", + "key=0-549755813887", + "key=0-1099511627775", + "key=0-2199023255551", + "key=0-4398046511103", + "key=0-8796093022207", + "key=0-17592186044415", + "key=0-35184372088831", + "key=0-70368744177663", + "key=0-140737488355327", + "key=0-281474976710655", + "key=0-562949953421311", + "key=0-1125899906842623", + "key=0-2251799813685247", + "key=0-4503599627370495", + "key=0-9007199254740991", + "key=0-18014398509481983", + "key=0-36028797018963967", + "key=0-72057594037927935", + "key=0-144115188075855871", + "key=0-288230376151711743", + "key=0-576460752303423487", + "key=0-1152921504606846975", + "key=0-2305843009213693951", + "key=0-4611686018427387903", + "key=0-9223372036854775807").iterator(); + expander.expand("key", 42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=42"), edge); assertEquals(0, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatSmallNegativeRangeIsExpandedInArity2() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(2); + Iterator<String> expectedLabels = Arrays.asList( + "key=-43-42", + "key=-43-40", + "key=-47-40", + "key=-47-32", + "key=-63-32", + "key=-63-0", + "key=-127-0", + "key=-255-0", + "key=-511-0", + "key=-1023-0", + "key=-2047-0", + "key=-4095-0", + "key=-8191-0", + "key=-16383-0", + "key=-32767-0", + "key=-65535-0", + "key=-131071-0", + "key=-262143-0", + "key=-524287-0", + "key=-1048575-0", + "key=-2097151-0", + "key=-4194303-0", + "key=-8388607-0", + "key=-16777215-0", + "key=-33554431-0", + "key=-67108863-0", + "key=-134217727-0", + "key=-268435455-0", + "key=-536870911-0", + "key=-1073741823-0", + "key=-2147483647-0", + "key=-4294967295-0", + "key=-8589934591-0", + "key=-17179869183-0", + "key=-34359738367-0", + "key=-68719476735-0", + "key=-137438953471-0", + "key=-274877906943-0", + "key=-549755813887-0", + "key=-1099511627775-0", + "key=-2199023255551-0", + "key=-4398046511103-0", + "key=-8796093022207-0", + "key=-17592186044415-0", + "key=-35184372088831-0", + "key=-70368744177663-0", + "key=-140737488355327-0", + "key=-281474976710655-0", + "key=-562949953421311-0", + "key=-1125899906842623-0", + "key=-2251799813685247-0", + "key=-4503599627370495-0", + "key=-9007199254740991-0", + "key=-18014398509481983-0", + "key=-36028797018963967-0", + "key=-72057594037927935-0", + "key=-144115188075855871-0", + "key=-288230376151711743-0", + "key=-576460752303423487-0", + "key=-1152921504606846975-0", + "key=-2305843009213693951-0", + "key=-4611686018427387903-0", + "key=-9223372036854775807-0").iterator(); + expander.expand("key", -42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=-42"), edge); assertEquals(0, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatUpperBoundIsUsed() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10, -99, 9999); + Iterator<String> expectedLabels = Arrays.asList( + "key=40-49", + "key=0-99", + "key=0-999", + "key=0-9999").iterator(); + expander.expand("key", 42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=40"), edge); assertEquals(2, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatLowerBoundIsUsed() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10, -9999, 99); + Iterator<String> expectedLabels = Arrays.asList( + "key=-49-40", + "key=-99-0", + "key=-999-0", + "key=-9999-0").iterator(); + expander.expand("key", -42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=-40"), edge); assertEquals(2, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatSearchesOutsideBoundsGenerateNoLabels() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10, 0, 200); + expander.expand("key", -10, x -> fail(), (x,y) -> fail()); + expander.expand("key", 210, x -> fail(), (x, y) -> fail()); + } + + @Test + public void requireThatUpperAndLowerBoundGreaterThan0Works() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10, 100, 9999); + Iterator<String> expectedLabels = Arrays.asList( + "key=140-149", + "key=100-199", + "key=0-999", + "key=0-9999").iterator(); + expander.expand("key", 142, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=140"), edge); assertEquals(2, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatSearchCloseToUnevenUpperBoundIsSensible() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10, -99, 1234); + Iterator<String> expectedLabels = Arrays.asList( + "key=40-49", + "key=0-99", + "key=0-999", + "key=0-9999").iterator(); + expander.expand("key", 42, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=40"), edge); assertEquals(2, value); }); + assertFalse(expectedLabels.hasNext()); + } + + @Test + public void requireThatSearchCloseToMaxUnevenUpperBoundIsSensible() { + PredicateRangeTermExpander expander = new PredicateRangeTermExpander(10, 0, 9223372036854771234L); + Iterator<String> expectedLabels = Arrays.asList( + "key=9223372036854770000-9223372036854770009", + "key=9223372036854770000-9223372036854770099", + "key=9223372036854770000-9223372036854770999").iterator(); + expander.expand("key", 9223372036854770000L, range -> assertEquals(PredicateHash.hash64(expectedLabels.next()), range), + (edge, value) -> { assertEquals(PredicateHash.hash64("key=9223372036854770000"), edge); assertEquals(0, value); }); + assertFalse(expectedLabels.hasNext()); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateSearchTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateSearchTest.java new file mode 100644 index 00000000000..64a44ff3680 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/PredicateSearchTest.java @@ -0,0 +1,305 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import com.yahoo.search.predicate.Hit; +import com.yahoo.search.predicate.SubqueryBitmap; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +import static java.util.stream.Collectors.toList; +import static org.junit.Assert.assertEquals; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class PredicateSearchTest { + + @Test + public void requireThatNoStreamsReturnNoResults() { + PredicateSearch search = new PredicateSearch(new ArrayList<>(), new byte[0], new byte[0], new short[0], 1); + assertEquals(0, search.stream().count()); + } + + @Test + public void requireThatSingleStreamFiltersOnConstructedCompleteIntervals() { + PredicateSearch search = createPredicateSearch( + new byte[]{1, 1, 1}, + postingList( + SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000100ff), + entry(1, 0x00010001, 0x000200ff), + entry(2, 0x00010042))); + assertEquals(Arrays.asList(new Hit(0), new Hit(1)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatMinFeatureIsUsedToPruneResults() { + PredicateSearch search = createPredicateSearch( + new byte[]{3, 1}, + postingList( + SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000100ff), + entry(1, 0x000100ff))); + assertEquals(Arrays.asList(new Hit(1)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatAHighKCanYieldResults() { + PredicateSearch search = createPredicateSearch( + new byte[]{2}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000200ff))); + assertEquals(Arrays.asList(new Hit(0)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatPostingListsAreSortedAfterAdvancing() { + PredicateSearch search = createPredicateSearch( + new byte[] {2, 1, 1, 1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000100ff), + entry(3, 0x000100ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(1, 0x000100ff), + entry(2, 0x000100ff))); + assertEquals(Arrays.asList(new Hit(1), new Hit(2), new Hit(3)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatEmptyPostingListsWork() { + PredicateSearch search = createPredicateSearch( + new byte[0], + postingList(SubqueryBitmap.ALL_SUBQUERIES)); + assertEquals(Arrays.asList().toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatShorterPostingListEndingIsOk() { + PredicateSearch search = createPredicateSearch( + new byte[]{1, 1, 1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000100ff), + entry(1, 0x000100ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(2, 0x000100ff))); + assertEquals(Arrays.asList(new Hit(0), new Hit(1), new Hit(2)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatSortingWorksForManyPostingLists() { + PredicateSearch search = createPredicateSearch( + new byte[]{1, 5, 2, 2}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000100ff), + entry(1, 0x000100ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(1, 0x000100ff), + entry(2, 0x000100ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(1, 0x000100ff), + entry(3, 0x000100ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(1, 0x000100ff), + entry(2, 0x000100ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(1, 0x000100ff), + entry(3, 0x000100ff))); + assertEquals( + Arrays.asList(new Hit(0), new Hit(1), new Hit(2), new Hit(3)).toString(), + search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatInsufficientIntervalCoveragePreventsMatch() { + PredicateSearch search = createPredicateSearch( + new byte[]{1, 1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001), + entry(1, 0x000200ff))); + assertEquals(Arrays.asList().toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatIntervalsAreSorted() { + PredicateSearch search = createPredicateSearch( + new byte[]{1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x000300ff)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00020002))); + assertEquals(Arrays.asList(new Hit(0)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatThereCanBeManyIntervals() { + PredicateSearch search = createPredicateSearch( + new byte[]{1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001, 0x00020002, 0x00030003, 0x000100ff, 0x00040004, 0x00050005, 0x00060006))); + assertEquals(Arrays.asList(new Hit(0)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatNotIsSupported_NoMatch() { + PredicateSearch search = createPredicateSearch( + new byte[]{1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010000, 0x00ff0001))); + assertEquals(Arrays.asList().toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatNotIsSupported_Match() { + PredicateSearch search = createPredicateSearch( + new byte[]{1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010000, 0x00ff0001))); + assertEquals(Arrays.asList(new Hit(0)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatNotIsSupported_NoMatchBecauseOfPreviousTerm() { + PredicateSearch search = createPredicateSearch( + new byte[]{1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00020001, 0x00ff0001))); + assertEquals(Arrays.asList().toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatIntervalSortingWorksAsUnsigned() { + PredicateSearch search = createPredicateSearch( + new byte[]{1}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00fe0001, 0x00ff00fe))); + assertEquals(Arrays.asList(new Hit(0)).toString(), search.stream().collect(toList()).toString()); + } + + @Test + public void requireThatMatchCanRequireMultiplePostingLists() { + PredicateSearch search = createPredicateSearch( + new byte[]{6}, + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010001)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x0002000b, 0x00030003)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00040003)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00050004)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00010008, 0x00060006)), + postingList(SubqueryBitmap.ALL_SUBQUERIES, + entry(0, 0x00020002, 0x000700ff))); + assertEquals(Arrays.asList(new Hit(0)).toString(), search.stream().collect(toList()).toString()); + } + + private static PredicateSearch createPredicateSearch(byte[] minFeatures, PostingList... postingLists) { + byte[] nPostingListsForDocument = new byte[minFeatures.length]; + short[] intervalEnds = new short[minFeatures.length]; + Arrays.fill(intervalEnds, (short) 0xFF); + List<PostingList> list = Arrays.asList(postingLists); + for (PostingList postingList : postingLists) { + for (int id : postingList.getDocIds()) { + nPostingListsForDocument[id]++; + } + } + return new PredicateSearch(list, nPostingListsForDocument, minFeatures, intervalEnds, 0xFF); + } + + private static class SimplePostingList implements PostingList { + private final long subquery; + private final Entry[] entries; + private int[] currentIntervals; + private int currentIntervalIndex; + private int currentDocId; + private int currentIndex; + + public SimplePostingList(long subquery, Entry... entries) { + this.subquery = subquery; + this.entries = entries; + this.currentIndex = 0; + this.currentDocId = -1; + } + + @Override + public boolean nextDocument(int docId) { + while (currentIndex < entries.length && entries[currentIndex].docId <= docId) { + ++currentIndex; + } + if (currentIndex == entries.length) { + return false; + } + Entry entry = entries[currentIndex]; + currentDocId = entry.docId; + currentIntervals = entry.intervals; + currentIntervalIndex = 0; + return true; + } + + @Override + public boolean prepareIntervals() { + return true; + } + + @Override + public boolean nextInterval() { + return ++currentIntervalIndex < currentIntervals.length; + } + + @Override + public int getDocId() { + return currentDocId; + } + + @Override + public int size() { + return entries.length; + } + + @Override + public int getInterval() { + return currentIntervals[currentIntervalIndex]; + } + + @Override + public long getSubquery() { + return subquery; + } + + @Override + public int[] getDocIds() { + return Arrays.stream(entries).mapToInt(e -> e.docId).toArray(); + } + + public static class Entry { + public final int docId; + public final int[] intervals; + + public Entry(int docId, int... intervals) { + this.docId = docId; + this.intervals = intervals; + } + } + } + + private static SimplePostingList postingList(long subquery, SimplePostingList.Entry... entries) { + return new SimplePostingList(subquery, entries); + } + + private static SimplePostingList.Entry entry(int docId, int... intervals) { + return new SimplePostingList.Entry(docId, intervals); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/SimpleIndexTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/SimpleIndexTest.java new file mode 100644 index 00000000000..3f6b803c33a --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/SimpleIndexTest.java @@ -0,0 +1,65 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import org.junit.Test; + +import java.io.IOException; + +import static com.yahoo.search.predicate.serialization.SerializationTestHelper.assertSerializationDeserializationMatches; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotNull; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + * @author bjorncs + */ +public class SimpleIndexTest { + + private static final long KEY = 0x12345L; + private static final int DOC_ID = 42; + + @Test + public void requireThatValuesCanBeInserted() { + SimpleIndex.Builder builder = new SimpleIndex.Builder(); + builder.insert(KEY, new Posting(DOC_ID, 10)); + SimpleIndex index = builder.build(); + SimpleIndex.Entry e = index.getPostingList(KEY); + assertNotNull(e); + assertEquals(1, e.docIds.length); + + builder = new SimpleIndex.Builder(); + builder.insert(KEY, new Posting(DOC_ID, 10)); + builder.insert(KEY, new Posting(DOC_ID + 1, 20)); + index = builder.build(); + e = index.getPostingList(KEY); + assertEquals(2, e.docIds.length); + assertEquals(10, e.dataRefs[0]); + assertEquals(20, e.dataRefs[1]); + } + + @Test + public void requireThatEntriesAreSortedOnId() { + SimpleIndex.Builder builder = new SimpleIndex.Builder(); + builder.insert(KEY, new Posting(DOC_ID, 10)); + builder.insert(KEY, new Posting(DOC_ID - 1, 20)); // Out of order + builder.insert(KEY, new Posting(DOC_ID + 1, 30)); + SimpleIndex index = builder.build(); + SimpleIndex.Entry entry = index.getPostingList(KEY); + assertEquals(3, entry.docIds.length); + assertEquals(DOC_ID - 1, entry.docIds[0]); + assertEquals(DOC_ID, entry.docIds[1]); + assertEquals(DOC_ID + 1, entry.docIds[2]); + } + + @Test + public void requireThatSerializationAndDeserializationRetainDictionary() throws IOException { + SimpleIndex.Builder builder = new SimpleIndex.Builder(); + builder.insert(KEY, new Posting(DOC_ID, 10)); + builder.insert(KEY, new Posting(DOC_ID + 1, 20)); + builder.insert(KEY, new Posting(DOC_ID + 2, 30)); + builder.insert(KEY + 0xFFFFFF, new Posting(DOC_ID, 100)); + builder.insert(KEY + 0xFFFFFF, new Posting(DOC_ID + 1, 200)); + SimpleIndex index = builder.build(); + assertSerializationDeserializationMatches(index, SimpleIndex::writeToOutputStream, SimpleIndex::fromInputStream); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/ZeroConstraintPostingListTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/ZeroConstraintPostingListTest.java new file mode 100644 index 00000000000..652441b796a --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/ZeroConstraintPostingListTest.java @@ -0,0 +1,36 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class ZeroConstraintPostingListTest { + + @Test + public void requireThatPostingListCanIterate() { + ZeroConstraintPostingList postingList = + new ZeroConstraintPostingList(new int[] {2, 4, 6, 8}); + assertEquals(-1, postingList.getDocId()); + assertEquals(Interval.fromBoundaries(1, Interval.ZERO_CONSTRAINT_RANGE), postingList.getInterval()); + assertEquals(0xffffffffffffffffL, postingList.getSubquery()); + + assertTrue(postingList.nextDocument(0)); + assertEquals(2, postingList.getDocId()); + assertTrue(postingList.prepareIntervals()); + assertFalse(postingList.nextInterval()); + + assertTrue(postingList.nextDocument(7)); + assertEquals(8, postingList.getDocId()); + + assertTrue(postingList.nextDocument(7)); + assertEquals(8, postingList.getDocId()); + + assertFalse(postingList.nextDocument(8)); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/ZstarCompressedPostingListTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/ZstarCompressedPostingListTest.java new file mode 100644 index 00000000000..3eb389757e3 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/ZstarCompressedPostingListTest.java @@ -0,0 +1,62 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index; + +import org.junit.Test; + +import java.util.Arrays; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author <a href="mailto:magnarn@yahoo-inc.com">Magnar Nedland</a> + */ +public class ZstarCompressedPostingListTest { + @Test + public void requireThatPostingListCanIterate() { + PredicateIntervalStore.Builder builder = new PredicateIntervalStore.Builder(); + int ref1 = builder.insert(Arrays.asList(0x10000)); + int ref2 = builder.insert(Arrays.asList(0x10000, 0x0ffff)); + int ref3 = builder.insert(Arrays.asList(0x10000, 0x00003, 0x40003, 0x60005)); + ZstarCompressedPostingList postingList = new ZstarCompressedPostingList( + builder.build(), new int[]{2, 4, 6}, new int[]{ref1, ref2, ref3}); + assertEquals(-1, postingList.getDocId()); + assertEquals(0, postingList.getInterval()); + assertEquals(0xffffffffffffffffL, postingList.getSubquery()); + + assertTrue(postingList.nextDocument(0)); + assertTrue(postingList.prepareIntervals()); + assertEquals(2, postingList.getDocId()); + assertEquals(0x10000, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x20001, postingList.getInterval()); + assertFalse(postingList.nextInterval()); + + assertTrue(postingList.nextDocument(2)); + assertTrue(postingList.prepareIntervals()); + assertEquals(4, postingList.getDocId()); + assertEquals(0x00010000, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0xffff0001, postingList.getInterval()); + assertFalse(postingList.nextInterval()); + + assertTrue(postingList.nextDocument(4)); + assertTrue(postingList.prepareIntervals()); + assertEquals(6, postingList.getDocId()); + assertEquals(0x10000, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x30001, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x40003, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x50004, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x60005, postingList.getInterval()); + assertTrue(postingList.nextInterval()); + assertEquals(0x70006, postingList.getInterval()); + assertFalse(postingList.nextInterval()); + + assertFalse(postingList.nextDocument(6)); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIdIteratorTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIdIteratorTest.java new file mode 100644 index 00000000000..d324faec50a --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIdIteratorTest.java @@ -0,0 +1,56 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +import com.yahoo.search.predicate.SubqueryBitmap; +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +/** + * @author bjorncs + */ +public class ConjunctionIdIteratorTest { + + @SuppressWarnings("PointlessBitwiseExpression") + @Test + public void require_that_next_returns_skips_to_correct_value() { + // NOTE: LST bit represents the conjunction sign: 0 => negative, 1 => positive. + int[] conjunctionIds = new int[]{ + 0 | 1, + 2 | 0, + 4 | 0, + 6 | 1, + 8 | 1, + 10 | 0}; + + ConjunctionIdIterator postingList = + new ConjunctionIdIterator(SubqueryBitmap.ALL_SUBQUERIES, conjunctionIds); + + assertEquals(1, postingList.getConjunctionId()); + assertEquals(1, postingList.getConjunctionId()); // Should not change. + + assertTrue(postingList.next(2)); + assertEquals(2, postingList.getConjunctionId()); + assertTrue(postingList.next(0)); // Should not change current conjunction id + assertEquals(2, postingList.getConjunctionId()); + + assertTrue(postingList.next(6 | 1)); // Should skip past id 4 + assertEquals(7, postingList.getConjunctionId()); + + assertTrue(postingList.next(8)); // Should skip to 9 + assertEquals(9, postingList.getConjunctionId()); + + assertTrue(postingList.next(10 | 1)); + assertEquals(10, postingList.getConjunctionId()); + + assertFalse(postingList.next(12)); // End of posting list + } + + @Test + public void require_that_subquery_is_correct() { + ConjunctionIdIterator iterator = new ConjunctionIdIterator(0b1111, new int[]{1}); + assertEquals(0b1111, iterator.getSubqueryBitmap()); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndexTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndexTest.java new file mode 100644 index 00000000000..70fd8b4b6f5 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/index/conjunction/ConjunctionIndexTest.java @@ -0,0 +1,375 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.index.conjunction; + +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.Predicate; +import com.yahoo.search.predicate.PredicateQuery; +import org.junit.Test; + +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static com.yahoo.document.predicate.Predicates.feature; +import static com.yahoo.document.predicate.Predicates.not; +import static com.yahoo.search.predicate.serialization.SerializationTestHelper.assertSerializationDeserializationMatches; +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; + +public class ConjunctionIndexTest { + + @Test + public void require_that_single_conjunction_can_be_indexed() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + builder.indexConjunction(indexableConj(conj(feature("a").inSet("1"), feature("b").inSet("2")))); + assertEquals(2, builder.calculateFeatureCount()); + assertEquals(1, builder.getUniqueConjunctionCount()); + } + + @Test + public void require_that_large_conjunction_can_be_indexed() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + builder.indexConjunction(indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("1"), + feature("c").inSet("1")))); + assertEquals(3, builder.calculateFeatureCount()); + assertEquals(1, builder.getUniqueConjunctionCount()); + } + + @Test + public void require_that_multiple_conjunctions_can_be_indexed() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + builder.indexConjunction(indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("3")))); + builder.indexConjunction(indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("3")))); // Duplicate + builder.indexConjunction(indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("2"), + feature("c").inSet("3")))); + builder.indexConjunction(indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("2"), + feature("c").inSet("3")))); // Duplicate + builder.indexConjunction(indexableConj( + conj( + feature("d").inSet("1"), + feature("e").inSet("5")))); + assertEquals(6, builder.calculateFeatureCount()); + assertEquals(3, builder.getUniqueConjunctionCount()); + } + + @Test + public void require_that_search_for_simple_conjunctions_work() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + + IndexableFeatureConjunction c1 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("2"))); + IndexableFeatureConjunction c2 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("2"), + feature("c").inSet("3"))); + IndexableFeatureConjunction c3 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("5"))); + + builder.indexConjunction(c1); + builder.indexConjunction(c2); + builder.indexConjunction(c3); + ConjunctionIndex index = builder.build(); + ConjunctionIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + query.addFeature("a", "1"); + query.addFeature("b", "2"); + assertHitsEquals(searcher.search(query), c1); + query.addFeature("c", "3"); + assertHitsEquals(searcher.search(query), c1, c2); + query.addFeature("b", "5"); + assertHitsEquals(searcher.search(query), c1, c2, c3); + } + + + @Test + public void require_that_conjunction_with_not_is_indexed() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + builder.indexConjunction(indexableConj( + conj( + not(feature("a").inSet("1")), + not(feature("b").inSet("1"))))); + builder.indexConjunction(indexableConj( + conj( + feature("a").inSet("1"), + not(feature("b").inSet("1"))))); + assertEquals(2, builder.calculateFeatureCount()); + assertEquals(2, builder.getUniqueConjunctionCount()); + assertEquals(1, builder.getZListSize()); + } + + @Test + public void require_that_not_works_when_k_is_0() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + IndexableFeatureConjunction c1 = indexableConj( + conj( + not(feature("a").inSet("1")), + not(feature("b").inSet("1")))); + IndexableFeatureConjunction c2 = indexableConj( + conj( + not(feature("a").inSet("1")), + not(feature("b").inSet("1")), + not(feature("c").inSet("1")))); + IndexableFeatureConjunction c3 = indexableConj( + conj( + not(feature("a").inSet("1")), + not(feature("b").inSet("1")), + not(feature("c").inSet("1")), + not(feature("d").inSet("1")))); + IndexableFeatureConjunction c4 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("1"))); + builder.indexConjunction(c1); + builder.indexConjunction(c2); + builder.indexConjunction(c3); + builder.indexConjunction(c4); + ConjunctionIndex index = builder.build(); + ConjunctionIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + assertHitsEquals(searcher.search(query), c1, c2, c3); + query.addFeature("a", "1"); + query.addFeature("b", "1"); + assertHitsEquals(searcher.search(query), c4); + query.addFeature("c", "1"); + assertHitsEquals(searcher.search(query), c4); + query.addFeature("d", "1"); + assertHitsEquals(searcher.search(query), c4); + } + + @Test + public void require_that_not_works_when_k_is_1() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + IndexableFeatureConjunction c1 = indexableConj( + conj( + feature("a").inSet("1"), + not(feature("b").inSet("1")))); + IndexableFeatureConjunction c2 = indexableConj( + conj( + feature("a").inSet("1"), + not(feature("b").inSet("1")), + not(feature("c").inSet("1")))); + IndexableFeatureConjunction c3 = indexableConj( + conj( + feature("a").inSet("1"), + not(feature("b").inSet("1")), + not(feature("c").inSet("1")), + not(feature("d").inSet("1")))); + builder.indexConjunction(c1); + builder.indexConjunction(c2); + builder.indexConjunction(c3); + ConjunctionIndex index = builder.build(); + ConjunctionIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + assertTrue(searcher.search(query).isEmpty()); + query.addFeature("a", "1"); + assertHitsEquals(searcher.search(query), c1, c2, c3); + query.addFeature("b", "1"); + assertTrue(searcher.search(query).isEmpty()); + query.addFeature("c", "1"); + assertTrue(searcher.search(query).isEmpty()); + query.addFeature("d", "1"); + assertTrue(searcher.search(query).isEmpty()); + } + + @Test + public void require_that_not_works_when_k_is_2() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + IndexableFeatureConjunction c1 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("1"), + not(feature("c").inSet("1")))); + IndexableFeatureConjunction c2 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("1"), + not(feature("c").inSet("1")), + not(feature("d").inSet("1")))); + IndexableFeatureConjunction c3 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("1"), + not(feature("c").inSet("1")), + not(feature("d").inSet("1")), + not(feature("e").inSet("1")))); + builder.indexConjunction(c1); + builder.indexConjunction(c2); + builder.indexConjunction(c3); + ConjunctionIndex index = builder.build(); + ConjunctionIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + query.addFeature("a", "1"); + query.addFeature("b", "1"); + assertHitsEquals(searcher.search(query), c1, c2, c3); + query.addFeature("c", "1"); + assertTrue(searcher.search(query).isEmpty()); + query.addFeature("d", "1"); + assertTrue(searcher.search(query).isEmpty()); + query.addFeature("e", "1"); + assertTrue(searcher.search(query).isEmpty()); + } + + @Test + public void require_that_multi_term_queries_are_supported() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + IndexableFeatureConjunction c1 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("3"))); + builder.indexConjunction(c1); + ConjunctionIndex index = builder.build(); + ConjunctionIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + query.addFeature("a", "1"); + query.addFeature("a", "2"); + assertTrue(searcher.search(query).isEmpty()); + query.addFeature("b", "3"); + assertHitsEquals(searcher.search(query), c1); + } + + @Test + public void require_that_subqueries_are_supported() { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + IndexableFeatureConjunction c1 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("3"), + not(feature("c").inSet("4")))); + IndexableFeatureConjunction c2 = indexableConj( + conj( + feature("a").inSet("1"), + feature("b").inSet("3"))); + IndexableFeatureConjunction c3 = indexableConj( + conj( + feature("a").inSet("2"), + feature("b").inSet("3"))); + IndexableFeatureConjunction c4 = indexableConj( + conj( + feature("e").inSet("5"), + feature("f").inSet("6")) + ); + builder.indexConjunction(c1); + builder.indexConjunction(c2); + builder.indexConjunction(c3); + builder.indexConjunction(c4); + ConjunctionIndex index = builder.build(); + ConjunctionIndex.Searcher searcher = index.searcher(); + + PredicateQuery query = new PredicateQuery(); + + //subquery 0: a=2 and b=3 + //subquery 1: a=1 and b=3 + //subquery 2: a=1 and b=3 + query.addFeature("a", "1", 0b110); + query.addFeature("a", "2", 0b001); + query.addFeature("b", "3", 0b111); + List<ConjunctionHit> expectedHits = matchingConjunctionList( + new ConjunctionHit(c1.id, 0b110), + new ConjunctionHit(c2.id, 0b110), + new ConjunctionHit(c3.id, 0b001) + ); + + List<ConjunctionHit> hits = searcher.search(query); + assertHitsEquals(expectedHits, hits); + + //subquery 0: a=2 and b=3 and c=4 + //subquery 1: a=1 and b=3 + //subquery 2: a=1 and b=3 and c=4 + query.addFeature("c", "4", 0b101); + expectedHits = matchingConjunctionList( + new ConjunctionHit(c1.id, 0b010), + new ConjunctionHit(c2.id, 0b110), + new ConjunctionHit(c3.id, 0b001) + ); + hits = searcher.search(query); + assertHitsEquals(expectedHits, hits); + + // subquery 0: a=2 and e=5 + // subquery 1: b=3 and f=6 + PredicateQuery query2 = new PredicateQuery(); + query2.addFeature("a", "2", 0b01); + query2.addFeature("b", "3", 0b10); + query2.addFeature("e", "5", 0b01); + query2.addFeature("f", "6", 0b10); + expectedHits = matchingConjunctionList( + new ConjunctionHit(c1.id, 0b010), + new ConjunctionHit(c2.id, 0b110), + new ConjunctionHit(c3.id, 0b001) + ); + hits = searcher.search(query); + assertHitsEquals(expectedHits, hits); + } + + @Test + public void require_that_serialization_and_deserialization_retain_data() throws IOException { + ConjunctionIndexBuilder builder = new ConjunctionIndexBuilder(); + builder.indexConjunction(indexableConj( + conj( + not(feature("a").inSet("1")), + not(feature("b").inSet("3")), + not(feature("c").inSet("4"))))); + builder.indexConjunction(indexableConj( + conj( + feature("d").inSet("5"), + feature("e").inSet("6")))); + ConjunctionIndex index = builder.build(); + assertSerializationDeserializationMatches( + index, ConjunctionIndex::writeToOutputStream, ConjunctionIndex::fromInputStream); + } + + private static List<ConjunctionHit> matchingConjunctionList(ConjunctionHit... conjunctionHits) { + return Arrays.asList(conjunctionHits); + } + + private static void assertHitsEquals(List<ConjunctionHit> hits, IndexableFeatureConjunction... conjunctions) { + Arrays.sort(conjunctions, (c1, c2) -> Long.compareUnsigned(c1.id, c2.id)); + Collections.sort(hits); + assertEquals(conjunctions.length, hits.size()); + for (int i = 0; i < hits.size(); i++) { + assertEquals(conjunctions[i].id, hits.get(i).conjunctionId); + } + } + + private static void assertHitsEquals(List<ConjunctionHit> expectedHits, List<ConjunctionHit> hits) { + Collections.sort(expectedHits); + Collections.sort(hits); + assertArrayEquals( + expectedHits.toArray(new ConjunctionHit[expectedHits.size()]), + hits.toArray(new ConjunctionHit[expectedHits.size()])); + } + + private static FeatureConjunction conj(Predicate... operands) { + return new FeatureConjunction(Arrays.asList(operands)); + } + + private static IndexableFeatureConjunction indexableConj(FeatureConjunction conjunction) { + return new IndexableFeatureConjunction(conjunction); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/optimization/FeatureConjunctionTransformerTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/optimization/FeatureConjunctionTransformerTest.java new file mode 100644 index 00000000000..e41c3b0676a --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/optimization/FeatureConjunctionTransformerTest.java @@ -0,0 +1,121 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.optimization; + +import com.yahoo.document.predicate.FeatureConjunction; +import com.yahoo.document.predicate.FeatureRange; +import com.yahoo.document.predicate.FeatureSet; +import com.yahoo.document.predicate.Predicate; +import org.junit.Test; + +import java.util.Arrays; + +import static com.yahoo.document.predicate.Predicates.and; +import static com.yahoo.document.predicate.Predicates.not; +import static com.yahoo.document.predicate.Predicates.or; +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class FeatureConjunctionTransformerTest { + private static final FeatureConjunctionTransformer transformer = new FeatureConjunctionTransformer(true); + + @Test + public void require_that_simple_ands_are_converted() { + assertConvertedPredicateEquals( + conj(not(featureSet(1)), featureSet(2)), + and(not(featureSet(1)), featureSet(2)) + ); + } + + @Test + public void require_that_featureranges_are_split_into_separate_and() { + assertConvertedPredicateEquals( + and(featureRange(2), conj(not(featureSet(1)), featureSet(3))), + and(not(featureSet(1)), featureRange(2), featureSet(3)) + ); + } + + @Test + public void require_that_ors_are_split_into_separate_and() { + assertConvertedPredicateEquals( + and(or(featureSet(1), featureSet(2)), conj(featureSet(3), featureSet(4))), + and(or(featureSet(1), featureSet(2)), featureSet(3), featureSet(4)) + ); + } + + @Test + public void require_that_ands_must_have_more_than_one_featureset_to_be_converted() { + assertConvertedPredicateEquals( + and(featureSet(1), featureRange(2)), + and(featureSet(1), featureRange(2)) + ); + } + + @Test + public void require_that_ordering_of_and_operands_are_preserved() { + assertConvertedPredicateEquals( + and(not(featureRange(1)), featureRange(4), conj(not(featureSet(2)), featureSet(3))), + and(not(featureRange(1)), not(featureSet(2)), featureSet(3), featureRange(4)) + ); + } + + @Test + public void require_that_nested_ands_are_converted() { + assertConvertedPredicateEquals( + and(conj(featureSet(1), featureSet(2)), conj(featureSet(3), featureSet(4))), + and(and(featureSet(1), featureSet(2)), and(featureSet(3), featureSet(4))) + ); + } + + @Test + public void require_that_featureset_with_common_key_is_not_converted() { + assertConvertedPredicateEquals( + and(not(featureSet(1)), featureSet(1)), + and(not(featureSet(1)), featureSet(1)) + ); + } + + @Test + public void require_that_nonunique_featureset_are_split_into_separate_conjunctions() { + assertConvertedPredicateEquals( + and(conj(not(featureSet(1)), featureSet(2)), featureSet(1)), + and(not(featureSet(1)), featureSet(1), featureSet(2)) + ); + assertConvertedPredicateEquals( + and(conj(not(featureSet(1)), featureSet(2)), conj(featureSet(1), featureSet(2))), + and(not(featureSet(1)), featureSet(1), featureSet(2), featureSet(2)) + ); + assertConvertedPredicateEquals( + and(featureRange(3), featureRange(4), conj(not(featureSet(1)), featureSet(2)), conj(featureSet(1), featureSet(2))), + and(not(featureSet(1)), featureSet(1), featureSet(2), featureSet(2), featureRange(3), featureRange(4)) + ); + } + + @Test + public void require_that_featuresets_in_conjunction_may_only_have_a_single_value() { + assertConvertedPredicateEquals( + and(featureSet(1, "a", "b"), featureSet(4, "c", "d"), conj(featureSet(2), featureSet(3))), + and(featureSet(1, "a", "b"), featureSet(2), featureSet(3), featureSet(4, "c", "d")) + ); + } + + private static FeatureConjunction conj(Predicate... operands) { + return new FeatureConjunction(Arrays.asList(operands)); + } + + private static FeatureSet featureSet(int id, String... values) { + if (values.length == 0) { + return new FeatureSet(Integer.toString(id), "a"); + } + return new FeatureSet(Integer.toString(id), values); + } + + private static FeatureRange featureRange(int id) { + return new FeatureRange(Integer.toString(id)); + } + + private static void assertConvertedPredicateEquals(Predicate expectedOutput, Predicate input) { + assertEquals(expectedOutput, transformer.process(input, new PredicateOptions(8))); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/PredicateQuerySerializerTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/PredicateQuerySerializerTest.java new file mode 100644 index 00000000000..133834cc3fe --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/PredicateQuerySerializerTest.java @@ -0,0 +1,54 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.serialization; + +import com.yahoo.search.predicate.PredicateQuery; +import com.yahoo.search.predicate.SubqueryBitmap; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class PredicateQuerySerializerTest { + + @Test + public void require_that_query_is_correctly_parsed_and_written_back_to_json() throws Exception { + String json = + "{\"features\":[" + + "{\"k\":\"k1\",\"v\":\"value1\",\"s\":\"0x1\"}," + + "{\"k\":\"k2\",\"v\":\"value2\",\"s\":\"0x3\"}" + + "],\"rangeFeatures\":[" + + "{\"k\":\"range1\",\"v\":123456789123,\"s\":\"0xffff\"}," + + "{\"k\":\"range2\",\"v\":0}" + + "]}"; + PredicateQuerySerializer serializer = new PredicateQuerySerializer(); + PredicateQuery query = serializer.fromJSON(json); + List<PredicateQuery.Feature> features = query.getFeatures(); + PredicateQuery.Feature f1 = features.get(0); + PredicateQuery.Feature f2 = features.get(1); + List<PredicateQuery.RangeFeature> rangeFeatures = query.getRangeFeatures(); + PredicateQuery.RangeFeature r1 = rangeFeatures.get(0); + PredicateQuery.RangeFeature r2 = rangeFeatures.get(1); + + assertEquals("k1", f1.key); + assertEquals("value1", f1.value); + assertEquals(0x1, f1.subqueryBitmap); + + assertEquals("k2", f2.key); + assertEquals("value2", f2.value); + assertEquals(0x3, f2.subqueryBitmap); + + assertEquals("range1", r1.key); + assertEquals(123456789123l, r1.value); + assertEquals(0xFFFF, r1.subqueryBitmap); + + assertEquals("range2", r2.key); + assertEquals(0l, r2.value); + assertEquals(SubqueryBitmap.DEFAULT_VALUE, r2.subqueryBitmap); + + assertEquals(json, serializer.toJSON(query)); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/SerializationHelperTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/SerializationHelperTest.java new file mode 100644 index 00000000000..4e4d6b40e0d --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/SerializationHelperTest.java @@ -0,0 +1,44 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.serialization; + +import org.junit.Test; + +import java.io.IOException; + +import static com.yahoo.search.predicate.serialization.SerializationTestHelper.*; + +/** + * @author bjorncs + */ +public class SerializationHelperTest { + + @Test + public void require_that_long_serialization_works() throws IOException { + long[] longs = {1, 2, 3, 4}; + assertSerializationDeserializationMatches( + longs, SerializationHelper::writeLongArray, SerializationHelper::readLongArray); + } + + @Test + public void require_that_int_serialization_works() throws IOException { + int[] ints = {1, 2, 3, 4}; + assertSerializationDeserializationMatches( + ints, SerializationHelper::writeIntArray, SerializationHelper::readIntArray); + } + + @Test + public void require_that_byte_serialization_works() throws IOException { + byte[] bytes = {1, 2, 3, 4}; + assertSerializationDeserializationMatches( + bytes, SerializationHelper::writeByteArray, SerializationHelper::readByteArray); + } + + @Test + public void require_that_short_serialization_works() throws IOException { + short[] shorts = {1, 2, 3, 4}; + assertSerializationDeserializationMatches( + shorts, SerializationHelper::writeShortArray, SerializationHelper::readShortArray); + } + + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/SerializationTestHelper.java b/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/SerializationTestHelper.java new file mode 100644 index 00000000000..47746b15d49 --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/serialization/SerializationTestHelper.java @@ -0,0 +1,48 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.serialization; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; + +import static org.junit.Assert.assertArrayEquals; + +/** + * @author bjorncs + */ +public class SerializationTestHelper { + + private SerializationTestHelper() {} + + public static <T> void assertSerializationDeserializationMatches + (T object, Serializer<T> serializer, Deserializer<T> deserializer) throws IOException { + + ByteArrayOutputStream byteArrayOut = new ByteArrayOutputStream(4096); + DataOutputStream out = new DataOutputStream(byteArrayOut); + serializer.serialize(object, out); + out.flush(); + + byte[] bytes = byteArrayOut.toByteArray(); + DataInputStream in = new DataInputStream(new ByteArrayInputStream(bytes)); + T newObject = deserializer.deserialize(in); + + byteArrayOut = new ByteArrayOutputStream(4096); + out = new DataOutputStream(byteArrayOut); + serializer.serialize(newObject, out); + byte[] newBytes = byteArrayOut.toByteArray(); + assertArrayEquals(bytes, newBytes); + } + + @FunctionalInterface + public interface Serializer<T> { + void serialize(T object, DataOutputStream out) throws IOException; + } + + @FunctionalInterface + public interface Deserializer<T> { + T deserialize(DataInputStream in) throws IOException; + } + +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/utils/PostingListSearchTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/utils/PostingListSearchTest.java new file mode 100644 index 00000000000..3daacb9826b --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/utils/PostingListSearchTest.java @@ -0,0 +1,59 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bjorncs + */ +public class PostingListSearchTest { + + @Test + public void require_that_search_find_index_of_first_element_higher() { + int[] values = {2, 8, 4000, 4001, 4100, 10000, 10000000}; + int length = values.length; + assertEquals(0, PostingListSearch.interpolationSearch(values, 0, length, 1)); + for (int value = 3; value < 8; value++) { + assertEquals(1, PostingListSearch.interpolationSearch(values, 0, length, value)); + } + assertEquals(2, PostingListSearch.interpolationSearch(values, 0, length, 8)); + assertEquals(values.length, PostingListSearch.interpolationSearch(values, 0, length, 10000000)); + assertEquals(values.length, PostingListSearch.interpolationSearch(values, 0, length, 10000001)); + } + + @Test + public void require_that_search_is_correct_for_one_size_arrays() { + int[] values = {100}; + assertEquals(0, PostingListSearch.interpolationSearch(values, 0, 1, 0)); + assertEquals(0, PostingListSearch.interpolationSearch(values, 0, 1, 99)); + assertEquals(1, PostingListSearch.interpolationSearch(values, 0, 1, 100)); + assertEquals(1, PostingListSearch.interpolationSearch(values, 0, 1, 101)); + assertEquals(1, PostingListSearch.interpolationSearch(values, 0, 1, 10000)); + } + + @Test + public void require_that_search_is_correct_for_sub_arrays() { + int[] values = {0, 2, 8, 4000, 4001, 4100}; + assertEquals(1, PostingListSearch.interpolationSearch(values, 1, 2, 1)); + assertEquals(2, PostingListSearch.interpolationSearch(values, 1, 2, 2)); + assertEquals(2, PostingListSearch.interpolationSearch(values, 1, 4, 2)); + assertEquals(4, PostingListSearch.interpolationSearch(values, 1, 4, 4000)); + assertEquals(5, PostingListSearch.interpolationSearch(values, 1, 5, 4001)); + assertEquals(5, PostingListSearch.interpolationSearch(values, 1, 5, 4101)); + } + + @Test + public void require_that_search_is_correct_for_large_arrays() { + int length = 10000; + int[] values = new int[length]; + for (int i = 0; i < length; i++) { + values[i] = 2 * i; + } + assertEquals(1, PostingListSearch.interpolationSearch(values, 1, length, 0)); + assertEquals(1227, PostingListSearch.interpolationSearch(values, 1, length, 2452)); + assertEquals(1227, PostingListSearch.interpolationSearch(values, 1, length, 2453)); + assertEquals(1228, PostingListSearch.interpolationSearch(values, 1, length, 2454)); + } +} diff --git a/predicate-search/src/test/java/com/yahoo/search/predicate/utils/PrimitiveArraySorterTest.java b/predicate-search/src/test/java/com/yahoo/search/predicate/utils/PrimitiveArraySorterTest.java new file mode 100644 index 00000000000..950268d5ecf --- /dev/null +++ b/predicate-search/src/test/java/com/yahoo/search/predicate/utils/PrimitiveArraySorterTest.java @@ -0,0 +1,122 @@ +// Copyright 2016 Yahoo Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.search.predicate.utils; + +import org.junit.Test; + +import java.util.Arrays; +import java.util.Random; + +import static org.junit.Assert.assertArrayEquals; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class PrimitiveArraySorterTest { + + @Test + public void sorting_empty_array_should_not_throw_exception() { + short[] array = {}; + PrimitiveArraySorter.sort(array, Short::compare); + } + + @Test + public void test_sorting_single_item_array() { + short[] array = {42}; + PrimitiveArraySorter.sort(array, Short::compare); + assertEquals(42, array[0]); + } + + @Test + public void test_sorting_custom_comparator() { + short[] array = {4, 2, 5}; + PrimitiveArraySorter.sort(array, (a, b) -> Short.compare(b, a)); // Sort using inverse ordering. + short[] expected = {5, 4, 2}; + assertArrayEquals(expected, array); + } + + @Test + public void test_complicated_array() { + short[] array = {20381, -28785, -19398, 17307, -12612, 11459, -30164, -16597, -4267, 30838, 8918, 9014, -26444, + -1232, -14620, 12636, -12389, -4931, 32108, 19854, -12681, 14933, 319, 27348, -4907, 19196, 14209, + -32694, 2579, 9771, -1157, -13717, 28506, -8016, 21423, 23697, 23755, 29650, 25644, -14660, -18952, + 25272, -19933, -11375, -32363, -11766, -29509, -23898, 12398, -2600, -20703, -23812, -8292, -1605, + 28642, 12748, 2547, -14535, 4476, -7802}; + short[] expected = Arrays.copyOf(array, array.length); + Arrays.sort(expected); + PrimitiveArraySorter.sort(array, Short::compare); + assertArrayEquals(expected, array); + } + + @Test + public void sorting_random_arrays_should_produce_identical_result_as_java_sort() { + Random r = new Random(4234); + for (int i = 0; i < 10000; i++) { + short[] original = makeRandomArray(r); + short[] javaSorted = Arrays.copyOf(original, original.length); + short[] customSorted = Arrays.copyOf(original, original.length); + PrimitiveArraySorter.sort(customSorted, Short::compare); + Arrays.sort(javaSorted); + String errorMsg = String.format("%s != %s (before sorting: %s)", Arrays.toString(customSorted), Arrays.toString(javaSorted), Arrays.toString(original)); + assertArrayEquals(errorMsg, customSorted, javaSorted); + } + } + + @Test + public void test_merging_simple_array() { + short[] array = {-20, -12, 2, -22, -11, 33, 44}; + short[] expected = {-22, -20, -12, -11, 2, 33, 44}; + short[] result = new short[array.length]; + PrimitiveArraySorter.merge(array, result, 3, Short::compare); + assertArrayEquals(expected, result); + } + + @Test + public void test_merging_of_random_generated_arrays() { + Random r = new Random(4234); + for (int i = 0; i < 10000; i++) { + short[] array = makeRandomArray(r); + int length = array.length; + short[] mergeArray = new short[length]; + short[] expected = Arrays.copyOf(array, length); + Arrays.sort(expected); + + int pivot = length > 0 ? r.nextInt(length) : 0; + Arrays.sort(array, 0, pivot); + Arrays.sort(array, pivot, length); + PrimitiveArraySorter.merge(array, mergeArray, pivot, Short::compare); + assertArrayEquals(expected, mergeArray); + } + } + + @Test + public void test_sortandmerge_returns_false_when_sort_is_in_place() { + short[] array = {3, 2, 1, 0, 4, 5, 6}; + short[] mergeArray = new short[array.length]; + assertFalse(PrimitiveArraySorter.sortAndMerge(array, mergeArray, 4, 7, Short::compare)); + assertIsSorted(array); + + array = new short[]{3, 2, 1, 0, 4, 5, 6}; + assertTrue(PrimitiveArraySorter.sortAndMerge(array, mergeArray, 3, 7, Short::compare)); + assertIsSorted(mergeArray); + } + + // Create random array with size [0, 99] filled with random values. + private short[] makeRandomArray(Random r) { + short[] array = new short[r.nextInt(100)]; + for (int j = 0; j < array.length; j++) { + array[j] = (short) r.nextInt(); + } + return array; + } + + private static void assertIsSorted(short[] array) { + if (array.length == 0) return; + int prev = array[0]; + for (int i = 1; i < array.length; i++) { + int next = array[i]; + assertTrue(prev <= next); + prev = next; + } + } + +} |