diff options
33 files changed, 430 insertions, 246 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java index a590a3a74bf..f2ad6a3ba2f 100644 --- a/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java +++ b/config-model/src/main/java/com/yahoo/schema/derived/VsmFields.java @@ -17,6 +17,7 @@ import com.yahoo.schema.document.Attribute; import com.yahoo.schema.document.Case; import com.yahoo.schema.document.FieldSet; import com.yahoo.schema.document.GeoPos; +import com.yahoo.schema.document.ImmutableSDField; import com.yahoo.schema.document.Matching; import com.yahoo.schema.document.MatchType; import com.yahoo.schema.document.SDDocumentType; @@ -52,34 +53,34 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { doctypes.put(document.getName(), docType); } for (Object o : document.fieldSet()) { - derive(docType, (SDField) o); + derive(docType, (SDField) o, false, false); } } - private void derive(StreamingDocumentType document, SDField field) { + private void derive(StreamingDocumentType document, SDField field, boolean isStructField, boolean ignoreAttributeAspect) { if (field.usesStructOrMap()) { if (GeoPos.isAnyPos(field)) { - StreamingField streamingField = new StreamingField(field); + var streamingField = new StreamingField(field, isStructField, true); addField(streamingField.getName(), streamingField); addFieldToIndices(document, field.getName(), streamingField); } for (SDField structField : field.getStructFields()) { - derive(document, structField); // Recursion + derive(document, structField, true, ignoreAttributeAspect || GeoPos.isAnyPos(field)); // Recursion } } else { - if (! (field.doesIndexing() || field.doesSummarying() || field.doesAttributing()) ) + if (! (field.doesIndexing() || field.doesSummarying() || isAttributeField(field, isStructField, ignoreAttributeAspect)) ) return; - StreamingField streamingField = new StreamingField(field); + var streamingField = new StreamingField(field, isStructField, ignoreAttributeAspect); addField(streamingField.getName(),streamingField); - deriveIndices(document, field, streamingField); + deriveIndices(document, field, streamingField, isStructField, ignoreAttributeAspect); } } - private void deriveIndices(StreamingDocumentType document, SDField field, StreamingField streamingField) { + private void deriveIndices(StreamingDocumentType document, SDField field, StreamingField streamingField, boolean isStructField, boolean ignoreAttributeAspect) { if (field.doesIndexing()) { addFieldToIndices(document, field.getName(), streamingField); - } else if (field.doesAttributing()) { + } else if (isAttributeField(field, isStructField, ignoreAttributeAspect)) { for (String indexName : field.getAttributes().keySet()) { addFieldToIndices(document, indexName, streamingField); } @@ -115,6 +116,17 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { } } + private static boolean isAttributeField(ImmutableSDField field, boolean isStructField, boolean ignoreAttributeAspect) { + if (field.doesAttributing()) { + return true; + } + if (!isStructField || ignoreAttributeAspect) { + return false; + } + var attribute = field.getAttributes().get(field.getName()); + return attribute != null; + } + private static class StreamingField { private final String name; @@ -170,8 +182,8 @@ public class VsmFields extends Derived implements VsmfieldsConfig.Producer { } - public StreamingField(SDField field) { - this(field.getName(), field.getDataType(), field.getMatching(), field.doesAttributing(), getDistanceMetric(field)); + public StreamingField(SDField field, boolean isStructField, boolean ignoreAttributeAspect) { + this(field.getName(), field.getDataType(), field.getMatching(), isAttributeField(field, isStructField, ignoreAttributeAspect), getDistanceMetric(field)); } private StreamingField(String name, DataType sourceType, Matching matching, boolean isAttribute, Attribute.DistanceMetric distanceMetric) { diff --git a/config-model/src/test/java/com/yahoo/schema/derived/VsmFieldsTestCase.java b/config-model/src/test/java/com/yahoo/schema/derived/VsmFieldsTestCase.java index 61d636d911f..852f567ccfa 100644 --- a/config-model/src/test/java/com/yahoo/schema/derived/VsmFieldsTestCase.java +++ b/config-model/src/test/java/com/yahoo/schema/derived/VsmFieldsTestCase.java @@ -6,6 +6,7 @@ import com.yahoo.config.model.deploy.TestProperties; import com.yahoo.config.model.test.MockApplicationPackage; import com.yahoo.document.DataType; import com.yahoo.documentmodel.NewDocumentReferenceDataType; +import com.yahoo.schema.ApplicationBuilder; import com.yahoo.schema.Schema; import com.yahoo.schema.document.Case; import com.yahoo.schema.document.MatchType; @@ -13,9 +14,14 @@ import com.yahoo.schema.document.Matching; import com.yahoo.schema.document.SDDocumentType; import com.yahoo.schema.document.SDField; import com.yahoo.schema.document.TemporarySDField; +import com.yahoo.schema.parser.ParseException; import com.yahoo.vespa.config.search.vsm.VsmfieldsConfig; import org.junit.jupiter.api.Test; +import java.util.HashSet; +import java.util.Set; + +import static com.yahoo.config.model.test.TestUtil.joinLines; import static org.junit.jupiter.api.Assertions.assertEquals; /** @@ -77,4 +83,72 @@ public class VsmFieldsTestCase { testIndexMatching(new Matching(MatchType.WORD).setCase(Case.CASED), VsmfieldsConfig.Fieldspec.Normalize.NONE, "word"); } + + private static Set<String> getIndexes(VsmfieldsConfig config, String field) { + var indexes = new HashSet<String>(); + var doctype = config.documenttype(0); + for (var index : doctype.index()) { + for (var indexField : index.field()) { + if (field.equals(indexField.name())) { + indexes.add(index.name()); + break; + } + } + } + return indexes; + } + + @Test + void deriveIndexFromNestedAttributes() throws ParseException { + String sd = joinLines( + "schema test {", + " document test {", + " field map_field type map<string,int> {", + " indexing: summary", + " struct-field key { indexing: attribute }", + " struct-field value { indexing: attribute }", + " }", + " }", + "}"); + var schema = ApplicationBuilder.createFromString(sd).getSchema(); + var config = vsmfieldsConfig(schema); + assertEquals(Set.of("map_field", "map_field.key"), getIndexes(config, "map_field.key")); + assertEquals(Set.of("map_field", "map_field.value"), getIndexes(config, "map_field.value")); + } + + @Test + void deriveIndexFromIndexStatement() throws ParseException { + String sd = joinLines( + "schema test {", + " document test {", + " field map_field type map<string,int> {", + " indexing: summary | index", + " }", + " }", + "}"); + var schema = ApplicationBuilder.createFromString(sd).getSchema(); + var config = vsmfieldsConfig(schema); + assertEquals(Set.of("map_field", "map_field.key"), getIndexes(config, "map_field.key")); + assertEquals(Set.of("map_field", "map_field.value"), getIndexes(config, "map_field.value")); + } + + @Test + void positionFieldTypeBlocksderivingOfIndexFromNestedAttributes() throws ParseException { + String sd = joinLines( + "schema test {", + " document test {", + " field pos type position {", + " indexing: attribute | summary", + " struct-field x { indexing: attribute }", + " struct-field y { indexing: attribute }", + " }", + " }", + "}"); + var schema = ApplicationBuilder.createFromString(sd).getSchema(); + var config = vsmfieldsConfig(schema); + assertEquals(Set.of("pos"), getIndexes(config, "pos")); + assertEquals(Set.of(), getIndexes(config, "pos.x")); + assertEquals(Set.of(), getIndexes(config, "pos.y")); + } + } diff --git a/container-search/src/main/java/com/yahoo/prelude/querytransform/CJKSearcher.java b/container-search/src/main/java/com/yahoo/prelude/querytransform/CJKSearcher.java index 4f4573a1e5a..878ba230274 100644 --- a/container-search/src/main/java/com/yahoo/prelude/querytransform/CJKSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/querytransform/CJKSearcher.java @@ -79,7 +79,7 @@ public class CJKSearcher extends Searcher { for (ListIterator<Item> i = ((CompositeItem) root).getItemIterator(); i.hasNext();) { Item item = i.next(); Item transformedItem = transform(item); - if (item != transformedItem) + if (item != transformedItem && ((CompositeItem) root).acceptsItemsOfType(transformedItem.getItemType())) i.set(transformedItem); } return root; diff --git a/container-search/src/main/java/com/yahoo/search/yql/MinimalQueryInserter.java b/container-search/src/main/java/com/yahoo/search/yql/MinimalQueryInserter.java index 0116d668d48..2d4ac86e3b2 100644 --- a/container-search/src/main/java/com/yahoo/search/yql/MinimalQueryInserter.java +++ b/container-search/src/main/java/com/yahoo/search/yql/MinimalQueryInserter.java @@ -75,14 +75,15 @@ public class MinimalQueryInserter extends Searcher { @Override public Result search(Query query, Execution execution) { + if (query.properties().get(YQL) == null) return execution.search(query); + Result errorResult; try { - if (query.properties().get(YQL) == null) return execution.search(query); - Result result = insertQuery(query, ParserEnvironment.fromExecutionContext(execution.context())); - return (result == null) ? execution.search(query) : result; + errorResult = insertQuery(query, ParserEnvironment.fromExecutionContext(execution.context())); } catch (IllegalArgumentException e) { throw new IllegalInputException("Illegal YQL query", e); } + return (errorResult == null) ? execution.search(query) : errorResult; } private static Result insertQuery(Query query, ParserEnvironment env) { diff --git a/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java b/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java index 7a287f8dcc9..22ba8754572 100644 --- a/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java +++ b/container-search/src/test/java/com/yahoo/prelude/querytransform/test/CJKSearcherTestCase.java @@ -18,6 +18,8 @@ import com.yahoo.search.query.parser.ParserEnvironment; import com.yahoo.search.query.parser.ParserFactory; import com.yahoo.search.searchchain.Execution; +import com.yahoo.search.test.QueryTestCase; +import com.yahoo.search.yql.MinimalQueryInserter; import org.junit.jupiter.api.Test; import static org.junit.jupiter.api.Assertions.assertEquals; @@ -53,6 +55,13 @@ public class CJKSearcherTestCase { TestLinguistics.INSTANCE); } + @Test + public void testEquivAndChinese() { + Query query = new Query(QueryTestCase.httpEncode("search?yql=select * from music-only where default contains equiv('a', 'b c') or default contains '东'")); + new Execution(new Chain<>(new MinimalQueryInserter(), new CJKSearcher()), Execution.Context.createContextStub()).search(query); + assertEquals("OR (EQUIV default:a default:'b c') default:东", query.getModel().getQueryTree().toString()); + } + private void assertTransformed(String queryString, String expected, Query.Type mode, Language actualLanguage, Language queryLanguage, Linguistics linguistics) { Parser parser = ParserFactory.newInstance(mode, new ParserEnvironment() diff --git a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java index 75e9525f09b..29a651aabf4 100644 --- a/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/yql/YqlParserTestCase.java @@ -1116,9 +1116,7 @@ public class YqlParserTestCase { void testBackslash() { { String queryString = "select * from testtype where title contains \"\\\\\""; // Java escaping * YQL escaping - QueryTree query = parse(queryString); - assertEquals("title:\\", query.toString()); } diff --git a/document/src/main/java/com/yahoo/document/DocumentUpdate.java b/document/src/main/java/com/yahoo/document/DocumentUpdate.java index ebaf821c29b..26254f3c081 100644 --- a/document/src/main/java/com/yahoo/document/DocumentUpdate.java +++ b/document/src/main/java/com/yahoo/document/DocumentUpdate.java @@ -347,6 +347,7 @@ public class DocumentUpdate extends DocumentOperation implements Iterable<FieldP } public final void serialize(GrowableByteBuffer buf) { + // TODO shouldn't this be createHead()?! serialize(DocumentSerializerFactory.create6(buf)); } diff --git a/documentapi/abi-spec.json b/documentapi/abi-spec.json index d00c89ae737..02d764deab8 100644 --- a/documentapi/abi-spec.json +++ b/documentapi/abi-spec.json @@ -3015,6 +3015,7 @@ ], "methods" : [ "public abstract boolean encode(com.yahoo.messagebus.Routable, com.yahoo.document.serialization.DocumentSerializer)", + "public byte[] encode(int, com.yahoo.messagebus.Routable)", "public abstract com.yahoo.messagebus.Routable decode(com.yahoo.document.serialization.DocumentDeserializer)" ], "fields" : [ ] diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactories80.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactories80.java index 2d29697717b..4712f6d2442 100644 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactories80.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactories80.java @@ -19,6 +19,7 @@ import com.yahoo.document.TestAndSetCondition; import com.yahoo.document.serialization.DocumentDeserializer; import com.yahoo.document.serialization.DocumentDeserializerFactory; import com.yahoo.document.serialization.DocumentSerializer; +import com.yahoo.document.serialization.DocumentSerializerFactory; import com.yahoo.io.GrowableByteBuffer; import com.yahoo.messagebus.Routable; import com.yahoo.vdslib.DocumentSummary; @@ -56,17 +57,31 @@ abstract class RoutableFactories80 { } @Override - public boolean encode(Routable obj, DocumentSerializer out) { + public byte[] encode(int msgType, Routable obj) { try { var protoMsg = encoderFn.apply(apiClass.cast(obj)); - // TODO avoid this buffer indirection by directly exposing an OutputStream to write into...! - // ... or at the very least have a way to preallocate buffer output of protoMsg.getSerializedSize() bytes! - out.getBuf().put(protoMsg.toByteArray()); - } catch (RuntimeException e) { + int protoSize = protoMsg.getSerializedSize(); + // The message payload contains a 4-byte header int which specifies the type of the message + // that follows. We want to write this header and the subsequence message bytes using a single + // allocation and without unneeded copying, so we create one array for both purposes and encode + // directly into it. Aside from the header, this is pretty much a mirror image of what the + // toByteArray() method on Protobuf message objects already does. + var buf = new byte[4 + protoSize]; + ByteBuffer.wrap(buf).putInt(msgType); // In network order (default setting) + var protoStream = CodedOutputStream.newInstance(buf, 4, protoSize); + protoMsg.writeTo(protoStream); // Writing straight to array, no need to flush + protoStream.checkNoSpaceLeft(); + return buf; + } catch (IOException | RuntimeException e) { logCodecError("encoding", e); - return false; + return null; } - return true; + } + + @Override + public boolean encode(Routable obj, DocumentSerializer out) { + // Legacy encode; not supported + return false; } @Override @@ -225,7 +240,7 @@ abstract class RoutableFactories80 { private static ByteBuffer serializeUpdate(DocumentUpdate update) { var buf = new GrowableByteBuffer(); - update.serialize(buf); + update.serialize(DocumentSerializerFactory.createHead(buf)); buf.flip(); return buf.getByteBuffer(); } diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactory.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactory.java index e98c9ab3a40..c635aa1581e 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactory.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableFactory.java @@ -3,6 +3,8 @@ package com.yahoo.documentapi.messagebus.protocol; import com.yahoo.document.serialization.DocumentDeserializer; import com.yahoo.document.serialization.DocumentSerializer; +import com.yahoo.document.serialization.DocumentSerializerFactory; +import com.yahoo.io.GrowableByteBuffer; import com.yahoo.messagebus.Routable; /** @@ -31,6 +33,32 @@ public interface RoutableFactory { boolean encode(Routable obj, DocumentSerializer out); /** + * <p>Encode a message type and object payload to a byte array. This is an alternative, + * optional method to {@link #encode(Routable, DocumentSerializer)}, but which defers all + * buffer management to the callee. This allows protocol implementations to make more + * efficient use of memory, as they do not have to deal with DocumentSerializer indirections.</p> + * + * <p>Implementations <strong>must</strong> ensure that the first 4 bytes of the returned + * byte array contain a 32-bit integer (in network order) equal to the provided msgType value.</p> + * + * @param msgType A positive integer indicating the concrete message type of obj. + * @param obj The message to encode. + * @return A byte buffer encapsulating the message type and the serialized representation + * of obj, or null if encoding failed. + */ + default byte[] encode(int msgType, Routable obj) { + var out = DocumentSerializerFactory.createHead(new GrowableByteBuffer(8192)); + out.putInt(null, msgType); + if (!encode(obj, out)) { + return null; + } + byte[] ret = new byte[out.getBuf().position()]; + out.getBuf().rewind(); + out.getBuf().get(ret); + return ret; + } + + /** * <p>This method decodes the given byte buffer to a routable.</p> <p>Return false to signal failure.</p> <p>This * method is NOT exception safe.</p> * diff --git a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableRepository.java b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableRepository.java index 47117471615..56d23d36811 100755 --- a/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableRepository.java +++ b/documentapi/src/main/java/com/yahoo/documentapi/messagebus/protocol/RoutableRepository.java @@ -93,17 +93,12 @@ final class RoutableRepository { log.log(Level.SEVERE,"Can not encode routable type " + type + " (version " + version + "). Only major version 5 and up supported."); return new byte[0]; } - DocumentSerializer out= DocumentSerializerFactory.createHead(new GrowableByteBuffer(8192)); - - out.putInt(null, type); - if (!factory.encode(obj, out)) { + byte[] ret = factory.encode(type, obj); + if (ret == null) { log.log(Level.SEVERE, "Routable factory " + factory.getClass().getName() + " failed to serialize " + "routable of type " + type + " (version " + version + ")."); return new byte[0]; } - byte[] ret = new byte[out.getBuf().position()]; - out.getBuf().rewind(); - out.getBuf().get(ret); return ret; } diff --git a/documentapi/src/vespa/documentapi/messagebus/routablerepository.cpp b/documentapi/src/vespa/documentapi/messagebus/routablerepository.cpp index 3e1ae07f7ca..7a03afbaf48 100644 --- a/documentapi/src/vespa/documentapi/messagebus/routablerepository.cpp +++ b/documentapi/src/vespa/documentapi/messagebus/routablerepository.cpp @@ -72,10 +72,6 @@ RoutableRepository::decode(const vespalib::Version &version, mbus::BlobRef data) if (!ret) { LOG(error, "Routable factory failed to deserialize routable of type %d (version %s).", type, version.toString().c_str()); - - std::ostringstream ost; - document::StringUtil::printAsHex(ost, data.data(), data.size()); - LOG(error, "%s", ost.str().c_str()); return {}; } return ret; diff --git a/searchcore/src/vespa/searchcore/proton/matching/query.cpp b/searchcore/src/vespa/searchcore/proton/matching/query.cpp index 5ade0a44b8a..1d7a693b1c9 100644 --- a/searchcore/src/vespa/searchcore/proton/matching/query.cpp +++ b/searchcore/src/vespa/searchcore/proton/matching/query.cpp @@ -200,7 +200,8 @@ Query::reserveHandles(const IRequestContext & requestContext, ISearchContext &co void Query::optimize(bool sort_by_cost) { - _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, sort_by_cost); + auto opts = Blueprint::Options::all().sort_by_cost(sort_by_cost); + _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, opts); LOG(debug, "optimized blueprint:\n%s\n", _blueprint->asString().c_str()); } @@ -222,7 +223,8 @@ Query::handle_global_filter(const IRequestContext & requestContext, uint32_t doc } // optimized order may change after accounting for global filter: trace.addEvent(5, "Optimize query execution plan to account for global filter"); - _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, sort_by_cost); + auto opts = Blueprint::Options::all().sort_by_cost(sort_by_cost); + _blueprint = Blueprint::optimize_and_sort(std::move(_blueprint), true, opts); LOG(debug, "blueprint after handle_global_filter:\n%s\n", _blueprint->asString().c_str()); // strictness may change if optimized order changed: fetchPostings(ExecuteInfo::create(true, 1.0, requestContext.getDoom(), requestContext.thread_bundle())); diff --git a/searchlib/src/tests/nearsearch/nearsearch_test.cpp b/searchlib/src/tests/nearsearch/nearsearch_test.cpp index 95701e59444..6f7cf85258b 100644 --- a/searchlib/src/tests/nearsearch/nearsearch_test.cpp +++ b/searchlib/src/tests/nearsearch/nearsearch_test.cpp @@ -229,7 +229,8 @@ Test::testNearSearch(MyQuery &query, uint32_t matchId) near_b->addChild(query.getTerm(i).make_blueprint(fieldId, i)); } bp->setDocIdLimit(1000); - bp = search::queryeval::Blueprint::optimize_and_sort(std::move(bp), true, true); + auto opts = search::queryeval::Blueprint::Options::all(); + bp = search::queryeval::Blueprint::optimize_and_sort(std::move(bp), true, opts); bp->fetchPostings(search::queryeval::ExecuteInfo::TRUE); search::fef::MatchData::UP md(layout.createMatchData()); search::queryeval::SearchIterator::UP near = bp->createSearch(*md, true); diff --git a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp index 2a21d66c090..1af9ee6cff7 100644 --- a/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/blueprint_test.cpp @@ -17,13 +17,15 @@ using namespace search::fef; namespace { +auto opts = Blueprint::Options::all(); + //----------------------------------------------------------------------------- class MyOr : public IntermediateBlueprint { private: - FlowCalc make_flow_calc(bool strict, double flow) const override { - return flow_calc<OrFlow>(strict, flow); + FlowCalc make_flow_calc(InFlow in_flow) const override { + return flow_calc<OrFlow>(in_flow); } public: FlowStats calculate_flow_stats(uint32_t) const final { @@ -451,7 +453,7 @@ TEST_F("testChildAndNotCollapsing", Fixture) ); TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, opts); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -491,7 +493,7 @@ TEST_F("testChildAndCollapsing", Fixture) TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, opts); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -530,10 +532,9 @@ TEST_F("testChildOrCollapsing", Fixture) ); TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - // we sort non-strict here since the default costs of 1/est for - // non-strict/strict leaf iterators makes the order of iterators - // under a strict OR irrelevant. - unsorted = Blueprint::optimize_and_sort(std::move(unsorted), false, true); + // we sort non-strict here since a strict OR does not have a + // deterministic sort order. + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), false, opts); TEST_DO(f.check_equal(*sorted, *unsorted)); } @@ -577,7 +578,7 @@ TEST_F("testChildSorting", Fixture) TEST_DO(f.check_not_equal(*sorted, *unsorted)); unsorted->setDocIdLimit(1000); - unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, true); + unsorted = Blueprint::optimize_and_sort(std::move(unsorted), true, opts); TEST_DO(f.check_equal(*sorted, *unsorted)); } diff --git a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp index 31db731a598..f192ea93b0e 100644 --- a/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp +++ b/searchlib/src/tests/queryeval/blueprint/intermediate_blueprints_test.cpp @@ -132,7 +132,8 @@ TEST("test AndNot Blueprint") { template <typename BP> void optimize(std::unique_ptr<BP> &ref, bool strict) { - auto optimized = Blueprint::optimize_and_sort(std::move(ref), strict, true); + auto opts = Blueprint::Options::all(); + auto optimized = Blueprint::optimize_and_sort(std::move(ref), strict, opts); ref.reset(dynamic_cast<BP*>(optimized.get())); ASSERT_TRUE(ref); optimized.release(); @@ -568,9 +569,10 @@ optimize_and_compare(Blueprint::UP top, Blueprint::UP expect, bool strict = true top->setDocIdLimit(1000); expect->setDocIdLimit(1000); TEST_DO(compare(*top, *expect, false)); - top = Blueprint::optimize_and_sort(std::move(top), strict, sort_by_cost); + auto opts = Blueprint::Options::all().sort_by_cost(sort_by_cost); + top = Blueprint::optimize_and_sort(std::move(top), strict, opts); TEST_DO(compare(*top, *expect, true)); - expect = Blueprint::optimize_and_sort(std::move(expect), strict, sort_by_cost); + expect = Blueprint::optimize_and_sort(std::move(expect), strict, opts); TEST_DO(compare(*expect, *top, true)); } @@ -699,11 +701,12 @@ TEST("test empty root node optimization and safeness") { //------------------------------------------------------------------------- auto expect_up = std::make_unique<EmptyBlueprint>(); - compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top1), true, true), true); - compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top2), true, true), true); - compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top3), true, true), true); - compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top4), true, true), true); - compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top5), true, true), true); + auto opts = Blueprint::Options::all(); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top1), true, opts), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top2), true, opts), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top3), true, opts), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top4), true, opts), true); + compare(*expect_up, *Blueprint::optimize_and_sort(std::move(top5), true, opts), true); } TEST("and with one empty child is optimized away") { @@ -711,7 +714,8 @@ TEST("and with one empty child is optimized away") { Blueprint::UP top = ap((new SourceBlenderBlueprint(*selector))-> addChild(ap(MyLeafSpec(10).create())). addChild(addLeafs(std::make_unique<AndBlueprint>(), {{0, true}, 10, 20}))); - top = Blueprint::optimize_and_sort(std::move(top), true, true); + auto opts = Blueprint::Options::all(); + top = Blueprint::optimize_and_sort(std::move(top), true, opts); Blueprint::UP expect_up(ap((new SourceBlenderBlueprint(*selector))-> addChild(ap(MyLeafSpec(10).create())). addChild(std::make_unique<EmptyBlueprint>()))); @@ -888,8 +892,9 @@ TEST("require that replaced blueprints retain source id") { addChild(ap(MyLeafSpec(30).create()->setSourceId(55))))); Blueprint::UP expect2_up(ap(MyLeafSpec(30).create()->setSourceId(42))); //------------------------------------------------------------------------- - top1_up = Blueprint::optimize_and_sort(std::move(top1_up), true, true); - top2_up = Blueprint::optimize_and_sort(std::move(top2_up), true, true); + auto opts = Blueprint::Options::all(); + top1_up = Blueprint::optimize_and_sort(std::move(top1_up), true, opts); + top2_up = Blueprint::optimize_and_sort(std::move(top2_up), true, opts); compare(*expect1_up, *top1_up, true); compare(*expect2_up, *top2_up, true); EXPECT_EQUAL(13u, top1_up->getSourceId()); @@ -1204,7 +1209,8 @@ TEST("require_that_unpack_optimization_is_not_overruled_by_equiv") { TEST("require that ANDNOT without children is optimized to empty search") { Blueprint::UP top_up = std::make_unique<AndNotBlueprint>(); auto expect_up = std::make_unique<EmptyBlueprint>(); - top_up = Blueprint::optimize_and_sort(std::move(top_up), true, true); + auto opts = Blueprint::Options::all(); + top_up = Blueprint::optimize_and_sort(std::move(top_up), true, opts); compare(*expect_up, *top_up, true); } diff --git a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp index 3e8bc06bfd8..c4d34ab3565 100644 --- a/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp +++ b/searchlib/src/tests/queryeval/filter_search/filter_search_test.cpp @@ -49,7 +49,7 @@ concept ChildCollector = requires(T a, std::unique_ptr<Blueprint> bp) { struct DefaultBlueprint : Blueprint { FlowStats calculate_flow_stats(uint32_t) const override { abort(); } void optimize(Blueprint* &, OptimizePass) override { abort(); } - void sort(bool, bool) override { abort(); } + double sort(InFlow, const Options &) override { abort(); } const State &getState() const override { abort(); } void fetchPostings(const ExecuteInfo &) override { abort(); } void freeze() override { abort(); } diff --git a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp index 5009a15e438..8b8b6c1282e 100644 --- a/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp +++ b/searchlib/src/tests/queryeval/flow/queryeval_flow_test.cpp @@ -141,7 +141,7 @@ TEST(FlowTest, full_and_flow) { {0.4, 0.4, false}, {0.4*0.7, 0.4*0.7, false}, {0.4*0.7*0.2, 0.4*0.7*0.2, false}}); - verify_flow_calc(flow_calc<AndFlow>(strict, 1.0), + verify_flow_calc(flow_calc<AndFlow>(strict), {0.4, 0.7, 0.2}, {1.0, 0.4, 0.4*0.7, 0.4*0.7*0.2}); } } @@ -153,7 +153,7 @@ TEST(FlowTest, partial_and_flow) { {in*0.4, in*0.4, false}, {in*0.4*0.7, in*0.4*0.7, false}, {in*0.4*0.7*0.2, in*0.4*0.7*0.2, false}}); - verify_flow_calc(flow_calc<AndFlow>(false, in), + verify_flow_calc(flow_calc<AndFlow>(in), {0.4, 0.7, 0.2}, {in*1.0, in*0.4, in*0.4*0.7, in*0.4*0.7*0.2}); } } @@ -164,14 +164,14 @@ TEST(FlowTest, full_or_flow) { {0.6, 1.0-0.6, false}, {0.6*0.3, 1.0-0.6*0.3, false}, {0.6*0.3*0.8, 1.0-0.6*0.3*0.8, false}}); - verify_flow_calc(flow_calc<OrFlow>(false, 1.0), + verify_flow_calc(flow_calc<OrFlow>(1.0), {0.4, 0.7, 0.2}, {1.0, 0.6, 0.6*0.3, 0.6*0.3*0.8}); verify_flow(OrFlow(true), {0.4, 0.7, 0.2}, {{1.0, 0.0, true}, {1.0, 1.0-0.6, true}, {1.0, 1.0-0.6*0.3, true}, {1.0, 1.0-0.6*0.3*0.8, true}}); - verify_flow_calc(flow_calc<OrFlow>(true, 1.0), + verify_flow_calc(flow_calc<OrFlow>(true), {0.4, 0.7, 0.2}, {1.0, 1.0, 1.0, 1.0}); } @@ -182,7 +182,7 @@ TEST(FlowTest, partial_or_flow) { {in*0.6, 1.0-in*0.6, false}, {in*0.6*0.3, 1.0-in*0.6*0.3, false}, {in*0.6*0.3*0.8, 1.0-in*0.6*0.3*0.8, false}}); - verify_flow_calc(flow_calc<OrFlow>(false, in), + verify_flow_calc(flow_calc<OrFlow>(in), {0.4, 0.7, 0.2}, {in, in*0.6, in*0.6*0.3, in*0.6*0.3*0.8}); } } @@ -194,7 +194,7 @@ TEST(FlowTest, full_and_not_flow) { {0.4, 0.4, false}, {0.4*0.3, 0.4*0.3, false}, {0.4*0.3*0.8, 0.4*0.3*0.8, false}}); - verify_flow_calc(flow_calc<AndNotFlow>(strict, 1.0), + verify_flow_calc(flow_calc<AndNotFlow>(strict), {0.4, 0.7, 0.2}, {1.0, 0.4, 0.4*0.3, 0.4*0.3*0.8}); } } @@ -206,45 +206,52 @@ TEST(FlowTest, partial_and_not_flow) { {in*0.4, in*0.4, false}, {in*0.4*0.3, in*0.4*0.3, false}, {in*0.4*0.3*0.8, in*0.4*0.3*0.8, false}}); - verify_flow_calc(flow_calc<AndNotFlow>(false, in), + verify_flow_calc(flow_calc<AndNotFlow>(in), {0.4, 0.7, 0.2}, {in, in*0.4, in*0.4*0.3, in*0.4*0.3*0.8}); } } TEST(FlowTest, full_first_flow_calc) { for (bool strict: {false, true}) { - verify_flow_calc(first_flow_calc(strict, 1.0), + verify_flow_calc(first_flow_calc(strict), {0.4, 0.7, 0.2}, {1.0, 0.4, 0.4, 0.4}); } } TEST(FlowTest, partial_first_flow_calc) { for (double in: {1.0, 0.5, 0.25}) { - verify_flow_calc(first_flow_calc(false, in), + verify_flow_calc(first_flow_calc(in), {0.4, 0.7, 0.2}, {in, in*0.4, in*0.4, in*0.4}); } } TEST(FlowTest, full_full_flow_calc) { for (bool strict: {false, true}) { - verify_flow_calc(full_flow_calc(strict, 1.0), + verify_flow_calc(full_flow_calc(strict), {0.4, 0.7, 0.2}, {1.0, 1.0, 1.0, 1.0}); } } TEST(FlowTest, partial_full_flow_calc) { for (double in: {1.0, 0.5, 0.25}) { - verify_flow_calc(full_flow_calc(false, in), + verify_flow_calc(full_flow_calc(in), {0.4, 0.7, 0.2}, {in, in, in, in}); } } -TEST(FlowTest, flow_calc_strictness_overrides_rate) { - EXPECT_EQ(flow_calc<AndFlow>(true, 0.5)(0.5), 1.0); - EXPECT_EQ(flow_calc<OrFlow>(true, 0.5)(0.5), 1.0); - EXPECT_EQ(flow_calc<AndNotFlow>(true, 0.5)(0.5), 1.0); - EXPECT_EQ(first_flow_calc(true, 0.5)(0.5), 1.0); - EXPECT_EQ(full_flow_calc(true, 0.5)(0.5), 1.0); +TEST(FlowTest, in_flow_strict_vs_rate_interaction) { + EXPECT_EQ(InFlow(true).strict(), true); + EXPECT_EQ(InFlow(true).rate(), 1.0); + EXPECT_EQ(InFlow(false).strict(), false); + EXPECT_EQ(InFlow(false).rate(), 1.0); + EXPECT_EQ(InFlow(0.5).strict(), false); + EXPECT_EQ(InFlow(0.5).rate(), 0.5); + EXPECT_EQ(InFlow(true, 0.5).strict(), true); + EXPECT_EQ(InFlow(true, 0.5).rate(), 1.0); + EXPECT_EQ(InFlow(false, 0.5).strict(), false); + EXPECT_EQ(InFlow(false, 0.5).rate(), 0.5); + EXPECT_EQ(InFlow(-1.0).strict(), false); + EXPECT_EQ(InFlow(-1.0).rate(), 0.0); } TEST(FlowTest, flow_cost) { diff --git a/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp b/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp index 6747fed888c..bdc89363b22 100644 --- a/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp +++ b/searchlib/src/tests/queryeval/iterator_benchmark/iterator_benchmark_test.cpp @@ -342,7 +342,8 @@ non_strict_search(Blueprint& blueprint, MatchData& md, uint32_t docid_limit, dou BenchmarkResult benchmark_search(Blueprint::UP blueprint, uint32_t docid_limit, bool strict, double filter_hit_ratio) { - blueprint->sort(strict, true); + auto opts = Blueprint::Options::all(); + blueprint->sort(strict, opts); blueprint->fetchPostings(ExecuteInfo::createForTest(strict)); // Note: All blueprints get the same TermFieldMatchData instance. // This is OK as long as we don't do unpacking and only use 1 thread. diff --git a/searchlib/src/tests/queryeval/same_element/same_element_test.cpp b/searchlib/src/tests/queryeval/same_element/same_element_test.cpp index c9fcb472b68..64f4fafd2d1 100644 --- a/searchlib/src/tests/queryeval/same_element/same_element_test.cpp +++ b/searchlib/src/tests/queryeval/same_element/same_element_test.cpp @@ -46,7 +46,8 @@ std::unique_ptr<SameElementBlueprint> make_blueprint(const std::vector<FakeResul } Blueprint::UP finalize(Blueprint::UP bp, bool strict) { - Blueprint::UP result = Blueprint::optimize_and_sort(std::move(bp), true, true); + auto opts = Blueprint::Options::all(); + Blueprint::UP result = Blueprint::optimize_and_sort(std::move(bp), true, opts); result->fetchPostings(ExecuteInfo::createForTest(strict)); result->freeze(); return result; diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp index f3539c6989a..5a225328003 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.cpp @@ -121,6 +121,7 @@ Blueprint::Blueprint() noexcept _flow_stats(0.0, 0.0, 0.0), _sourceId(0xffffffff), _docid_limit(0), + _force_strict(false), _frozen(false) { } @@ -481,12 +482,6 @@ IntermediateBlueprint::count_termwise_nodes(const UnpackInfo &unpack) const return termwise_nodes; } -FlowCalc -IntermediateBlueprint::make_flow_calc(bool strict, double flow) const -{ - return full_flow_calc(strict, flow); -} - IntermediateBlueprint::IndexList IntermediateBlueprint::find(const IPredicate & pred) const { @@ -574,13 +569,17 @@ IntermediateBlueprint::optimize(Blueprint* &self, OptimizePass pass) maybe_eliminate_self(self, get_replacement()); } -void -IntermediateBlueprint::sort(bool strict, bool sort_by_cost) +double +IntermediateBlueprint::sort(InFlow in_flow, const Options &opts) { - sort(_children, strict, sort_by_cost); + auto flow_calc = make_flow_calc(in_flow); + sort(_children, in_flow.strict(), opts.sort_by_cost()); for (size_t i = 0; i < _children.size(); ++i) { - _children[i]->sort(strict && inheritStrict(i), sort_by_cost); + double next_rate = flow_calc(_children[i]->estimate()); + _children[i]->sort(InFlow(in_flow.strict() && inheritStrict(i), next_rate), opts); } + // TODO: better cost estimate (due to known in-flow and eagerness) + return in_flow.strict() ? strict_cost() : in_flow.rate() * cost(); } void @@ -647,7 +646,7 @@ IntermediateBlueprint::visitMembers(vespalib::ObjectVisitor &visitor) const void IntermediateBlueprint::fetchPostings(const ExecuteInfo &execInfo) { - FlowCalc flow_calc = make_flow_calc(execInfo.is_strict(), execInfo.hit_rate()); + FlowCalc flow_calc = make_flow_calc(InFlow(execInfo.is_strict(), execInfo.hit_rate())); for (size_t i = 0; i < _children.size(); ++i) { Blueprint & child = *_children[i]; double nextHitRate = flow_calc(child.estimate()); @@ -766,9 +765,11 @@ LeafBlueprint::optimize(Blueprint* &self, OptimizePass pass) maybe_eliminate_self(self, get_replacement()); } -void -LeafBlueprint::sort(bool, bool) +double +LeafBlueprint::sort(InFlow in_flow, const Options &) { + // TODO: better cost estimate (due to known in-flow and eagerness) + return in_flow.strict() ? strict_cost() : in_flow.rate() * cost(); } void diff --git a/searchlib/src/vespa/searchlib/queryeval/blueprint.h b/searchlib/src/vespa/searchlib/queryeval/blueprint.h index c24790ddcf1..0c08e6aedf5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/blueprint.h +++ b/searchlib/src/vespa/searchlib/queryeval/blueprint.h @@ -55,6 +55,29 @@ public: enum class OptimizePass { FIRST, LAST }; + class Options { + private: + bool _sort_by_cost; + bool _allow_force_strict; + public: + constexpr Options() noexcept + : _sort_by_cost(false), + _allow_force_strict(false) {} + constexpr bool sort_by_cost() const noexcept { return _sort_by_cost; } + constexpr Options &sort_by_cost(bool value) noexcept { + _sort_by_cost = value; + return *this; + } + constexpr bool allow_force_strict() const noexcept { return _allow_force_strict; } + constexpr Options &allow_force_strict(bool value) noexcept { + _allow_force_strict = value; + return *this; + } + static constexpr Options all() noexcept { + return Options().sort_by_cost(true).allow_force_strict(true); + } + }; + struct HitEstimate { uint32_t estHits; bool empty; @@ -182,6 +205,7 @@ private: FlowStats _flow_stats; uint32_t _sourceId; uint32_t _docid_limit; + bool _force_strict; bool _frozen; protected: @@ -224,10 +248,10 @@ public: uint32_t get_docid_limit() const noexcept { return _docid_limit; } static Blueprint::UP optimize(Blueprint::UP bp); - virtual void sort(bool strict, bool sort_by_cost) = 0; - static Blueprint::UP optimize_and_sort(Blueprint::UP bp, bool strict, bool sort_by_cost) { + virtual double sort(InFlow in_flow, const Options &opts) = 0; + static Blueprint::UP optimize_and_sort(Blueprint::UP bp, InFlow in_flow, const Options &opts) { auto result = optimize(std::move(bp)); - result->sort(strict, sort_by_cost); + result->sort(in_flow, opts); return result; } virtual void optimize(Blueprint* &self, OptimizePass pass) = 0; @@ -362,7 +386,7 @@ private: bool infer_want_global_filter() const; size_t count_termwise_nodes(const UnpackInfo &unpack) const; - virtual FlowCalc make_flow_calc(bool strict, double flow) const = 0; + virtual FlowCalc make_flow_calc(InFlow in_flow) const = 0; protected: // returns an empty collection if children have empty or @@ -385,7 +409,7 @@ public: void setDocIdLimit(uint32_t limit) noexcept final; void optimize(Blueprint* &self, OptimizePass pass) final; - void sort(bool strict, bool sort_by_cost) override; + double sort(InFlow in_flow, const Options &opts) override; void set_global_filter(const GlobalFilter &global_filter, double estimated_hit_ratio) override; IndexList find(const IPredicate & check) const; @@ -422,7 +446,7 @@ private: State _state; protected: void optimize(Blueprint* &self, OptimizePass pass) final; - void sort(bool strict, bool sort_by_cost) override; + double sort(InFlow in_flow, const Options &opts) override; void setEstimate(HitEstimate est) { _state.estimate(est); notifyChange(); diff --git a/searchlib/src/vespa/searchlib/queryeval/flow.h b/searchlib/src/vespa/searchlib/queryeval/flow.h index 4548baf7545..ade2516b509 100644 --- a/searchlib/src/vespa/searchlib/queryeval/flow.h +++ b/searchlib/src/vespa/searchlib/queryeval/flow.h @@ -11,6 +11,23 @@ namespace search::queryeval { +// Encapsulate information about strictness and in-flow in a structure +// for convenient parameter passing. We do not need an explicit value +// in the strict case since strict basically means the receiving end +// will eventually decide the actual flow. We use a rate of 1.0 for +// strict flow to indicate that the corpus is not reduced externally. +class InFlow { +private: + double _value; +public: + constexpr InFlow(bool strict, double rate) noexcept + : _value(strict ? -1.0 : std::max(rate, 0.0)) {} + constexpr InFlow(bool strict) noexcept : InFlow(strict, 1.0) {} + constexpr InFlow(double rate) noexcept : InFlow(false, rate) {} + constexpr bool strict() noexcept { return _value < 0.0; } + constexpr double rate() noexcept { return strict() ? 1.0 : _value; } +}; + struct FlowStats { double estimate; double cost; @@ -122,16 +139,13 @@ void sort_partial(ADAPTER adapter, T &children, size_t offset) { template <typename ADAPTER, typename T, typename F> double ordered_cost_of(ADAPTER adapter, const T &children, F flow) { - double cost = 0.0; + double total_cost = 0.0; for (const auto &child: children) { - if (flow.strict()) { - cost += adapter.strict_cost(child); - } else { - cost += flow.flow() * adapter.cost(child); - } + double child_cost = flow.strict() ? adapter.strict_cost(child) : (flow.flow() * adapter.cost(child)); + flow.update_cost(total_cost, child_cost); flow.add(adapter.estimate(child)); } - return cost; + return total_cost; } template <typename ADAPTER, typename T> @@ -188,8 +202,7 @@ private: bool _strict; bool _first; public: - AndFlow(bool strict) noexcept : _flow(1.0), _strict(strict), _first(true) {} - AndFlow(double in) noexcept : _flow(in), _strict(false), _first(true) {} + AndFlow(InFlow flow) noexcept : _flow(flow.rate()), _strict(flow.strict()), _first(true) {} void add(double est) noexcept { _flow *= est; _first = false; @@ -203,6 +216,9 @@ public: double estimate() const noexcept { return _first ? 0.0 : _flow; } + void update_cost(double &total_cost, double child_cost) noexcept { + total_cost += child_cost; + } static void sort(auto adapter, auto &children, bool strict) { flow::sort<flow::MinAndCost>(adapter, children); if (strict && children.size() > 1) { @@ -225,8 +241,7 @@ private: bool _strict; bool _first; public: - OrFlow(bool strict) noexcept : _flow(1.0), _strict(strict), _first(true) {} - OrFlow(double in) noexcept : _flow(in), _strict(false), _first(true) {} + OrFlow(InFlow flow) noexcept : _flow(flow.rate()), _strict(flow.strict()), _first(true) {} void add(double est) noexcept { _flow *= (1.0 - est); _first = false; @@ -240,6 +255,9 @@ public: double estimate() const noexcept { return _first ? 0.0 : (1.0 - _flow); } + void update_cost(double &total_cost, double child_cost) noexcept { + total_cost += child_cost; + } static void sort(auto adapter, auto &children, bool strict) { if (!strict) { flow::sort<flow::MinOrCost>(adapter, children); @@ -256,8 +274,7 @@ private: bool _strict; bool _first; public: - AndNotFlow(bool strict) noexcept : _flow(1.0), _strict(strict), _first(true) {} - AndNotFlow(double in) noexcept : _flow(in), _strict(false), _first(true) {} + AndNotFlow(InFlow flow) noexcept : _flow(flow.rate()), _strict(flow.strict()), _first(true) {} void add(double est) noexcept { _flow *= _first ? est : (1.0 - est); _first = false; @@ -271,6 +288,9 @@ public: double estimate() const noexcept { return _first ? 0.0 : _flow; } + void update_cost(double &total_cost, double child_cost) noexcept { + total_cost += child_cost; + } static void sort(auto adapter, auto &children, bool) { flow::sort_partial<flow::MinOrCost>(adapter, children, 1); } @@ -282,21 +302,18 @@ public: using FlowCalc = std::function<double(double)>; template <typename FLOW> -FlowCalc flow_calc(bool strict, double non_strict_rate) { - FLOW flow = strict ? FLOW(true) : FLOW(non_strict_rate); - return [flow](double est) mutable noexcept { +FlowCalc flow_calc(InFlow in_flow) { + return [flow=FLOW(in_flow)](double est) mutable noexcept { double next_flow = flow.flow(); flow.add(est); return next_flow; }; } -inline FlowCalc first_flow_calc(bool strict, double flow) { - if (strict) { - flow = 1.0; - } +inline FlowCalc first_flow_calc(InFlow in_flow) { bool first = true; - return [flow,first](double est) mutable noexcept { + double flow = in_flow.rate(); + return [first,flow](double est) mutable noexcept { double next_flow = flow; if (first) { flow *= est; @@ -306,10 +323,8 @@ inline FlowCalc first_flow_calc(bool strict, double flow) { }; } -inline FlowCalc full_flow_calc(bool strict, double flow) { - if (strict) { - flow = 1.0; - } +inline FlowCalc full_flow_calc(InFlow in_flow) { + double flow = in_flow.rate(); return [flow](double) noexcept { return flow; }; } diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp index b8bf7d40655..9d0acc50ce5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.cpp @@ -208,9 +208,9 @@ AndNotBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) co FlowCalc -AndNotBlueprint::make_flow_calc(bool strict, double flow) const +AndNotBlueprint::make_flow_calc(InFlow in_flow) const { - return flow_calc<AndNotFlow>(strict, flow); + return flow_calc<AndNotFlow>(in_flow); } //----------------------------------------------------------------------------- @@ -308,9 +308,9 @@ AndBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) const } FlowCalc -AndBlueprint::make_flow_calc(bool strict, double flow) const +AndBlueprint::make_flow_calc(InFlow in_flow) const { - return flow_calc<AndFlow>(strict, flow); + return flow_calc<AndFlow>(in_flow); } //----------------------------------------------------------------------------- @@ -408,9 +408,9 @@ OrBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) const } FlowCalc -OrBlueprint::make_flow_calc(bool strict, double flow) const +OrBlueprint::make_flow_calc(InFlow in_flow) const { - return flow_calc<OrFlow>(strict, flow); + return flow_calc<OrFlow>(in_flow); } uint8_t @@ -426,9 +426,9 @@ OrBlueprint::calculate_cost_tier() const //----------------------------------------------------------------------------- FlowCalc -WeakAndBlueprint::make_flow_calc(bool strict, double flow) const +WeakAndBlueprint::make_flow_calc(InFlow in_flow) const { - return flow_calc<OrFlow>(strict, flow); + return flow_calc<OrFlow>(in_flow); } WeakAndBlueprint::~WeakAndBlueprint() = default; @@ -503,9 +503,9 @@ WeakAndBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) c //----------------------------------------------------------------------------- FlowCalc -NearBlueprint::make_flow_calc(bool strict, double flow) const +NearBlueprint::make_flow_calc(InFlow in_flow) const { - return flow_calc<AndFlow>(strict, flow); + return flow_calc<AndFlow>(in_flow); } FlowStats @@ -574,9 +574,9 @@ NearBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) cons //----------------------------------------------------------------------------- FlowCalc -ONearBlueprint::make_flow_calc(bool strict, double flow) const +ONearBlueprint::make_flow_calc(InFlow in_flow) const { - return flow_calc<AndFlow>(strict, flow); + return flow_calc<AndFlow>(in_flow); } FlowStats @@ -735,17 +735,17 @@ RankBlueprint::createFilterSearch(bool strict, FilterConstraint constraint) cons } FlowCalc -RankBlueprint::make_flow_calc(bool strict, double flow) const +RankBlueprint::make_flow_calc(InFlow in_flow) const { - return first_flow_calc(strict, flow); + return first_flow_calc(in_flow); } //----------------------------------------------------------------------------- FlowCalc -SourceBlenderBlueprint::make_flow_calc(bool strict, double flow) const +SourceBlenderBlueprint::make_flow_calc(InFlow in_flow) const { - return full_flow_calc(strict, flow); + return full_flow_calc(in_flow); } SourceBlenderBlueprint::SourceBlenderBlueprint(const ISourceSelector &selector) noexcept diff --git a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h index 0095095dfe8..028898d3f47 100644 --- a/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h +++ b/searchlib/src/vespa/searchlib/queryeval/intermediate_blueprints.h @@ -29,7 +29,7 @@ public: SearchIterator::UP createFilterSearch(bool strict, FilterConstraint constraint) const override; private: - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; uint8_t calculate_cost_tier() const override { return (childCnt() > 0) ? get_children()[0]->getState().cost_tier() : State::COST_TIER_NORMAL; } @@ -57,7 +57,7 @@ public: SearchIterator::UP createFilterSearch(bool strict, FilterConstraint constraint) const override; private: - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; }; //----------------------------------------------------------------------------- @@ -82,7 +82,7 @@ public: SearchIterator::UP createFilterSearch(bool strict, FilterConstraint constraint) const override; private: - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; uint8_t calculate_cost_tier() const override; }; @@ -94,7 +94,7 @@ private: uint32_t _n; std::vector<uint32_t> _weights; - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; @@ -125,7 +125,7 @@ class NearBlueprint : public IntermediateBlueprint private: uint32_t _window; - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; @@ -148,7 +148,7 @@ class ONearBlueprint : public IntermediateBlueprint private: uint32_t _window; - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; public: FlowStats calculate_flow_stats(uint32_t docid_limit) const final; HitEstimate combine(const std::vector<HitEstimate> &data) const override; @@ -186,7 +186,7 @@ public: return (childCnt() > 0) ? get_children()[0]->getState().cost_tier() : State::COST_TIER_NORMAL; } private: - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; }; //----------------------------------------------------------------------------- @@ -196,7 +196,7 @@ class SourceBlenderBlueprint final : public IntermediateBlueprint private: const ISourceSelector &_selector; - FlowCalc make_flow_calc(bool strict, double flow) const override; + FlowCalc make_flow_calc(InFlow in_flow) const override; public: explicit SourceBlenderBlueprint(const ISourceSelector &selector) noexcept; ~SourceBlenderBlueprint() override; diff --git a/vespalib/src/tests/array/CMakeLists.txt b/vespalib/src/tests/array/CMakeLists.txt index 3f08aa07b4d..11e81cc42ce 100644 --- a/vespalib/src/tests/array/CMakeLists.txt +++ b/vespalib/src/tests/array/CMakeLists.txt @@ -11,5 +11,6 @@ vespa_add_executable(vespalib_sort_benchmark_app sort_benchmark.cpp DEPENDS vespalib + GTest::gtest ) vespa_add_test(NAME vespalib_sort_benchmark_app COMMAND vespalib_sort_benchmark_app BENCHMARK) diff --git a/vespalib/src/tests/array/sort_benchmark.cpp b/vespalib/src/tests/array/sort_benchmark.cpp index ba95c663332..db5c3c80f5f 100644 --- a/vespalib/src/tests/array/sort_benchmark.cpp +++ b/vespalib/src/tests/array/sort_benchmark.cpp @@ -1,6 +1,6 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/rusage.h> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/array.hpp> #include <csignal> #include <algorithm> @@ -10,19 +10,32 @@ LOG_SETUP("sort_benchmark"); using namespace vespalib; -class Test : public TestApp +class SortBenchmark : public ::testing::Test { + int _argc; + char** _argv; public: -private: + SortBenchmark(int argc, char **argv); + ~SortBenchmark() override; + void TestBody() override; +protected: template<typename T> vespalib::Array<T> create(size_t count); template<typename T> void sortDirect(size_t count); template<typename T> void sortInDirect(size_t count); - int Main() override; }; +SortBenchmark::SortBenchmark(int argc, char** argv) + : testing::Test(), + _argc(argc), + _argv(argv) +{ +} + +SortBenchmark::~SortBenchmark() = default; + template<size_t N> class TT { @@ -46,7 +59,7 @@ private: template<typename T> vespalib::Array<T> -Test::create(size_t count) +SortBenchmark::create(size_t count) { vespalib::Array<T> v; v.reserve(count); @@ -58,7 +71,7 @@ Test::create(size_t count) } template<typename T> -void Test::sortDirect(size_t count) +void SortBenchmark::sortDirect(size_t count) { vespalib::Array<T> v(create<T>(count)); LOG(info, "Running sortDirect with %ld count and payload of %ld", v.size(), sizeof(T)); @@ -69,7 +82,7 @@ void Test::sortDirect(size_t count) } template<typename T> -void Test::sortInDirect(size_t count) +void SortBenchmark::sortInDirect(size_t count) { vespalib::Array<T> k(create<T>(count)); LOG(info, "Running sortInDirect with %ld count and payload of %ld", k.size(), sizeof(T)); @@ -84,8 +97,8 @@ void Test::sortInDirect(size_t count) } } -int -Test::Main() +void +SortBenchmark::TestBody() { std::string type("sortdirect"); size_t count = 1000000; @@ -99,8 +112,7 @@ Test::Main() if (_argc > 3) { payLoad = strtol(_argv[3], NULL, 0); } - TEST_INIT("sort_benchmark"); - steady_time start(steady_clock::now()); + steady_time start(steady_clock::now()); if (payLoad < 8) { using T = TT<8>; if (type == "sortdirect") { @@ -176,9 +188,13 @@ Test::Main() } } LOG(info, "rusage = {\n%s\n}", vespalib::RUsage::createSelf(start).toString().c_str()); - ASSERT_EQUAL(0, kill(0, SIGPROF)); - TEST_DONE(); + ASSERT_EQ(0, kill(0, SIGPROF)); } -TEST_APPHOOK(Test); - +int main(int argc, char** argv) { + ::testing::InitGoogleTest(&argc, argv); + ::testing::RegisterTest("SortBenchmark", "benchmark", nullptr, "", + __FILE__, __LINE__, + [=]() -> SortBenchmark* { return new SortBenchmark(argc, argv); }); + return RUN_ALL_TESTS(); +} diff --git a/vespalib/src/tests/bits/CMakeLists.txt b/vespalib/src/tests/bits/CMakeLists.txt index 71a977a295a..3ba0c6afc1c 100644 --- a/vespalib/src/tests/bits/CMakeLists.txt +++ b/vespalib/src/tests/bits/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(vespalib_bits_test_app TEST bits_test.cpp DEPENDS vespalib + GTest::gtest ) vespa_add_test(NAME vespalib_bits_test_app COMMAND vespalib_bits_test_app) diff --git a/vespalib/src/tests/bits/bits_test.cpp b/vespalib/src/tests/bits/bits_test.cpp index 6341477c2a7..6b01b2cd63a 100644 --- a/vespalib/src/tests/bits/bits_test.cpp +++ b/vespalib/src/tests/bits/bits_test.cpp @@ -1,23 +1,25 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/bits.h> +#include <vespa/vespalib/gtest/gtest.h> using namespace vespalib; -class Test : public TestApp +class BitsTest : public ::testing::Test { -public: - int Main() override; +protected: + BitsTest(); + ~BitsTest() override; template <typename T> void testFixed(const T * v, size_t sz, const T * exp); void testBuffer(); }; -int -Test::Main() +BitsTest::BitsTest() = default; +BitsTest::~BitsTest() = default; + +TEST_F(BitsTest, test_bits) { - TEST_INIT("bits_test"); uint8_t u8[5] = { 0, 0x1, 0x7f, 0x87, 0xff }; uint8_t exp8[5] = { 0, 0x80, 0xfe, 0xe1, 0xff }; testFixed(u8, sizeof(u8)/sizeof(u8[0]), exp8); @@ -31,29 +33,28 @@ Test::Main() uint64_t exp64[5] = { 0, 0x8000000000000000, 0xfe00000000000000, 0xe100000000000000, 0xff00000000000000 }; testFixed(u64, sizeof(u64)/sizeof(u64[0]), exp64); testBuffer(); - TEST_DONE(); } template <typename T> -void Test::testFixed(const T * v, size_t sz, const T * exp) +void BitsTest::testFixed(const T * v, size_t sz, const T * exp) { - EXPECT_EQUAL(0u, Bits::reverse(static_cast<T>(0))); - EXPECT_EQUAL(1ul << (sizeof(T)*8 - 1), Bits::reverse(static_cast<T>(1))); - EXPECT_EQUAL(static_cast<T>(-1), Bits::reverse(static_cast<T>(-1))); + EXPECT_EQ(0u, Bits::reverse(static_cast<T>(0))); + EXPECT_EQ(1ul << (sizeof(T)*8 - 1), Bits::reverse(static_cast<T>(1))); + EXPECT_EQ(static_cast<T>(-1), Bits::reverse(static_cast<T>(-1))); for (size_t i(0); i < sz; i++) { - EXPECT_EQUAL(Bits::reverse(v[i]), exp[i]); - EXPECT_EQUAL(Bits::reverse(Bits::reverse(v[i])), v[i]); + EXPECT_EQ(Bits::reverse(v[i]), exp[i]); + EXPECT_EQ(Bits::reverse(Bits::reverse(v[i])), v[i]); } } -void Test::testBuffer() +void BitsTest::testBuffer() { uint64_t a(0x0102040810204080ul); uint64_t b(a); Bits::reverse(&a, sizeof(a)); - EXPECT_EQUAL(a, Bits::reverse(b)); + EXPECT_EQ(a, Bits::reverse(b)); Bits::reverse(&a, sizeof(a)); - EXPECT_EQUAL(a, b); + EXPECT_EQ(a, b); } -TEST_APPHOOK(Test) +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/btree/CMakeLists.txt b/vespalib/src/tests/btree/CMakeLists.txt index 79bda87867e..bf4edc9e5e9 100644 --- a/vespalib/src/tests/btree/CMakeLists.txt +++ b/vespalib/src/tests/btree/CMakeLists.txt @@ -12,6 +12,7 @@ vespa_add_executable(vespalib_frozenbtree_test_app TEST frozenbtree_test.cpp DEPENDS vespalib + GTest::gtest ) vespa_add_test(NAME vespalib_frozenbtree_test_app COMMAND vespalib_frozenbtree_test_app COST 30) vespa_add_executable(vespalib_btreeaggregation_test_app TEST diff --git a/vespalib/src/tests/btree/frozenbtree_test.cpp b/vespalib/src/tests/btree/frozenbtree_test.cpp index b16a7013db4..ffe8b4516aa 100644 --- a/vespalib/src/tests/btree/frozenbtree_test.cpp +++ b/vespalib/src/tests/btree/frozenbtree_test.cpp @@ -2,12 +2,12 @@ #define DEBUG_FROZENBTREE #define LOG_FROZENBTREEXX -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/btree/btreeroot.h> #include <vespa/vespalib/btree/btreeiterator.hpp> #include <vespa/vespalib/btree/btreeroot.hpp> #include <vespa/vespalib/btree/btreenodeallocator.hpp> #include <vespa/vespalib/datastore/buffer_type.hpp> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/rand48.h> #include <map> @@ -24,11 +24,11 @@ using vespalib::GenerationHandler; namespace vespalib { -class FrozenBTreeTest : public vespalib::TestApp +class FrozenBTreeTest : public ::testing::Test { public: using KeyType = int; -private: +protected: std::vector<KeyType> _randomValues; std::vector<KeyType> _sortedRandomValues; @@ -43,7 +43,7 @@ public: using LeafNodeType = Tree::LeafNodeType; using Iterator = Tree::Iterator; using ConstIterator = Tree::ConstIterator; -private: +protected: GenerationHandler *_generationHandler; NodeAllocator *_allocator; Tree *_tree; @@ -70,21 +70,21 @@ private: } public: FrozenBTreeTest(); - ~FrozenBTreeTest(); - - int Main() override; + ~FrozenBTreeTest() override; }; FrozenBTreeTest::FrozenBTreeTest() - : vespalib::TestApp(), + : ::testing::Test(), _randomValues(), _sortedRandomValues(), _generationHandler(NULL), _allocator(NULL), _tree(NULL), _randomGenerator() -{} -FrozenBTreeTest::~FrozenBTreeTest() {} +{ +} + +FrozenBTreeTest::~FrozenBTreeTest() = default; void FrozenBTreeTest::allocTree() @@ -410,11 +410,8 @@ FrozenBTreeTest::printEnumTree(const Tree *tree, -int -FrozenBTreeTest::Main() +TEST_F(FrozenBTreeTest, test_frozen_btree) { - TEST_INIT("frozenbtree_test"); - fillRandomValues(1000); sortRandomValues(); @@ -451,21 +448,8 @@ FrozenBTreeTest::Main() true); insertRandomValues(*_tree, *_allocator, _randomValues); freeTree(true); - - fillRandomValues(1000000); - sortRandomValues(); - - allocTree(); - insertRandomValues(*_tree, *_allocator, _randomValues); - traverseTreeIterator(*_tree, - *_allocator, - _sortedRandomValues, - false); - freeTree(false); - - TEST_DONE(); } } -TEST_APPHOOK(vespalib::FrozenBTreeTest); +GTEST_MAIN_RUN_ALL_TESTS() diff --git a/vespalib/src/tests/compress/CMakeLists.txt b/vespalib/src/tests/compress/CMakeLists.txt index 3085aec6bd8..31679e1f479 100644 --- a/vespalib/src/tests/compress/CMakeLists.txt +++ b/vespalib/src/tests/compress/CMakeLists.txt @@ -4,5 +4,6 @@ vespa_add_executable(vespalib_compress_test_app TEST compress_test.cpp DEPENDS vespalib + GTest::gtest ) vespa_add_test(NAME vespalib_compress_test_app COMMAND vespalib_compress_test_app) diff --git a/vespalib/src/tests/compress/compress_test.cpp b/vespalib/src/tests/compress/compress_test.cpp index c4383a1c193..7ad69b54e06 100644 --- a/vespalib/src/tests/compress/compress_test.cpp +++ b/vespalib/src/tests/compress/compress_test.cpp @@ -1,58 +1,59 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -#include <vespa/vespalib/testkit/testapp.h> #include <vespa/vespalib/util/compress.h> +#include <vespa/vespalib/gtest/gtest.h> #include <vespa/vespalib/util/exceptions.h> using namespace vespalib; using compress::Integer; -class CompressTest : public TestApp +class CompressTest : public ::testing::Test { -private: +protected: + CompressTest(); + ~CompressTest() override; void verifyPositiveNumber(uint64_t n, const uint8_t * expected, size_t sz); void verifyNumber(int64_t n, const uint8_t * expected, size_t sz); - void requireThatPositiveNumberCompressCorrectly(); - void requireThatNumberCompressCorrectly(); -public: - int Main() override; }; +CompressTest::CompressTest() = default; +CompressTest::~CompressTest() = default; + void CompressTest::verifyPositiveNumber(uint64_t n, const uint8_t * expected, size_t sz) { uint8_t buf[8]; - EXPECT_EQUAL(sz, Integer::compressPositive(n, buf)); - EXPECT_EQUAL(sz, Integer::compressedPositiveLength(n)); + EXPECT_EQ(sz, Integer::compressPositive(n, buf)); + EXPECT_EQ(sz, Integer::compressedPositiveLength(n)); for (size_t i(0); i < sz; i++) { - EXPECT_EQUAL(expected[i], buf[i]); + EXPECT_EQ(expected[i], buf[i]); } EXPECT_FALSE(Integer::check_decompress_positive_space(expected, 0u)); EXPECT_FALSE(Integer::check_decompress_positive_space(expected, sz - 1)); EXPECT_TRUE(Integer::check_decompress_positive_space(expected, sz)); uint64_t v(0); - EXPECT_EQUAL(sz, Integer::decompressPositive(v, expected)); - EXPECT_EQUAL(n, v); + EXPECT_EQ(sz, Integer::decompressPositive(v, expected)); + EXPECT_EQ(n, v); } void CompressTest::verifyNumber(int64_t n, const uint8_t * expected, size_t sz) { uint8_t buf[8]; - EXPECT_EQUAL(sz, Integer::compress(n, buf)); - EXPECT_EQUAL(sz, Integer::compressedLength(n)); + EXPECT_EQ(sz, Integer::compress(n, buf)); + EXPECT_EQ(sz, Integer::compressedLength(n)); for (size_t i(0); i < sz; i++) { - EXPECT_EQUAL(expected[i], buf[i]); + EXPECT_EQ(expected[i], buf[i]); } EXPECT_FALSE(Integer::check_decompress_space(expected, 0u)); EXPECT_FALSE(Integer::check_decompress_space(expected, sz - 1)); EXPECT_TRUE(Integer::check_decompress_space(expected, sz)); int64_t v(0); - EXPECT_EQUAL(sz, Integer::decompress(v, expected)); - EXPECT_EQUAL(n, v); + EXPECT_EQ(sz, Integer::decompress(v, expected)); + EXPECT_EQ(n, v); } #define VERIFY_POSITIVE(n, p) verifyPositiveNumber(n, p, sizeof(p)) -void -CompressTest::requireThatPositiveNumberCompressCorrectly() + +TEST_F(CompressTest, require_that_positive_number_compress_correctly) { const uint8_t zero[1] = {0}; VERIFY_POSITIVE(0, zero); @@ -73,19 +74,19 @@ CompressTest::requireThatPositiveNumberCompressCorrectly() VERIFY_POSITIVE(0x40000000, x40000000); EXPECT_TRUE(false); } catch (const IllegalArgumentException & e) { - EXPECT_EQUAL("Number '1073741824' too big, must extend encoding", e.getMessage()); + EXPECT_EQ("Number '1073741824' too big, must extend encoding", e.getMessage()); } try { VERIFY_POSITIVE(-1, x40000000); EXPECT_TRUE(false); } catch (const IllegalArgumentException & e) { - EXPECT_EQUAL("Number '18446744073709551615' too big, must extend encoding", e.getMessage()); + EXPECT_EQ("Number '18446744073709551615' too big, must extend encoding", e.getMessage()); } } #define VERIFY_NUMBER(n, p) verifyNumber(n, p, sizeof(p)) -void -CompressTest::requireThatNumberCompressCorrectly() + +TEST_F(CompressTest, require_that_number_compress_correctly) { const uint8_t zero[1] = {0}; VERIFY_NUMBER(0, zero); @@ -106,7 +107,7 @@ CompressTest::requireThatNumberCompressCorrectly() VERIFY_NUMBER(0x20000000, x20000000); EXPECT_TRUE(false); } catch (const IllegalArgumentException & e) { - EXPECT_EQUAL("Number '536870912' too big, must extend encoding", e.getMessage()); + EXPECT_EQ("Number '536870912' too big, must extend encoding", e.getMessage()); } const uint8_t mzero[1] = {0x81}; VERIFY_NUMBER(-1, mzero); @@ -127,19 +128,8 @@ CompressTest::requireThatNumberCompressCorrectly() VERIFY_NUMBER(-0x20000000, mx20000000); EXPECT_TRUE(false); } catch (const IllegalArgumentException & e) { - EXPECT_EQUAL("Number '-536870912' too big, must extend encoding", e.getMessage()); + EXPECT_EQ("Number '-536870912' too big, must extend encoding", e.getMessage()); } } -int -CompressTest::Main() -{ - TEST_INIT("compress_test"); - - requireThatPositiveNumberCompressCorrectly(); - requireThatNumberCompressCorrectly(); - - TEST_DONE(); -} - -TEST_APPHOOK(CompressTest) +GTEST_MAIN_RUN_ALL_TESTS() |