From 0970129d98a386753e2fa24c559c77392691c633 Mon Sep 17 00:00:00 2001 From: Håvard Pettersen Date: Mon, 6 Nov 2017 15:27:58 +0000 Subject: clean up tensor engine API make Tensor a subclass of Value --- searchlib/src/tests/features/constant/constant_test.cpp | 5 ++--- searchlib/src/tests/features/tensor/tensor_test.cpp | 2 +- .../features/tensor_from_labels/tensor_from_labels_test.cpp | 2 +- .../tensor_from_weighted_set_test.cpp | 2 +- .../tensor/dense_tensor_store/dense_tensor_store_test.cpp | 2 +- .../src/vespa/searchlib/features/constant_tensor_executor.h | 11 ++++++----- .../searchlib/features/dense_tensor_attribute_executor.cpp | 6 ++---- .../searchlib/features/dense_tensor_attribute_executor.h | 1 - .../vespa/searchlib/features/tensor_attribute_executor.cpp | 13 +++++-------- .../vespa/searchlib/features/tensor_attribute_executor.h | 2 +- .../searchlib/features/tensor_from_attribute_executor.h | 4 ++-- 11 files changed, 22 insertions(+), 28 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/tests/features/constant/constant_test.cpp b/searchlib/src/tests/features/constant/constant_test.cpp index a10f76e25ba..4a88fde58ce 100644 --- a/searchlib/src/tests/features/constant/constant_test.cpp +++ b/searchlib/src/tests/features/constant/constant_test.cpp @@ -19,7 +19,6 @@ using namespace search::features; using vespalib::eval::Function; using vespalib::eval::Value; using vespalib::eval::DoubleValue; -using vespalib::eval::TensorValue; using vespalib::eval::TensorSpec; using vespalib::eval::ValueType; using vespalib::tensor::DefaultTensorEngine; @@ -39,7 +38,7 @@ Tensor::UP createTensor(const TensorCells &cells, } Tensor::UP make_tensor(const TensorSpec &spec) { - auto tensor = DefaultTensorEngine::ref().create(spec); + auto tensor = DefaultTensorEngine::ref().from_spec(spec); return Tensor::UP(dynamic_cast(tensor.release())); } @@ -80,7 +79,7 @@ struct ExecFixture ValueType type(tensor->getType()); test.getIndexEnv().addConstantValue(name, std::move(type), - std::make_unique(std::move(tensor))); + std::move(tensor)); } void addDouble(const vespalib::string &name, const double value) { diff --git a/searchlib/src/tests/features/tensor/tensor_test.cpp b/searchlib/src/tests/features/tensor/tensor_test.cpp index be7bb9defac..b097f27342d 100644 --- a/searchlib/src/tests/features/tensor/tensor_test.cpp +++ b/searchlib/src/tests/features/tensor/tensor_test.cpp @@ -54,7 +54,7 @@ Tensor::UP createTensor(const TensorCells &cells, } Tensor::UP make_tensor(const TensorSpec &spec) { - auto tensor = DefaultTensorEngine::ref().create(spec); + auto tensor = DefaultTensorEngine::ref().from_spec(spec); return Tensor::UP(dynamic_cast(tensor.release())); } diff --git a/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp b/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp index 0a900ad9ec8..1ac524b5d0b 100644 --- a/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp +++ b/searchlib/src/tests/features/tensor_from_labels/tensor_from_labels_test.cpp @@ -36,7 +36,7 @@ typedef search::AttributeVector::SP AttributePtr; typedef FtTestApp FTA; Tensor::UP make_tensor(const TensorSpec &spec) { - auto tensor = DefaultTensorEngine::ref().create(spec); + auto tensor = DefaultTensorEngine::ref().from_spec(spec); return Tensor::UP(dynamic_cast(tensor.release())); } diff --git a/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp b/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp index cad0c56b0ca..e0eee954a53 100644 --- a/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp +++ b/searchlib/src/tests/features/tensor_from_weighted_set/tensor_from_weighted_set_test.cpp @@ -37,7 +37,7 @@ typedef search::AttributeVector::SP AttributePtr; typedef FtTestApp FTA; Tensor::UP make_tensor(const TensorSpec &spec) { - auto tensor = DefaultTensorEngine::ref().create(spec); + auto tensor = DefaultTensorEngine::ref().from_spec(spec); return Tensor::UP(dynamic_cast(tensor.release())); } diff --git a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp index 28b4ad3c4e4..2e88f0e90b0 100644 --- a/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp +++ b/searchlib/src/tests/tensor/dense_tensor_store/dense_tensor_store_test.cpp @@ -21,7 +21,7 @@ using EntryRef = DenseTensorStore::EntryRef; Tensor::UP makeTensor(const TensorSpec &spec) { - auto tensor = DefaultTensorEngine::ref().create(spec); + auto tensor = DefaultTensorEngine::ref().from_spec(spec); return Tensor::UP(dynamic_cast(tensor.release())); } diff --git a/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h b/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h index 1a0e425e0ef..43ce48282ee 100644 --- a/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h +++ b/searchlib/src/vespa/searchlib/features/constant_tensor_executor.h @@ -18,10 +18,10 @@ namespace features { class ConstantTensorExecutor : public fef::FeatureExecutor { private: - const vespalib::eval::TensorValue::UP _tensor; + vespalib::eval::Value::UP _tensor; public: - ConstantTensorExecutor(vespalib::eval::TensorValue::UP tensor) + ConstantTensorExecutor(vespalib::eval::Value::UP tensor) : _tensor(std::move(tensor)) {} virtual bool isPure() override { return true; } @@ -29,11 +29,12 @@ public: outputs().set_object(0, *_tensor); } static fef::FeatureExecutor &create(std::unique_ptr tensor, vespalib::Stash &stash) { - return stash.create(std::make_unique(std::move(tensor))); + return stash.create(std::move(tensor)); } static fef::FeatureExecutor &createEmpty(const vespalib::eval::ValueType &valueType, vespalib::Stash &stash) { - return create(vespalib::tensor::DefaultTensorEngine::ref() - .create(vespalib::eval::TensorSpec(valueType.to_spec())), stash); + const auto &engine = vespalib::tensor::DefaultTensorEngine::ref(); + auto spec = vespalib::eval::TensorSpec(valueType.to_spec()); + return stash.create(engine.from_spec(spec)); } static fef::FeatureExecutor &createEmpty(vespalib::Stash &stash) { return createEmpty(vespalib::eval::ValueType::double_type(), stash); diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp index 76252486bf4..487bc724e07 100644 --- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp +++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.cpp @@ -5,7 +5,6 @@ using search::tensor::DenseTensorAttribute; using vespalib::eval::Tensor; -using vespalib::eval::TensorValue; using vespalib::tensor::MutableDenseTensorView; namespace search { @@ -14,8 +13,7 @@ namespace features { DenseTensorAttributeExecutor:: DenseTensorAttributeExecutor(const DenseTensorAttribute *attribute) : _attribute(attribute), - _tensorView(_attribute->getConfig().tensorType()), - _tensor(_tensorView) + _tensorView(_attribute->getConfig().tensorType()) { } @@ -23,7 +21,7 @@ void DenseTensorAttributeExecutor::execute(uint32_t docId) { _attribute->getTensor(docId, _tensorView); - outputs().set_object(0, _tensor); + outputs().set_object(0, _tensorView); } } // namespace features diff --git a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h index 68042075942..ac3d327c12a 100644 --- a/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h +++ b/searchlib/src/vespa/searchlib/features/dense_tensor_attribute_executor.h @@ -19,7 +19,6 @@ class DenseTensorAttributeExecutor : public fef::FeatureExecutor private: const search::tensor::DenseTensorAttribute *_attribute; vespalib::tensor::MutableDenseTensorView _tensorView; - vespalib::eval::TensorValue _tensor; public: DenseTensorAttributeExecutor(const search::tensor::DenseTensorAttribute *attribute); diff --git a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp index 6ee7664f2bb..03393d6f590 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp +++ b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.cpp @@ -3,8 +3,6 @@ #include "tensor_attribute_executor.h" #include -using vespalib::eval::TensorValue; - namespace search { namespace features { @@ -12,20 +10,19 @@ TensorAttributeExecutor:: TensorAttributeExecutor(const search::tensor::TensorAttribute *attribute) : _attribute(attribute), _emptyTensor(attribute->getEmptyTensor()), - _tensor(*_emptyTensor) + _tensor() { } void TensorAttributeExecutor::execute(uint32_t docId) { - auto tensor = _attribute->getTensor(docId); - if (!tensor) { - _tensor = TensorValue(*_emptyTensor); + _tensor = _attribute->getTensor(docId); + if (_tensor) { + outputs().set_object(0, *_tensor); } else { - _tensor = TensorValue(std::move(tensor)); + outputs().set_object(0, *_emptyTensor); } - outputs().set_object(0, _tensor); } } // namespace features diff --git a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h index 198b03e3d1d..0f1e21c8cad 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h +++ b/searchlib/src/vespa/searchlib/features/tensor_attribute_executor.h @@ -17,7 +17,7 @@ class TensorAttributeExecutor : public fef::FeatureExecutor private: const search::tensor::TensorAttribute *_attribute; std::unique_ptr _emptyTensor; - vespalib::eval::TensorValue _tensor; + std::unique_ptr _tensor; public: TensorAttributeExecutor(const search::tensor::TensorAttribute *attribute); diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h b/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h index 31b92f89538..f102749f1b6 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h +++ b/searchlib/src/vespa/searchlib/features/tensor_from_attribute_executor.h @@ -22,7 +22,7 @@ private: const search::attribute::IAttributeVector *_attribute; vespalib::string _dimension; WeightedBufferType _attrBuffer; - vespalib::eval::TensorValue::UP _tensor; + std::unique_ptr _tensor; public: TensorFromAttributeExecutor(const search::attribute::IAttributeVector *attribute, @@ -48,7 +48,7 @@ TensorFromAttributeExecutor::execute(uint32_t docId) builder.add_label(dimensionEnum, vespalib::string(_attrBuffer[i].value())); builder.add_cell(_attrBuffer[i].weight()); } - _tensor = vespalib::eval::TensorValue::UP(new vespalib::eval::TensorValue(builder.build())); + _tensor = builder.build(); outputs().set_object(0, *_tensor); } -- cgit v1.2.3 From 238064fb7136aedecbff4f37c0a48f3b0152d32a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 9 Nov 2017 13:34:19 +0100 Subject: Add tensor conformance test in Java --- searchlib/pom.xml | 10 ++ .../searchlib/tensor/TensorConformanceTest.java | 154 +++++++++++++++++++++ 2 files changed, 164 insertions(+) create mode 100644 searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java (limited to 'searchlib') diff --git a/searchlib/pom.xml b/searchlib/pom.xml index 36e6fa1ffda..c669903c3da 100644 --- a/searchlib/pom.xml +++ b/searchlib/pom.xml @@ -34,6 +34,16 @@ vespajlib ${project.version} + + com.fasterxml.jackson.core + jackson-core + test + + + com.fasterxml.jackson.core + jackson-databind + test + diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java new file mode 100644 index 00000000000..27aaeb776e4 --- /dev/null +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -0,0 +1,154 @@ +package com.yahoo.searchlib.tensor; + +import com.fasterxml.jackson.databind.JsonNode; +import com.fasterxml.jackson.databind.ObjectMapper; +import com.yahoo.io.GrowableByteBuffer; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleCompatibleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.parser.ParseException; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.serialization.TypedBinaryFormat; +import org.junit.Assert; +import org.junit.Test; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Optional; + +import static org.junit.Assert.assertEquals; + +public class TensorConformanceTest { + + private static String testPath = "eval/src/apps/tensor_conformance/test_spec.json"; + + @Test + public void testConformance() throws IOException { + File testSpec = new File(testPath); + if (!testSpec.exists()) { + testSpec = new File("../" + testPath); + } + int count = 0; + List failList = new ArrayList<>(); + + try(BufferedReader br = new BufferedReader(new FileReader(testSpec))) { + String test = br.readLine(); + while (test != null) { + boolean success = testCase(test, count); + if (!success) { + failList.add(count); + } + test = br.readLine(); + count++; + } + } + if (failList.size() > 0) { + System.out.println("Conformance test fails:"); + System.out.println(failList); + } + + // Disable this for now: + //assertEquals(0, failList.size()); + } + + private boolean testCase(String test, int count) throws IOException { + try { + ObjectMapper mapper = new ObjectMapper(); + JsonNode node = mapper.readTree(test); + if (node.has("num_tests")) { + Assert.assertEquals(node.get("num_tests").asInt(), count); + } else if (node.has("expression")) { + String expression = node.get("expression").asText(); + MapContext context = getInput(node.get("inputs")); + Tensor expect = getTensor(node.get("result").get("expect").asText()); + Tensor result = evaluate(expression, context); + boolean equals = Tensor.equals(result, expect); + if (!equals) { + System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } + return Tensor.equals(result, expect); + } + } catch (Exception e) { + System.out.println(count + " : " + e.toString()); + } + return false; + } + + private Tensor evaluate(String expression, MapContext context) throws ParseException { + Value value = new RankingExpression(expression).evaluate(context); + if (!(value instanceof TensorValue)) { + throw new IllegalArgumentException("Result is not a tensor"); + } + return ((TensorValue)value).asTensor(); + } + + private MapContext getInput(JsonNode inputs) { + MapContext context = new MapContext(); + for (Iterator i = inputs.fieldNames(); i.hasNext(); ) { + String name = i.next(); + String value = inputs.get(name).asText(); + Tensor tensor = getTensor(value); + context.put(name, new TensorValue(tensor)); + } + return context; + } + + private Tensor getTensor(String binaryRepresentation) { + byte[] bin = getBytes(binaryRepresentation); + return TypedBinaryFormat.decode(Optional.empty(), GrowableByteBuffer.wrap(bin)); + } + + private byte[] getBytes(String binaryRepresentation) { + return parseHexValue(binaryRepresentation.substring(2)); + } + + private byte[] parseHexValue(String s) { + final int len = s.length(); + byte[] bytes = new byte[len/2]; + for (int i = 0; i < len; i += 2) { + int c1 = hexValue(s.charAt(i)) << 4; + int c2 = hexValue(s.charAt(i + 1)); + bytes[i/2] = (byte)(c1 + c2); + } + return bytes; + } + + private int hexValue(Character c) { + if (c >= 'a' && c <= 'f') { + return c - 'a' + 10; + } else if (c >= 'A' && c <= 'F') { + return c - 'A' + 10; + } else if (c >= '0' && c <= '9') { + return c - '0'; + } + throw new IllegalArgumentException("Hex contains illegal characters"); + } + + private static String valueType(Value value) { + if (value instanceof StringValue) { + return "string"; + } + if (value instanceof BooleanValue) { + return "boolean"; + } + if (value instanceof DoubleCompatibleValue) { + return "double"; + } + if (value instanceof TensorValue) { + return ((TensorValue)value).asTensor().type().toString(); + } + return "unknown"; + } + + +} + -- cgit v1.2.3 From 59f5ccf748a6e2d6f8e4b7565f106594e1d057f3 Mon Sep 17 00:00:00 2001 From: Tor Egge Date: Sun, 12 Nov 2017 21:11:27 +0100 Subject: Use std::lock_guard instead of std::unique_lock. --- fastlib/src/vespa/fastlib/net/httpserver.cpp | 16 ++++++++-------- fastlib/src/vespa/fastlib/text/normwordfolder.cpp | 4 ++-- .../proton_config_fetcher/proton_config_fetcher_test.cpp | 4 ++-- .../src/vespa/searchcore/fdispatch/common/search.cpp | 4 ++-- .../vespa/searchcore/fdispatch/search/engine_base.cpp | 8 ++++---- .../vespa/searchcore/fdispatch/search/fnet_search.cpp | 6 +++--- .../vespa/searchcore/fdispatch/search/nodemanager.cpp | 16 ++++++++-------- .../searchcore/proton/attribute/attribute_directory.cpp | 6 +++--- .../searchcore/proton/attribute/attributedisklayout.cpp | 6 +++--- .../src/vespa/searchcore/proton/metrics/job_tracker.cpp | 4 ++-- .../proton/reference/document_db_reference_registry.cpp | 2 +- .../vespa/searchcore/proton/server/pendinglidtracker.cpp | 7 +++---- searchlib/src/tests/postinglistbm/andstress.cpp | 2 +- staging_vespalib/src/vespa/vespalib/util/clock.cpp | 2 +- 14 files changed, 43 insertions(+), 44 deletions(-) (limited to 'searchlib') diff --git a/fastlib/src/vespa/fastlib/net/httpserver.cpp b/fastlib/src/vespa/fastlib/net/httpserver.cpp index a9bba95a8ff..0d1b75ec7fe 100644 --- a/fastlib/src/vespa/fastlib/net/httpserver.cpp +++ b/fastlib/src/vespa/fastlib/net/httpserver.cpp @@ -367,7 +367,7 @@ int Fast_HTTPServer::Start(void) int retCode = FASTLIB_SUCCESS; { - std::unique_lock runningGuard(_runningMutex); + std::lock_guard runningGuard(_runningMutex); if (!_isRunning) { // Try listening retCode = Listen(); @@ -391,7 +391,7 @@ int Fast_HTTPServer::Start(void) void Fast_HTTPServer::Stop(void) { { - std::unique_lock runningGuard(_runningMutex); + std::lock_guard runningGuard(_runningMutex); _stopSignalled = true; if (_acceptThread) { _acceptThread->SetBreakFlag(); @@ -407,7 +407,7 @@ Fast_HTTPServer::Stop(void) { bool Fast_HTTPServer::StopSignalled(void) { bool retVal; - std::unique_lock runningGuard(_runningMutex); + std::lock_guard runningGuard(_runningMutex); retVal = _stopSignalled; return retVal; } @@ -458,7 +458,7 @@ void Fast_HTTPServer::Run(FastOS_ThreadInterface *thisThread, void *params) Fast_Socket *mySocket; { - std::unique_lock runningGuard(_runningMutex); + std::lock_guard runningGuard(_runningMutex); _isRunning = true; _stopSignalled = false; } @@ -516,7 +516,7 @@ void Fast_HTTPServer::Run(FastOS_ThreadInterface *thisThread, void *params) _serverSocket.SetSocketEvent(NULL); } - std::unique_lock runningGuard(_runningMutex); + std::lock_guard runningGuard(_runningMutex); _isRunning = false; } @@ -1040,7 +1040,7 @@ void Fast_HTTPServer::HandleFileRequest(const string & url, Fast_HTTPConnection& void Fast_HTTPServer::SetBaseDir(const char *baseDir) { - std::unique_lock runningGuard(_runningMutex); + std::lock_guard runningGuard(_runningMutex); if (!_isRunning) { _baseDir = baseDir; @@ -1178,14 +1178,14 @@ void Fast_HTTPServer::OutputNotFound(Fast_HTTPConnection& conn, void Fast_HTTPServer::AddConnection(Fast_HTTPConnection* connection) { - std::unique_lock connectionGuard(_connectionLock); + std::lock_guard connectionGuard(_connectionLock); _connections.Insert(connection); } void Fast_HTTPServer::RemoveConnection(Fast_HTTPConnection* connection) { - std::unique_lock connectionGuard(_connectionLock); + std::lock_guard connectionGuard(_connectionLock); _connections.RemoveElement(connection); _connectionCond.notify_one(); } diff --git a/fastlib/src/vespa/fastlib/text/normwordfolder.cpp b/fastlib/src/vespa/fastlib/text/normwordfolder.cpp index f383ff85df5..ca1f260515f 100644 --- a/fastlib/src/vespa/fastlib/text/normwordfolder.cpp +++ b/fastlib/src/vespa/fastlib/text/normwordfolder.cpp @@ -29,7 +29,7 @@ Fast_NormalizeWordFolder::Setup(uint32_t flags) { // Only allow setting these when not initialized or initializing... { - std::unique_lock initGuard(_initMutex); + std::lock_guard initGuard(_initMutex); _doAccentRemoval = (DO_ACCENT_REMOVAL & flags) != 0; // _doSmallToNormalKana = (DO_SMALL_TO_NORMAL_KANA & flags) != 0; // _doKatakanaToHiragana = (DO_KATAKANA_TO_HIRAGANA & flags) != 0; @@ -48,7 +48,7 @@ Fast_NormalizeWordFolder::Initialize() { unsigned int i; if (!_isInitialized) { - std::unique_lock initGuard(_initMutex); + std::lock_guard initGuard(_initMutex); if (!_isInitialized) { for (i = 0; i < 128; i++) diff --git a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp index 6c682ea33e9..b9059338f27 100644 --- a/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp +++ b/searchcore/src/tests/proton/proton_config_fetcher/proton_config_fetcher_test.cpp @@ -170,12 +170,12 @@ struct ProtonConfigOwner : public proton::IProtonConfigurer return getConfigured(); } virtual void reconfigure(std::shared_ptr cfg) override { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); _config.set(cfg); _configured = true; } bool getConfigured() const { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); return _configured; } BootstrapConfig::SP getBootstrapConfig() { diff --git a/searchcore/src/vespa/searchcore/fdispatch/common/search.cpp b/searchcore/src/vespa/searchcore/fdispatch/common/search.cpp index 7685ddcc328..7b060e793f6 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/common/search.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/common/search.cpp @@ -165,7 +165,7 @@ void FastS_SyncSearchAdapter::DoneQuery(FastS_ISearch *, FastS_SearchContext) { - std::unique_lock guard(_lock); + std::lock_guard guard(_lock); _queryDone = true; if (_waitQuery) { _cond.notify_one(); @@ -177,7 +177,7 @@ void FastS_SyncSearchAdapter::DoneDocsums(FastS_ISearch *, FastS_SearchContext) { - std::unique_lock guard(_lock); + std::lock_guard guard(_lock); _docsumsDone = true; if (_waitDocsums) { _cond.notify_one(); diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/engine_base.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/engine_base.cpp index 83312a41875..24668db6024 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/engine_base.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/search/engine_base.cpp @@ -112,7 +112,7 @@ void FastS_EngineBase::SlowQuery(double limit, double secs, bool silent) { { - std::unique_lock engineGuard(_lock); + std::lock_guard engineGuard(_lock); _stats._slowQueryCnt++; _stats._slowQuerySecs += secs; } @@ -127,7 +127,7 @@ void FastS_EngineBase::SlowDocsum(double limit, double secs) { { - std::unique_lock engineGuard(_lock); + std::lock_guard engineGuard(_lock); _stats._slowDocsumCnt++; _stats._slowDocsumSecs += secs; } @@ -173,7 +173,7 @@ FastS_EngineBase::SampleQueueLens() double queueLen; double activecnt; - std::unique_lock engineGuard(_lock); + std::lock_guard engineGuard(_lock); if (_stats._queueLenSampleCnt > 0) queueLen = (double) _stats._queueLenSampleAcc / (double) _stats._queueLenSampleCnt; else @@ -217,7 +217,7 @@ FastS_EngineBase::MarkBad(uint32_t badness) bool worse = false; { - std::unique_lock engineGuard(_lock); + std::lock_guard engineGuard(_lock); if (badness > _badness) { _badness = badness; worse = true; diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/fnet_search.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/fnet_search.cpp index 0cfbdc8b69a..85599b9e897 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/fnet_search.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/search/fnet_search.cpp @@ -1077,7 +1077,7 @@ FastS_FNET_Search::Search(uint32_t searchOffset, // allow FNET responses while requests are being sent { - std::unique_lock searchGuard(_lock); + std::lock_guard searchGuard(_lock); ++_pendingQueries; // add Elephant query node to avoid early query done ++_queryNodes; // add Elephant query node to avoid early query done _FNET_mode = FNET_QUERY; @@ -1102,7 +1102,7 @@ FastS_FNET_Search::Search(uint32_t searchOffset, // finalize setup and check if query is still in progress bool done; { - std::unique_lock searchGuard(_lock); + std::lock_guard searchGuard(_lock); assert(_queryNodes >= _pendingQueries); for (uint32_t i: send_failed) { // conditional revert of state for failed nodes @@ -1398,7 +1398,7 @@ FastS_FNET_Search::GetDocsums(const FastS_hitresult *hits, uint32_t hitcnt) ConnectDocsumNodes(ignoreRow); bool done; { - std::unique_lock searchGuard(_lock); + std::lock_guard searchGuard(_lock); // patch in engine dependent features and send docsum requests diff --git a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp index 6fddfae2ab0..4b272a615a6 100644 --- a/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp +++ b/searchcore/src/vespa/searchcore/fdispatch/search/nodemanager.cpp @@ -125,7 +125,7 @@ FastS_NodeManager::CheckTempFail() _checkTempFailScheduled = false; tempfail = false; { - std::unique_lock mangerGuard(_managerLock); + std::lock_guard mangerGuard(_managerLock); FastS_DataSetCollection *dsc = PeekDataSetCollection(); for (unsigned int i = 0; i < dsc->GetMaxNumDataSets(); i++) { FastS_DataSetBase *ds; @@ -166,7 +166,7 @@ uint32_t FastS_NodeManager::SetPartMap(const PartitionsConfig& partmap, unsigned int waitms) { - std::unique_lock configGuard(_configLock); + std::lock_guard configGuard(_configLock); FastS_DataSetCollDesc *configDesc = new FastS_DataSetCollDesc(); if (!configDesc->ReadConfig(partmap)) { LOG(error, "NodeManager::SetPartMap: Failed to load configuration"); @@ -275,7 +275,7 @@ FastS_NodeManager::SetDataSetCollection(FastS_DataSetCollection *dsc) } else { { - std::unique_lock managerGuard(_managerLock); + std::lock_guard managerGuard(_managerLock); _gencnt++; gencnt = _gencnt; @@ -304,7 +304,7 @@ FastS_NodeManager::GetDataSetCollection() { FastS_DataSetCollection *ret; - std::unique_lock managerGuard(_managerLock); + std::lock_guard managerGuard(_managerLock); ret = _datasetCollection; FastS_assert(ret != NULL); ret->addRef(); @@ -320,8 +320,8 @@ FastS_NodeManager::ShutdownConfig() FastS_DataSetCollection *old_dsc; { - std::unique_lock configGuard(_configLock); - std::unique_lock managerGuard(_managerLock); + std::lock_guard configGuard(_configLock); + std::lock_guard managerGuard(_managerLock); _shutdown = true; // disallow SetPartMap dsc = _datasetCollection; _datasetCollection = new FastS_DataSetCollection(_appCtx); @@ -347,7 +347,7 @@ FastS_NodeManager::GetTotalPartitions() uint32_t ret; ret = 0; - std::unique_lock managerGuard(_managerLock); + std::lock_guard managerGuard(_managerLock); FastS_DataSetCollection *dsc = PeekDataSetCollection(); for (unsigned int i = 0; i < dsc->GetMaxNumDataSets(); i++) { FastS_DataSetBase *ds; @@ -429,7 +429,7 @@ FastS_NodeManager::CheckEvents(FastS_TimeKeeper *timeKeeper) FastS_DataSetCollection *tmp; { - std::unique_lock managerGuard(_managerLock); + std::lock_guard managerGuard(_managerLock); old_dsc = _oldDSCList; } diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attribute_directory.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attribute_directory.cpp index eac994bb339..f775f4443f8 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attribute_directory.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attribute_directory.cpp @@ -56,7 +56,7 @@ AttributeDirectory::getDirName() const { std::shared_ptr diskLayout; { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); assert(!_diskLayout.expired()); diskLayout = _diskLayout.lock(); } @@ -204,7 +204,7 @@ void AttributeDirectory::detach() { assert(empty()); - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); _diskLayout.reset(); } @@ -238,7 +238,7 @@ AttributeDirectory::tryGetWriter() bool AttributeDirectory::empty() const { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); return _snapInfo.snapshots().empty(); } diff --git a/searchcore/src/vespa/searchcore/proton/attribute/attributedisklayout.cpp b/searchcore/src/vespa/searchcore/proton/attribute/attributedisklayout.cpp index 1fcffa92cce..bb2f99d077b 100644 --- a/searchcore/src/vespa/searchcore/proton/attribute/attributedisklayout.cpp +++ b/searchcore/src/vespa/searchcore/proton/attribute/attributedisklayout.cpp @@ -59,7 +59,7 @@ AttributeDiskLayout::getAttributeDir(const vespalib::string &name) std::shared_ptr AttributeDiskLayout::createAttributeDir(const vespalib::string &name) { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); auto itr = _dirs.find(name); if (itr == _dirs.end()) { auto dir = std::make_shared(shared_from_this(), name); @@ -81,7 +81,7 @@ AttributeDiskLayout::removeAttributeDir(const vespalib::string &name, search::Se writer->invalidateOldSnapshots(serialNum); writer->removeInvalidSnapshots(); if (writer->removeDiskDir()) { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); auto itr = _dirs.find(name); assert(itr != _dirs.end()); assert(dir.get() == itr->second.get()); @@ -89,7 +89,7 @@ AttributeDiskLayout::removeAttributeDir(const vespalib::string &name, search::Se writer->detach(); } } else { - std::unique_lock guard(_mutex); + std::lock_guard guard(_mutex); auto itr = _dirs.find(name); if (itr != _dirs.end()) { assert(dir.get() != itr->second.get()); diff --git a/searchcore/src/vespa/searchcore/proton/metrics/job_tracker.cpp b/searchcore/src/vespa/searchcore/proton/metrics/job_tracker.cpp index 753c84cd9b6..6d05ce8c57d 100644 --- a/searchcore/src/vespa/searchcore/proton/metrics/job_tracker.cpp +++ b/searchcore/src/vespa/searchcore/proton/metrics/job_tracker.cpp @@ -20,14 +20,14 @@ JobTracker::sampleLoad(time_point now, const std::lock_guard &guard) void JobTracker::start() { - std::unique_lock guard(_lock); + std::lock_guard guard(_lock); _sampler.startJob(std::chrono::steady_clock::now()); } void JobTracker::end() { - std::unique_lock guard(_lock); + std::lock_guard guard(_lock); _sampler.endJob(std::chrono::steady_clock::now()); } diff --git a/searchcore/src/vespa/searchcore/proton/reference/document_db_reference_registry.cpp b/searchcore/src/vespa/searchcore/proton/reference/document_db_reference_registry.cpp index 75a20f9f8e5..68aaad3b557 100644 --- a/searchcore/src/vespa/searchcore/proton/reference/document_db_reference_registry.cpp +++ b/searchcore/src/vespa/searchcore/proton/reference/document_db_reference_registry.cpp @@ -30,7 +30,7 @@ DocumentDBReferenceRegistry::get(vespalib::stringref name) const std::shared_ptr DocumentDBReferenceRegistry::tryGet(vespalib::stringref name) const { - std::unique_lock guard(_lock); + std::lock_guard guard(_lock); auto itr = _handlers.find(name); if (itr == _handlers.end()) { return std::shared_ptr(); diff --git a/searchcore/src/vespa/searchcore/proton/server/pendinglidtracker.cpp b/searchcore/src/vespa/searchcore/proton/server/pendinglidtracker.cpp index 15283c170cc..79bf970aeac 100644 --- a/searchcore/src/vespa/searchcore/proton/server/pendinglidtracker.cpp +++ b/searchcore/src/vespa/searchcore/proton/server/pendinglidtracker.cpp @@ -6,7 +6,6 @@ namespace proton { -using LockGuard = std::unique_lock; PendingLidTracker::PendingLidTracker() : _mutex(), _cond(), @@ -19,12 +18,12 @@ PendingLidTracker::~PendingLidTracker() { void PendingLidTracker::produce(uint32_t lid) { - LockGuard guard(_mutex); + std::lock_guard guard(_mutex); _pending[lid]++; } void PendingLidTracker::consume(uint32_t lid) { - LockGuard guard(_mutex); + std::lock_guard guard(_mutex); auto found = _pending.find(lid); assert (found != _pending.end()); assert (found->second > 0); @@ -38,7 +37,7 @@ PendingLidTracker::consume(uint32_t lid) { void PendingLidTracker::waitForConsumedLid(uint32_t lid) { - LockGuard guard(_mutex); + std::unique_lock guard(_mutex); while (_pending.find(lid) != _pending.end()) { _cond.wait(guard); } diff --git a/searchlib/src/tests/postinglistbm/andstress.cpp b/searchlib/src/tests/postinglistbm/andstress.cpp index 736d53508b4..40f919509e8 100644 --- a/searchlib/src/tests/postinglistbm/andstress.cpp +++ b/searchlib/src/tests/postinglistbm/andstress.cpp @@ -280,7 +280,7 @@ AndStressMaster::Task * AndStressMaster::getTask() { Task *result = NULL; - std::unique_lock taskGuard(_taskLock); + std::lock_guard taskGuard(_taskLock); if (_taskIdx < _tasks.size()) { result = &_tasks[_taskIdx]; ++_taskIdx; diff --git a/staging_vespalib/src/vespa/vespalib/util/clock.cpp b/staging_vespalib/src/vespa/vespalib/util/clock.cpp index b19b067afa9..c9768417914 100644 --- a/staging_vespalib/src/vespa/vespalib/util/clock.cpp +++ b/staging_vespalib/src/vespa/vespalib/util/clock.cpp @@ -45,7 +45,7 @@ void Clock::Run(FastOS_ThreadInterface *thread, void *arguments) void Clock::stop(void) { - std::unique_lock guard(_lock); + std::lock_guard guard(_lock); _stop = true; _cond.notify_all(); } -- cgit v1.2.3 From fa7d6c2ec6180d69568a75a7293bc97294a5c811 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 14 Nov 2017 11:11:57 +0100 Subject: Fix 'Tensors cannot be compared with ~=' --- .../rankingexpression/evaluation/TensorValue.java | 1 + vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 17 ++++++++++++++++- 2 files changed, 17 insertions(+), 1 deletion(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 6cf15837da1..88abbe279aa 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -103,6 +103,7 @@ public class TensorValue extends Value { case SMALLEREQUAL: return value.smallerOrEqual(argument); case EQUAL: return value.equal(argument); case NOTEQUAL: return value.notEqual(argument); + case APPROX_EQUAL: return value.approxEqual(argument); default: throw new UnsupportedOperationException("Tensors cannot be compared with " + operator); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 8fc80e3b440..10098e24e76 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -177,6 +177,7 @@ public interface Tensor { default Tensor smallerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a <= b ? 1.0 : 0.0)); } default Tensor equal(Tensor argument) { return join(argument, (a, b) -> ( a == b ? 1.0 : 0.0)); } default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } + default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); } default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); } default Tensor avg(List dimensions) { return reduce(Reduce.Aggregator.avg, dimensions); } @@ -261,11 +262,25 @@ public interface Tensor { Cell aCell = aIterator.next(); double aValue = aCell.getValue(); double bValue = b.get(aCell.getKey()); - if (Math.abs(aValue-bValue) > 1e-7) return false; // TODO: determine relative precision + if (!approxEquals(aValue, bValue, 1e-6)) return false; } return true; } + static boolean approxEquals(double x, double y, double tolerance) { + return Math.abs(x-y) < tolerance; + } + + static boolean approxEquals(double x, double y) { + if (y < -1.0 || y > 1.0) { + x = Math.nextAfter(x/y, 1.0); + y = 1.0; + } else { + x = Math.nextAfter(x, y); + } + return x==y; + } + // ----------------- Factories /** -- cgit v1.2.3 From 21192547ac52e1ed45dffdd50ec99e02c04ac8cd Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 14 Nov 2017 11:16:09 +0100 Subject: Fix 'Cannot combine two tensors using pow' --- .../com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java | 1 + vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 1 + 2 files changed, 2 insertions(+) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 88abbe279aa..42bf7b75141 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -121,6 +121,7 @@ public class TensorValue extends Value { case min: return value.min(argument); case max: return value.max(argument); case atan2: return value.atan2(argument); + case pow: return value.pow(argument); default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 10098e24e76..cb9ed1f5eb8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -171,6 +171,7 @@ public interface Tensor { default Tensor max(Tensor argument) { return join(argument, (a, b) -> (a > b ? a : b )); } default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } + default Tensor pow(Tensor argument) { return join(argument, Math::pow); } default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } -- cgit v1.2.3 From 58357a693f9ff78477e1d1a99ecbfad59a331420 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 14 Nov 2017 11:25:07 +0100 Subject: Fix 'Cannot combine two tensors using fmod' --- .../com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java | 1 + vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 1 + 2 files changed, 2 insertions(+) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 42bf7b75141..935aadc4559 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -122,6 +122,7 @@ public class TensorValue extends Value { case max: return value.max(argument); case atan2: return value.atan2(argument); case pow: return value.pow(argument); + case fmod: return value.fmod(argument); default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index cb9ed1f5eb8..bdf976819bb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -172,6 +172,7 @@ public interface Tensor { default Tensor min(Tensor argument) { return join(argument, (a, b) -> (a < b ? a : b )); } default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } default Tensor pow(Tensor argument) { return join(argument, Math::pow); } + default Tensor fmod(Tensor argument) { return join(argument, (a, b) -> ( a % b )); } default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } -- cgit v1.2.3 From fe52b637bde1e2e1b11b6158666f9002e1c3cd0a Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 14 Nov 2017 12:40:22 +0100 Subject: Fix 'Cannot combine two tensors using ldexp' Also, make 'ldexp' in Java comparable to C++, as the library call in C++ does an implicit cast to int. --- .../com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java | 1 + .../main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java | 2 +- vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 1 + 3 files changed, 3 insertions(+), 1 deletion(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 935aadc4559..45988ef0776 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -123,6 +123,7 @@ public class TensorValue extends Value { case atan2: return value.atan2(argument); case pow: return value.pow(argument); case fmod: return value.fmod(argument); + case ldexp: return value.ldexp(argument); default: throw new UnsupportedOperationException("Cannot combine two tensors using " + function); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java index fc4a511b307..c3c1c371a68 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Function.java @@ -39,7 +39,7 @@ public enum Function implements Serializable { atan2(2) { public double evaluate(double x, double y) { return atan2(x,y); } }, fmod(2) { public double evaluate(double x, double y) { return x % y; } }, - ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,y); } }, + ldexp(2) { public double evaluate(double x, double y) { return x*pow(2,(int)y); } }, max(2) { public double evaluate(double x, double y) { return max(x,y); } }, min(2) { public double evaluate(double x, double y) { return min(x,y); } }, pow(2) { public double evaluate(double x, double y) { return pow(x,y); } }; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index bdf976819bb..2ed211539d8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -173,6 +173,7 @@ public interface Tensor { default Tensor atan2(Tensor argument) { return join(argument, Math::atan2); } default Tensor pow(Tensor argument) { return join(argument, Math::pow); } default Tensor fmod(Tensor argument) { return join(argument, (a, b) -> ( a % b )); } + default Tensor ldexp(Tensor argument) { return join(argument, (a, b) -> ( a * Math.pow(2.0, (int)b) )); } default Tensor larger(Tensor argument) { return join(argument, (a, b) -> ( a > b ? 1.0 : 0.0)); } default Tensor largerOrEqual(Tensor argument) { return join(argument, (a, b) -> ( a >= b ? 1.0 : 0.0)); } default Tensor smaller(Tensor argument) { return join(argument, (a, b) -> ( a < b ? 1.0 : 0.0)); } -- cgit v1.2.3 From 56fad78e8d03ff18348ddb34d34d5ff2431b9128 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 14 Nov 2017 13:53:31 +0100 Subject: Add % to Java ranking expressions to conform with C++ --- .../rankingexpression/evaluation/DoubleCompatibleValue.java | 5 +++++ .../searchlib/rankingexpression/evaluation/DoubleValue.java | 11 +++++++++++ .../searchlib/rankingexpression/evaluation/StringValue.java | 5 +++++ .../searchlib/rankingexpression/evaluation/TensorValue.java | 9 +++++++++ .../yahoo/searchlib/rankingexpression/evaluation/Value.java | 2 ++ .../searchlib/rankingexpression/rule/ArithmeticOperator.java | 6 +++++- searchlib/src/main/javacc/RankingExpressionParser.jj | 4 +++- 7 files changed, 40 insertions(+), 2 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index dab89fe8955..0ed2bdd6331 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -38,6 +38,11 @@ public abstract class DoubleCompatibleValue extends Value { return new DoubleValue(asDouble() / value.asDouble()); } + @Override + public Value modulo(Value value) { + return new DoubleValue(asDouble() % value.asDouble()); + } + @Override public Value compare(TruthOperator operator, Value value) { return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 28272e58c91..17157ab385f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -97,6 +97,17 @@ public final class DoubleValue extends DoubleCompatibleValue { } } + @Override + public Value modulo(Value value) { + try { + return mutable(this.value % value.asDouble()); + } + catch (UnsupportedOperationException e) { + throw unsupported("modulo",value); + } + } + + @Override public Value function(Function function, Value value) { // use the tensor implementation of max and min if the argument is a tensor diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index acf301f3b80..5374a9d3ce6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -67,6 +67,11 @@ public class StringValue extends Value { throw new UnsupportedOperationException("String values ('" + value + "') does not support division"); } + @Override + public Value modulo(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') does not support modulo"); + } + @Override public Value compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 45988ef0776..b283603e713 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -81,6 +81,15 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> value / argument.asDouble())); } + @Override + public Value modulo(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.fmod(((TensorValue) argument).value)); + else + return new TensorValue(value.map((value) -> value % argument.asDouble())); + } + + private Tensor asTensor(Value value, String operationName) { if ( ! (value instanceof TensorValue)) throw new UnsupportedOperationException("Could not perform " + operationName + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index a63387506a0..f42082321b3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -41,6 +41,8 @@ public abstract class Value { public abstract Value divide(Value value); + public abstract Value modulo(Value value); + /** Perform the comparison specified by the operator between this value and the given value */ public abstract Value compare(TruthOperator operator, Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java index 5a5237c2608..2187a96ba4d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java @@ -25,8 +25,11 @@ public enum ArithmeticOperator { }}, DIVIDE(3, "/") { public Value evaluate(Value x, Value y) { return x.divide(y); + }}, + MODULO(4, "%") { public Value evaluate(Value x, Value y) { + return x.modulo(y); }}; - + /** A list of all the operators in this in order of decreasing precedence */ public static final List operatorsByPrecedence = operatorsByPrecedence(); @@ -52,6 +55,7 @@ public enum ArithmeticOperator { private static List operatorsByPrecedence() { List operators = new ArrayList<>(); + operators.add(MODULO); operators.add(DIVIDE); operators.add(MULTIPLY); operators.add(MINUS); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index c3b9235cc93..01fed00202c 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -65,6 +65,7 @@ TOKEN : | | | + | | | @@ -202,7 +203,8 @@ ArithmeticOperator arithmetic() : { } ( { return ArithmeticOperator.PLUS; } | { return ArithmeticOperator.MINUS; } |
{ return ArithmeticOperator.DIVIDE; } | - { return ArithmeticOperator.MULTIPLY; } ) + { return ArithmeticOperator.MULTIPLY; } | + { return ArithmeticOperator.MODULO; } ) { return null; } } -- cgit v1.2.3 From 7ac79b89486b537882894970922c521772eb3ad6 Mon Sep 17 00:00:00 2001 From: Håvard Pettersen Date: Tue, 14 Nov 2017 14:15:49 +0000 Subject: we need to keep all iterators to handle additional ranges --- .../src/vespa/searchlib/queryeval/multisearch.cpp | 26 +++++----------------- .../src/vespa/searchlib/queryeval/multisearch.h | 1 - 2 files changed, 5 insertions(+), 22 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp b/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp index 19d744dfd28..b63a54785a4 100644 --- a/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp +++ b/searchlib/src/vespa/searchlib/queryeval/multisearch.cpp @@ -27,32 +27,16 @@ MultiSearch::remove(size_t index) void MultiSearch::doUnpack(uint32_t docid) { - size_t sz(_children.size()); - for (size_t i = 0; i < sz; ) { - if (__builtin_expect(_children[i]->getDocId() < docid, false)) { - _children[i]->doSeek(docid); - if (_children[i]->isAtEnd()) { - sz = deactivate(i); - continue; - } + for (SearchIterator *child: _children) { + if (__builtin_expect(child->getDocId() < docid, false)) { + child->doSeek(docid); } - if (__builtin_expect(_children[i]->getDocId() == docid, false)) { - _children[i]->doUnpack(docid); + if (__builtin_expect(child->getDocId() == docid, false)) { + child->doUnpack(docid); } - i++; } } -size_t -MultiSearch::deactivate(size_t idx) -{ - assert(idx < _children.size()); - delete _children[idx]; - _children[idx] = _children.back(); - _children.pop_back(); - return _children.size(); -} - MultiSearch::MultiSearch(const Children & children) : _children(children) { diff --git a/searchlib/src/vespa/searchlib/queryeval/multisearch.h b/searchlib/src/vespa/searchlib/queryeval/multisearch.h index 16bbd5d4ecc..d67f895ddb5 100644 --- a/searchlib/src/vespa/searchlib/queryeval/multisearch.h +++ b/searchlib/src/vespa/searchlib/queryeval/multisearch.h @@ -54,7 +54,6 @@ private: virtual void onInsert(size_t index) { (void) index; } bool isMultiSearch() const override { return true; } - size_t deactivate(size_t index); Children _children; }; -- cgit v1.2.3 From 9b913de9cd46de589d8f29436bfc46c1159a53de Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 15 Nov 2017 14:39:56 +0100 Subject: Add boolean operators to Java ranking evaluation --- .../evaluation/DoubleCompatibleValue.java | 15 +++++++ .../rankingexpression/evaluation/StringValue.java | 23 ++++++++-- .../rankingexpression/evaluation/TensorValue.java | 20 +++++++++ .../rankingexpression/evaluation/Value.java | 6 +++ .../rankingexpression/rule/ArithmeticNode.java | 2 +- .../rankingexpression/rule/ArithmeticOperator.java | 20 ++++++--- .../searchlib/rankingexpression/rule/NotNode.java | 50 ++++++++++++++++++++++ .../src/main/javacc/RankingExpressionParser.jj | 29 +++++++++---- .../evaluation/EvaluationTestCase.java | 33 ++++++++++++++ .../evaluation/EvaluationTester.java | 8 ++++ 10 files changed, 187 insertions(+), 19 deletions(-) create mode 100644 searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 0ed2bdd6331..0868af9bc72 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -43,6 +43,21 @@ public abstract class DoubleCompatibleValue extends Value { return new DoubleValue(asDouble() % value.asDouble()); } + @Override + public Value and(Value value) { + return new BooleanValue(asBoolean() && value.asBoolean()); + } + + @Override + public Value or(Value value) { + return new BooleanValue(asBoolean() || value.asBoolean()); + } + + @Override + public Value not() { + return new BooleanValue(!asBoolean()); + } + @Override public Value compare(TruthOperator operator, Value value) { return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index 5374a9d3ce6..b62081f2c6a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -54,22 +54,37 @@ public class StringValue extends Value { @Override public Value subtract(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support subtraction"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support subtraction"); } @Override public Value multiply(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support multiplication"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support multiplication"); } @Override public Value divide(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support division"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support division"); } @Override public Value modulo(Value value) { - throw new UnsupportedOperationException("String values ('" + value + "') does not support modulo"); + throw new UnsupportedOperationException("String values ('" + value + "') do not support modulo"); + } + + @Override + public Value and(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support and"); + } + + @Override + public Value or(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support or"); + } + + @Override + public Value not() { + throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); } @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index b283603e713..919a23eeaf5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -89,6 +89,26 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> value % argument.asDouble())); } + @Override + public Value and(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) && (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) && argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value or(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.join(((TensorValue)argument).value, (a, b) -> ((a!=0.0) || (b!=0.0)) ? 1.0 : 0.0 )); + else + return new TensorValue(value.map((value) -> ((value!=0.0) || argument.asBoolean()) ? 1 : 0)); + } + + @Override + public Value not() { + return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); + } private Tensor asTensor(Value value, String operationName) { if ( ! (value instanceof TensorValue)) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index f42082321b3..bcbce6e646f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -43,6 +43,12 @@ public abstract class Value { public abstract Value modulo(Value value); + public abstract Value and(Value value); + + public abstract Value or(Value value); + + public abstract Value not(); + /** Perform the comparison specified by the operator between this value and the given value */ public abstract Value compare(TruthOperator operator, Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index 91d8abec1be..518a15bcc87 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -77,7 +77,7 @@ public final class ArithmeticNode extends CompositeNode { Iterator child = children.iterator(); Deque stack = new ArrayDeque<>(); - stack.push(new ValueItem(ArithmeticOperator.PLUS, child.next().evaluate(context))); + stack.push(new ValueItem(ArithmeticOperator.OR, child.next().evaluate(context))); for (Iterator it = operators.iterator(); it.hasNext() && child.hasNext();) { ArithmeticOperator op = it.next(); if (!stack.isEmpty()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java index 2187a96ba4d..aae59fe2af8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java @@ -14,22 +14,28 @@ import java.util.List; */ public enum ArithmeticOperator { - PLUS(0, "+") { public Value evaluate(Value x, Value y) { + OR(0, "||") { public Value evaluate(Value x, Value y) { + return x.or(y); + }}, + AND(1, "&&") { public Value evaluate(Value x, Value y) { + return x.and(y); + }}, + PLUS(2, "+") { public Value evaluate(Value x, Value y) { return x.add(y); }}, - MINUS(1, "-") { public Value evaluate(Value x, Value y) { + MINUS(3, "-") { public Value evaluate(Value x, Value y) { return x.subtract(y); }}, - MULTIPLY(2, "*") { public Value evaluate(Value x, Value y) { + MULTIPLY(4, "*") { public Value evaluate(Value x, Value y) { return x.multiply(y); }}, - DIVIDE(3, "/") { public Value evaluate(Value x, Value y) { + DIVIDE(5, "/") { public Value evaluate(Value x, Value y) { return x.divide(y); }}, - MODULO(4, "%") { public Value evaluate(Value x, Value y) { + MODULO(6, "%") { public Value evaluate(Value x, Value y) { return x.modulo(y); }}; - + /** A list of all the operators in this in order of decreasing precedence */ public static final List operatorsByPrecedence = operatorsByPrecedence(); @@ -60,6 +66,8 @@ public enum ArithmeticOperator { operators.add(MULTIPLY); operators.add(MINUS); operators.add(PLUS); + operators.add(AND); + operators.add(OR); return Collections.unmodifiableList(operators); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java new file mode 100644 index 00000000000..8c459a032bd --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -0,0 +1,50 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; + +import java.util.Collections; +import java.util.Deque; +import java.util.List; + +/** + * A node which flips the logical value produced from the nested expression. + * + * @author lesters + */ +public class NotNode extends BooleanNode { + + private final ExpressionNode value; + + public NotNode(ExpressionNode value) { + this.value = value; + } + + public ExpressionNode getValue() { + return value; + } + + @Override + public List children() { + return Collections.singletonList(value); + } + + @Override + public String toString(SerializationContext context, Deque path, CompositeNode parent) { + return "!" + value.toString(context, path, parent); + } + + @Override + public Value evaluate(Context context) { + return value.evaluate(context).not(); + } + + @Override + public NotNode setChildren(List children) { + if (children.size() != 1) throw new IllegalArgumentException("Expected 1 children but got " + children.size()); + return new NotNode(children.get(0)); + } + +} + diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 01fed00202c..035a92b0365 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -86,6 +86,10 @@ TOKEN : | | + | + | + | + | | | @@ -204,7 +208,9 @@ ArithmeticOperator arithmetic() : { } { return ArithmeticOperator.MINUS; } |
{ return ArithmeticOperator.DIVIDE; } | { return ArithmeticOperator.MULTIPLY; } | - { return ArithmeticOperator.MODULO; } ) + { return ArithmeticOperator.MODULO; } | + { return ArithmeticOperator.AND; } | + { return ArithmeticOperator.OR; } ) { return null; } } @@ -224,16 +230,23 @@ ExpressionNode value() : { ExpressionNode ret; boolean neg = false; + boolean not = false; } { - ( [ LOOKAHEAD(2) { neg = true; } ] - ( ret = constantPrimitive() | - LOOKAHEAD(2) ret = ifExpression() | - LOOKAHEAD(4) ret = function() | - ret = feature() | - ret = queryFeature() | + ( + [ { not = true; } ] + [ LOOKAHEAD(2) { neg = true; } ] + ( ret = constantPrimitive() | + LOOKAHEAD(2) ret = ifExpression() | + LOOKAHEAD(4) ret = function() | + ret = feature() | + ret = queryFeature() | ( ret = expression() { ret = new EmbracedNode(ret); } ) ) ) - { return neg ? new NegativeNode(ret) : ret; } + { + ret = not ? new NotNode(ret) : ret; + ret = neg ? new NegativeNode(ret) : ret; + return ret; + } } IfNode ifExpression() : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 5d357777657..26d3695dd07 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -37,6 +37,7 @@ public class EvaluationTestCase { tester.assertEvaluates(26, "2*3+4*5"); tester.assertEvaluates(1, "2/6+4/6"); tester.assertEvaluates(2 * 3 * 4 + 3 * 4 * 5 - 4 * 200 / 10, "2*3*4+3*4*5-4*200/10"); + tester.assertEvaluates(3, "1 + 10 % 6 / 2"); // Conditionals tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10"); @@ -88,6 +89,38 @@ public class EvaluationTestCase { tester.assertEvaluates(1.25, "5*if(1>=1.1, one_half, if(min(1,2) Date: Wed, 15 Nov 2017 15:40:00 +0100 Subject: Add Java ranking power operator --- .../evaluation/DoubleCompatibleValue.java | 5 +++++ .../rankingexpression/evaluation/StringValue.java | 5 +++++ .../rankingexpression/evaluation/TensorValue.java | 8 ++++++++ .../searchlib/rankingexpression/evaluation/Value.java | 2 ++ .../rankingexpression/rule/ArithmeticOperator.java | 4 ++++ searchlib/src/main/javacc/RankingExpressionParser.jj | 16 +++++++++------- .../rankingexpression/evaluation/EvaluationTestCase.java | 3 +++ 7 files changed, 36 insertions(+), 7 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java index 0868af9bc72..ea750295423 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleCompatibleValue.java @@ -58,6 +58,11 @@ public abstract class DoubleCompatibleValue extends Value { return new BooleanValue(!asBoolean()); } + @Override + public Value power(Value value) { + return new DoubleValue(Function.pow.evaluate(asDouble(), value.asDouble())); + } + @Override public Value compare(TruthOperator operator, Value value) { return new BooleanValue(operator.evaluate(asDouble(), value.asDouble())); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java index b62081f2c6a..ac8aba6a617 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/StringValue.java @@ -87,6 +87,11 @@ public class StringValue extends Value { throw new UnsupportedOperationException("String values ('" + value + "') do not support not"); } + @Override + public Value power(Value value) { + throw new UnsupportedOperationException("String values ('" + value + "') do not support ^"); + } + @Override public Value compare(TruthOperator operator, Value value) { if (operator.equals(TruthOperator.EQUAL)) diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java index 919a23eeaf5..49c3ccb7b01 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/TensorValue.java @@ -110,6 +110,14 @@ public class TensorValue extends Value { return new TensorValue(value.map((value) -> (value==0.0) ? 1.0 : 0.0)); } + @Override + public Value power(Value argument) { + if (argument instanceof TensorValue) + return new TensorValue(value.pow(((TensorValue)argument).value)); + else + return new TensorValue(value.map((value) -> Math.pow(value, argument.asDouble()))); + } + private Tensor asTensor(Value value, String operationName) { if ( ! (value instanceof TensorValue)) throw new UnsupportedOperationException("Could not perform " + operationName + diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java index bcbce6e646f..b2ccbe572d0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Value.java @@ -49,6 +49,8 @@ public abstract class Value { public abstract Value not(); + public abstract Value power(Value value); + /** Perform the comparison specified by the operator between this value and the given value */ public abstract Value compare(TruthOperator operator, Value value); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java index aae59fe2af8..a715490e95a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticOperator.java @@ -34,6 +34,9 @@ public enum ArithmeticOperator { }}, MODULO(6, "%") { public Value evaluate(Value x, Value y) { return x.modulo(y); + }}, + POWER(7, "^") { public Value evaluate(Value x, Value y) { + return x.power(y); }}; /** A list of all the operators in this in order of decreasing precedence */ @@ -61,6 +64,7 @@ public enum ArithmeticOperator { private static List operatorsByPrecedence() { List operators = new ArrayList<>(); + operators.add(POWER); operators.add(MODULO); operators.add(DIVIDE); operators.add(MULTIPLY); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 035a92b0365..7821ab88b86 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -66,6 +66,7 @@ TOKEN : | | | + | | | @@ -204,13 +205,14 @@ ExpressionNode arithmeticExpression() : ArithmeticOperator arithmetic() : { } { - ( { return ArithmeticOperator.PLUS; } | - { return ArithmeticOperator.MINUS; } | -
{ return ArithmeticOperator.DIVIDE; } | - { return ArithmeticOperator.MULTIPLY; } | - { return ArithmeticOperator.MODULO; } | - { return ArithmeticOperator.AND; } | - { return ArithmeticOperator.OR; } ) + ( { return ArithmeticOperator.PLUS; } | + { return ArithmeticOperator.MINUS; } | +
{ return ArithmeticOperator.DIVIDE; } | + { return ArithmeticOperator.MULTIPLY; } | + { return ArithmeticOperator.MODULO; } | + { return ArithmeticOperator.AND; } | + { return ArithmeticOperator.OR; } | + { return ArithmeticOperator.POWER; } ) { return null; } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 26d3695dd07..8e34f35245d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -29,6 +29,7 @@ public class EvaluationTestCase { tester.assertEvaluates(0.75, "0.5 + 0.25"); tester.assertEvaluates(0.75, "one_half + a_quarter"); tester.assertEvaluates(1.25, "0.5 - 0.25 + one"); + tester.assertEvaluates(9.0, "3 ^ 2"); // String tester.assertEvaluates(1, "if(\"a\"==\"a\",1,0)"); @@ -38,6 +39,8 @@ public class EvaluationTestCase { tester.assertEvaluates(1, "2/6+4/6"); tester.assertEvaluates(2 * 3 * 4 + 3 * 4 * 5 - 4 * 200 / 10, "2*3*4+3*4*5-4*200/10"); tester.assertEvaluates(3, "1 + 10 % 6 / 2"); + tester.assertEvaluates(10.0, "3 ^ 2 + 1"); + tester.assertEvaluates(18.0, "2 * 3 ^ 2"); // Conditionals tester.assertEvaluates(2 * (3 * 4 + 3) * (4 * 5 - 4 * 200) / 10, "2*(3*4+3)*(4*5-4*200)/10"); -- cgit v1.2.3 From ed9640e21c4b918b26db24a5b2fb3ee877bd0ce8 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 15 Nov 2017 19:43:59 +0100 Subject: Add Java ranking set membership for tensors --- .../rankingexpression/rule/SetMembershipNode.java | 32 +++++++++++++++++++--- 1 file changed, 28 insertions(+), 4 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index f8e44f1087c..f6b1a1a8979 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -4,9 +4,14 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; -import java.util.*; +import java.util.ArrayList; +import java.util.Deque; +import java.util.List; +import java.util.function.Predicate; /** * A node which returns true or false depending on a set membership test @@ -55,11 +60,30 @@ public class SetMembershipNode extends BooleanNode { @Override public Value evaluate(Context context) { Value value = testValue.evaluate(context); + if (value instanceof TensorValue) { + return evaluateTensor(((TensorValue) value).asTensor(), context); + } + return evaluateValue(value, context); + } + + private Value evaluateValue(Value value, Context context) { + return new BooleanValue(testMembership(value::equals, context)); + } + + private Value evaluateTensor(Tensor tensor, Context context) { + return new TensorValue(tensor.map((value) -> contains(value, context) ? 1.0 : 0.0)); + } + + private boolean contains(double value, Context context) { + return testMembership((setValue) -> setValue.asDouble() == value, context); + } + + private boolean testMembership(Predicate test, Context context) { for (ExpressionNode setValue : setValues) { - if (setValue.evaluate(context).equals(value)) - return new BooleanValue(true); + if (test.test(setValue.evaluate(context))) + return true; } - return new BooleanValue(false); + return false; } @Override -- cgit v1.2.3 From d248daea0a53004b7f15fb36393504d182171f01 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Thu, 16 Nov 2017 12:26:32 +0100 Subject: Enable Java tensor conformance test --- .../evaluation/EvaluationTestCase.java | 28 ++++++++++- .../searchlib/tensor/TensorConformanceTest.java | 54 ++++++++-------------- .../tests/rankingexpression/rankingexpressionlist | 4 ++ 3 files changed, 50 insertions(+), 36 deletions(-) (limited to 'searchlib') diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 8e34f35245d..82e5d0cfe5b 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -143,6 +143,16 @@ public class EvaluationTestCase { "min(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }"); tester.assertEvaluates("{ {d1:0}:0, {d1:1}:0, {d1:2 }:10 }", "max(tensor0, 0)", "{ {d1:0}:-10, {d1:1}:0, {d1:2}:10 }"); + // operators + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 % 2 == map(tensor0, f(x) (x % 2))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 || 1 == map(tensor0, f(x) (x || 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "tensor0 && 1 == map(tensor0, f(x) (x && 1))", "{ {d1:0}:2, {d1:1}:3, {d1:2}:4 }"); + tester.assertEvaluates("{ {d1:0}:1, {d1:1}:1, {d1:2 }:1 }", + "!tensor0 == map(tensor0, f(x) (!x))", "{ {d1:0}:0, {d1:1}:1, {d1:2}:0 }"); + // -- explicitly implemented functions (not foolproof tests as we don't bother testing float value equivalence) tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "abs(tensor0)", "{ {x:0}:1, {x:1}:-2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "acos(tensor0)", "{ {x:0}:1, {x:1}:1 }"); @@ -158,8 +168,9 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "isNan(tensor0)", "{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:0 }", "log(tensor0)", "{ {x:0}:1, {x:1}:1 }"); tester.assertEvaluates("{ {x:0}:0, {x:1}:1 }", "log10(tensor0)", "{ {x:0}:1, {x:1}:10 }"); - tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)", "{ {x:0}:3, {x:1}:8 }"); + tester.assertEvaluates("{ {x:0}:0, {x:1}:2 }", "fmod(tensor0, 3)","{ {x:0}:3, {x:1}:8 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:8 }", "pow(tensor0, 3)", "{ {x:0}:1, {x:1}:2 }"); + tester.assertEvaluates("{ {x:0}:8, {x:1}:16 }", "ldexp(tensor0,3.1)","{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "relu(tensor0)", "{ {x:0}:1, {x:1}:2 }"); tester.assertEvaluates("{ {x:0}:1, {x:1}:2 }", "round(tensor0)", "{ {x:0}:1, {x:1}:1.8 }"); tester.assertEvaluates("{ {x:0}:0.5, {x:1}:0.5 }", "sigmoid(tensor0)","{ {x:0}:0, {x:1}:0 }"); @@ -237,6 +248,16 @@ public class EvaluationTestCase { "max(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:5 }", "min(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }", + "pow(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:243, {x:1,y:0}:16807 }", + "tensor0 ^ tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }", + "fmod(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:3, {x:1,y:0}:2 }", + "tensor0 % tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); + tester.assertEvaluates("{ {x:0,y:0}:96, {x:1,y:0}:224 }", + "ldexp(tensor0, tensor1)", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5.1 }"); tester.assertEvaluates("{ {x:0,y:0,z:0}:7, {x:0,y:0,z:1}:13, {x:1,y:0,z:0}:21, {x:1,y:0,z:1}:39, {x:0,y:1,z:0}:55, {x:0,y:1,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:0 }", "tensor0 * tensor1", "{ {x:0,y:0}:1, {x:1,y:0}:3, {x:0,y:1}:5, {x:1,y:1}:0 }", "{ {y:0,z:0}:7, {y:1,z:0}:11, {y:0,z:1}:13, {y:1,z:1}:0 }"); tester.assertEvaluates("{ {x:0,y:1,z:0}:35, {x:0,y:1,z:1}:65 }", @@ -261,8 +282,13 @@ public class EvaluationTestCase { "tensor0 <= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:5 }"); tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }", "tensor0 == tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); + tester.assertEvaluates("{ {x:0,y:0}:0, {x:1,y:0}:1 }", + "tensor0 ~= tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); tester.assertEvaluates("{ {x:0,y:0}:1, {x:1,y:0}:0 }", "tensor0 != tensor1", "{ {x:0}:3, {x:1}:7 }", "{ {y:0}:7 }"); + tester.assertEvaluates("{ {x:0}:1, {x:1}:0 }", + "tensor0 in [1,2,3]", "{ {x:0}:3, {x:1}:7 }"); + // TODO // argmax // argmin diff --git a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java index 27aaeb776e4..dde9d4bf21e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/tensor/TensorConformanceTest.java @@ -18,6 +18,7 @@ import org.junit.Test; import java.io.BufferedReader; import java.io.File; +import java.io.FileNotFoundException; import java.io.FileReader; import java.io.IOException; import java.util.ArrayList; @@ -51,32 +52,32 @@ public class TensorConformanceTest { count++; } } - if (failList.size() > 0) { - System.out.println("Conformance test fails:"); - System.out.println(failList); - } - - // Disable this for now: - //assertEquals(0, failList.size()); + assertEquals(failList.size() + " conformance test fails: " + failList, 0, failList.size()); } - private boolean testCase(String test, int count) throws IOException { + private boolean testCase(String test, int count) { try { ObjectMapper mapper = new ObjectMapper(); JsonNode node = mapper.readTree(test); + if (node.has("num_tests")) { Assert.assertEquals(node.get("num_tests").asInt(), count); - } else if (node.has("expression")) { - String expression = node.get("expression").asText(); - MapContext context = getInput(node.get("inputs")); - Tensor expect = getTensor(node.get("result").get("expect").asText()); - Tensor result = evaluate(expression, context); - boolean equals = Tensor.equals(result, expect); - if (!equals) { - System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); - } - return Tensor.equals(result, expect); + return true; + } + if (!node.has("expression")) { + return true; // ignore } + + String expression = node.get("expression").asText(); + MapContext context = getInput(node.get("inputs")); + Tensor expect = getTensor(node.get("result").get("expect").asText()); + Tensor result = evaluate(expression, context); + boolean equals = Tensor.equals(result, expect); + if (!equals) { + System.out.println(count + " : Tensors not equal. Result: " + result.toString() + " Expected: " + expect.toString() + " -> expression \"" + expression + "\""); + } + return equals; + } catch (Exception e) { System.out.println(count + " : " + e.toString()); } @@ -133,22 +134,5 @@ public class TensorConformanceTest { throw new IllegalArgumentException("Hex contains illegal characters"); } - private static String valueType(Value value) { - if (value instanceof StringValue) { - return "string"; - } - if (value instanceof BooleanValue) { - return "boolean"; - } - if (value instanceof DoubleCompatibleValue) { - return "double"; - } - if (value instanceof TensorValue) { - return ((TensorValue)value).asTensor().type().toString(); - } - return "unknown"; - } - - } diff --git a/searchlib/src/tests/rankingexpression/rankingexpressionlist b/searchlib/src/tests/rankingexpression/rankingexpressionlist index 327f2b161cd..77b2294c668 100644 --- a/searchlib/src/tests/rankingexpression/rankingexpressionlist +++ b/searchlib/src/tests/rankingexpression/rankingexpressionlist @@ -160,3 +160,7 @@ mysum ( mysum(4, 4), value( 4 ), value(4) ); mysum(mysum(4,4),value(4),value(4) "1008\x1977" "100819\x77" if(1.09999~=1.1,2,3); if (1.09999 ~= 1.1, 2, 3) +10 % 3 +1 && 0 || 1 +!a && (a || a) +10 ^ 3 -- cgit v1.2.3 From 85ccf6cfbe0634a8fddf6f17aba6c27ab5910782 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Tue, 21 Nov 2017 09:21:15 +0100 Subject: Replace min/max on tensors with reduce in config model --- .../com/yahoo/searchdefinition/RankProfile.java | 1 + .../yahoo/searchdefinition/TensorTransformer.java | 282 +++++++++++++++++++++ .../processing/TensorTransformTestCase.java | 205 +++++++++++++++ .../rankingexpression/transform/Simplifier.java | 2 +- 4 files changed, 489 insertions(+), 1 deletion(-) create mode 100644 config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java create mode 100644 config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java (limited to 'searchlib') diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index fd249956d5a..1021227b0e6 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -706,6 +706,7 @@ public class RankProfile implements Serializable, Cloneable { expression = new ConstantTensorTransformer(constants, rankPropertiesOutput).transform(expression); expression = new MacroInliner(inlineMacros).transform(expression); expression = new MacroShadower(getMacros()).transform(expression); + expression = new TensorTransformer(this).transform(expression); expression = new Simplifier().transform(expression); for (Map.Entry rankProperty : rankPropertiesOutput.entrySet()) { addRankProperty(rankProperty.getKey(), rankProperty.getValue()); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java new file mode 100644 index 00000000000..e9723042d77 --- /dev/null +++ b/config-model/src/main/java/com/yahoo/searchdefinition/TensorTransformer.java @@ -0,0 +1,282 @@ +package com.yahoo.searchdefinition; + +import com.yahoo.searchdefinition.document.Attribute; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.functions.Reduce; + +import java.util.List; +import java.util.Map; + +/** + * Transforms and simplifies tensor expressions. + * + * Currently transforms min(tensor,dim) and max(tensor,dim) to + * reduce(tensor,min/max,dim). This is necessary as the backend does + * not recognize these forms of min and max. + * + * @author lesters + */ +public class TensorTransformer extends ExpressionTransformer { + + private Search search; + private RankProfile rankprofile; + private Map macros; + + public TensorTransformer(RankProfile rankprofile) { + this.rankprofile = rankprofile; + this.search = rankprofile.getSearch(); + this.macros = rankprofile.getMacros(); + } + + @Override + public ExpressionNode transform(ExpressionNode node) { + if (node instanceof CompositeNode) { + node = transformChildren((CompositeNode) node); + } + if (node instanceof FunctionNode) { + node = transformFunctionNode((FunctionNode) node); + } + return node; + } + + private ExpressionNode transformFunctionNode(FunctionNode node) { + switch (node.getFunction()) { + case min: + case max: + return transformMaxAndMinFunctionNode(node); + } + return node; + } + + /** + * Transforms max and min functions if it can be proven that the first + * argument resolves to a tensor and the second argument is a valid + * dimension in the tensor. If these do not hold, the node will not + * be transformed. + * + * The test for whether or not the first argument resolves to a tensor + * is to evaluate that expression. All values used in the expression + * is bound to a context with dummy values with enough information to + * deduce tensor types. + * + * There is currently no guarantee that all cases will be found. For + * instance, if-statements are problematic. + */ + private ExpressionNode transformMaxAndMinFunctionNode(FunctionNode node) { + if (node.children().size() != 2) { + return node; + } + ExpressionNode arg1 = node.children().get(0); + ExpressionNode arg2 = node.children().get(1); + if (!potentialDimensionName(arg2)) { + return node; + } + try { + String dimension = ((ReferenceNode) arg2).getName(); + Context context = buildContext(arg1); + Value value = arg1.evaluate(context); + if (verifyTensorAndDimension(value, dimension)) { + return replaceMaxAndMinFunction(node); + } + } catch (Exception e) { + // Don't replace the expression in case of any errors, e.g. unknown values or rank features + } + return node; + } + + private boolean potentialDimensionName(ExpressionNode arg) { + return arg instanceof ReferenceNode && ((ReferenceNode) arg).children().size() == 0; + } + + private boolean verifyTensorAndDimension(Value value, String dimension) { + if (value instanceof TensorValue) { + Tensor tensor = ((TensorValue) value).asTensor(); + TensorType type = tensor.type(); + return type.dimensionNames().contains(dimension); + } + return false; + } + + private ExpressionNode replaceMaxAndMinFunction(FunctionNode node) { + ExpressionNode arg1 = node.children().get(0); + ExpressionNode arg2 = node.children().get(1); + + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); + Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); + String dimension = ((ReferenceNode) arg2).getName(); + + return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); + } + + /** + * Creates an evaluation context by iterating through the expression tree, and + * adding dummy values with correct types to the context. + */ + private Context buildContext(ExpressionNode node) { + Context context = new MapContext(); + addRoot(node, context); + return context; + } + + private Value emptyStringValue() { + return new StringValue(""); + } + + private Value emptyDoubleValue() { + return new DoubleValue(0.0); + } + + private Value emptyTensorValue(TensorType type) { + Tensor empty = Tensor.Builder.of(type).build(); + return new TensorValue(empty); + } + + private void addRoot(ExpressionNode node, Context context) { + addChildren(node, context); + if (node instanceof ReferenceNode) { + ReferenceNode referenceNode = (ReferenceNode) node; + addIfAttribute(referenceNode, context); + addIfConstant(referenceNode, context); + addIfQuery(referenceNode, context); + addIfTensorFrom(referenceNode, context); + addIfMacro(referenceNode, context); + } + } + + private void addChildren(ExpressionNode node, Context context) { + if (node instanceof CompositeNode) { + List children = ((CompositeNode) node).children(); + for (ExpressionNode child : children) { + addRoot(child, context); + } + } + } + + private void addIfAttribute(ReferenceNode node, Context context) { + if (!node.getName().equals("attribute")) { + return; + } + if (node.children().size() == 0) { + return; + } + String attribute = node.children().get(0).toString(); + Attribute a = search.getAttribute(attribute); + Value v; + if (a.getType() == Attribute.Type.STRING) { + v = emptyStringValue(); + } else if (a.getType() == Attribute.Type.TENSOR) { + v = emptyTensorValue(a.tensorType().orElseThrow(RuntimeException::new)); + } else { + v = emptyDoubleValue(); + } + context.put(node.toString(), v); + } + + private void addIfConstant(ReferenceNode node, Context context) { + if (!node.getName().equals("constant")) { + return; + } + if (node.children().size() != 1) { + return; + } + ExpressionNode child = node.children().get(0); + while (child instanceof CompositeNode && ((CompositeNode) child).children().size() > 0) { + child = ((CompositeNode) child).children().get(0); + } + String name = child.toString(); + addIfConstantInRankProfile(name, node, context); + addIfConstantInRankingConstants(name, node, context); + } + + private void addIfConstantInRankProfile(String name, ReferenceNode node, Context context) { + if (rankprofile.getConstants().containsKey(name)) { + context.put(node.toString(), rankprofile.getConstants().get(name)); + } + } + + private void addIfConstantInRankingConstants(String name, ReferenceNode node, Context context) { + for (RankingConstant rankingConstant : search.getRankingConstants()) { + if (rankingConstant.getName().equals(name)) { + context.put(node.toString(), emptyTensorValue(rankingConstant.getTensorType())); + } + } + } + + private void addIfQuery(ReferenceNode node, Context context) { + if (!node.getName().equals("query")) { + return; + } + if (node.children().size() != 1) { + return; + } + String name = node.children().get(0).toString(); + if (rankprofile.getQueryFeatureTypes().containsKey(name)) { + String type = rankprofile.getQueryFeatureTypes().get(name); + Value v; + if (type.contains("tensor")) { + v = emptyTensorValue(TensorType.fromSpec(type)); + } else if (type.equalsIgnoreCase("string")) { + v = emptyStringValue(); + } else { + v = emptyDoubleValue(); + } + context.put(node.toString(), v); + } + } + + private void addIfTensorFrom(ReferenceNode node, Context context) { + if (!node.getName().startsWith("tensorFrom")) { + return; + } + if (node.children().size() < 1 || node.children().size() > 2) { + return; + } + ExpressionNode source = node.children().get(0); + if (source instanceof CompositeNode && ((CompositeNode) source).children().size() > 0) { + source = ((CompositeNode) source).children().get(0); + } + String dimension = source.toString(); + if (node.children().size() == 2) { + dimension = node.children().get(1).toString(); + } + TensorType type = (new TensorType.Builder()).mapped(dimension).build(); + context.put(node.toString(), emptyTensorValue(type)); + } + + private void addIfMacro(ReferenceNode node, Context context) { + RankProfile.Macro macro = macros.get(node.getName()); + if (macro == null) { + return; + } + ExpressionNode root = macro.getRankingExpression().getRoot(); + Context macroContext = buildContext(root); + addMacroArguments(node, context, macro, macroContext); + Value value = root.evaluate(macroContext); + context.put(node.toString(), value); + } + + private void addMacroArguments(ReferenceNode node, Context context, RankProfile.Macro macro, Context macroContext) { + if (macro.getFormalParams().size() > 0 && node.children().size() > 0) { + for (int i = 0; i < macro.getFormalParams().size() && i < node.children().size(); ++i) { + String param = macro.getFormalParams().get(i); + ExpressionNode argumentExpression = node.children().get(i); + Value arg = argumentExpression.evaluate(context); + macroContext.put(param, arg); + } + } + } + +} diff --git a/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java new file mode 100644 index 00000000000..aa3fd4e9aae --- /dev/null +++ b/config-model/src/test/java/com/yahoo/searchdefinition/processing/TensorTransformTestCase.java @@ -0,0 +1,205 @@ +// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchdefinition.processing; + +import com.yahoo.collections.Pair; +import com.yahoo.component.ComponentId; +import com.yahoo.config.model.application.provider.BaseDeployLogger; +import com.yahoo.search.query.profile.QueryProfileRegistry; +import com.yahoo.search.query.profile.types.FieldDescription; +import com.yahoo.search.query.profile.types.FieldType; +import com.yahoo.search.query.profile.types.QueryProfileType; +import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; +import com.yahoo.searchdefinition.RankProfile; +import com.yahoo.searchdefinition.RankProfileRegistry; +import com.yahoo.searchdefinition.Search; +import com.yahoo.searchdefinition.SearchBuilder; +import com.yahoo.searchdefinition.SearchDefinitionTestCase; +import com.yahoo.searchdefinition.derived.AttributeFields; +import com.yahoo.searchdefinition.derived.RawRankProfile; +import com.yahoo.searchdefinition.parser.ParseException; +import com.yahoo.vespa.model.container.search.QueryProfiles; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertTrue; + +public class TensorTransformTestCase extends SearchDefinitionTestCase { + + @Test + public void requireThatNormalMaxAndMinAreNotReplaced() throws ParseException { + assertContainsExpression("max(1.0,2.0)", "max(1.0,2.0)"); + assertContainsExpression("min(attribute(double_field),x)", "min(attribute(double_field),x)"); + assertContainsExpression("max(attribute(double_field),attribute(double_array_field))", "max(attribute(double_field),attribute(double_array_field))"); + assertContainsExpression("min(attribute(tensor_field_1),attribute(double_field))", "min(attribute(tensor_field_1),attribute(double_field))"); + assertContainsExpression("max(attribute(tensor_field_1),attribute(tensor_field_2))", "max(attribute(tensor_field_1),attribute(tensor_field_2))"); + assertContainsExpression("min(test_constant_tensor,1.0)", "min(constant(test_constant_tensor),1.0)"); + assertContainsExpression("max(base_constant_tensor,1.0)", "max(constant(base_constant_tensor),1.0)"); + assertContainsExpression("min(constant(file_constant_tensor),1.0)", "min(constant(file_constant_tensor),1.0)"); + assertContainsExpression("max(query(q),1.0)", "max(query(q),1.0)"); + assertContainsExpression("max(query(n),1.0)", "max(query(n),1.0)"); + } + + @Test + public void requireThatMaxAndMinWithTensorAttributesAreReplaced() throws ParseException { + assertContainsExpression("max(attribute(tensor_field_1),x)", "reduce(attribute(tensor_field_1),max,x)"); + assertContainsExpression("1 + max(attribute(tensor_field_1),x)", "1+reduce(attribute(tensor_field_1),max,x)"); + assertContainsExpression("if(attribute(double_field),1 + max(attribute(tensor_field_1),x),0)", "if(attribute(double_field),1+reduce(attribute(tensor_field_1),max,x),0)"); + assertContainsExpression("max(max(attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(max(attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); + assertContainsExpression("max(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),x)", "reduce(if(attribute(double_field),attribute(tensor_field_1),attribute(tensor_field_2)),max,x)"); + assertContainsExpression("max(max(attribute(tensor_field_1),x),x)", "max(reduce(attribute(tensor_field_1),max,x),x)"); // will result in deploy error. + assertContainsExpression("max(max(attribute(tensor_field_2),x),y)", "reduce(reduce(attribute(tensor_field_2),max,x),max,y)"); + } + + @Test + public void requireThatMaxAndMinWithConstantTensorsAreReplaced() throws ParseException { + assertContainsExpression("max(test_constant_tensor,x)", "reduce(constant(test_constant_tensor),max,x)"); + assertContainsExpression("max(base_constant_tensor,x)", "reduce(constant(base_constant_tensor),max,x)"); + assertContainsExpression("min(constant(file_constant_tensor),x)", "reduce(constant(file_constant_tensor),min,x)"); + } + + @Test + public void requireThatMaxAndMinWithTensorExpressionsAreReplaced() throws ParseException { + assertContainsExpression("min(attribute(double_field) + attribute(tensor_field_1),x)", "reduce(attribute(double_field)+attribute(tensor_field_1),min,x)"); + assertContainsExpression("min(attribute(tensor_field_1) * attribute(tensor_field_2),x)", "reduce(attribute(tensor_field_1)*attribute(tensor_field_2),min,x)"); + assertContainsExpression("min(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),x)", "reduce(join(attribute(tensor_field_1),attribute(tensor_field_2),f(x,y)(x*y)),min,x)"); + assertContainsExpression("min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)", "min(join(tensor_field_1,tensor_field_2,f(x,y)(x*y)),x)"); + } + + @Test + public void requireThatMaxAndMinWithTensorFromIsReplaced() throws ParseException { + assertContainsExpression("max(tensorFromLabels(attribute(double_array_field)),double_array_field)", "reduce(tensorFromLabels(attribute(double_array_field)),max,double_array_field)"); + assertContainsExpression("max(tensorFromLabels(attribute(double_array_field),x),x)", "reduce(tensorFromLabels(attribute(double_array_field),x),max,x)"); + assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field)),weightedset_field)", "reduce(tensorFromWeightedSet(attribute(weightedset_field)),max,weightedset_field)"); + assertContainsExpression("max(tensorFromWeightedSet(attribute(weightedset_field),x),x)", "reduce(tensorFromWeightedSet(attribute(weightedset_field),x),max,x)"); + } + + @Test + public void requireThatMaxAndMinWithTensorInQueryIsReplaced() throws ParseException { + assertContainsExpression("max(query(q),x)", "reduce(query(q),max,x)"); + assertContainsExpression("max(query(n),x)", "max(query(n),x)"); + } + + @Test + public void requireThatMaxAndMinWithTensoresReturnedFromMacrosAreReplaced() throws ParseException { + assertContainsExpression("max(returns_tensor,x)", "reduce(rankingExpression(returns_tensor),max,x)"); + assertContainsExpression("max(wraps_returns_tensor,x)", "reduce(rankingExpression(wraps_returns_tensor),max,x)"); + assertContainsExpression("max(tensor_inheriting,x)", "reduce(rankingExpression(tensor_inheriting),max,x)"); + assertContainsExpression("max(returns_tensor_with_arg(attribute(tensor_field_1)),x)", "reduce(rankingExpression(returns_tensor_with_arg@),max,x)"); + } + + + private void assertContainsExpression(String expr, String transformedExpression) throws ParseException { + assertTrue("Expected expression '" + transformedExpression + "' not found", + containsExpression(expr, transformedExpression)); + } + + private boolean containsExpression(String expr, String transformedExpression) throws ParseException { + for (Pair rankPropertyExpression : buildSearch(expr)) { + String rankProperty = rankPropertyExpression.getFirst(); + if (rankProperty.equals("rankingExpression(firstphase).rankingScript")) { + String rankExpression = censorBindingHash(rankPropertyExpression.getSecond().replace(" ","")); + return rankExpression.equals(transformedExpression); + } + } + return false; + } + + private List> buildSearch(String expression) throws ParseException { + RankProfileRegistry rankProfileRegistry = new RankProfileRegistry(); + SearchBuilder builder = new SearchBuilder(rankProfileRegistry); + builder.importString( + "search test {\n" + + " document test { \n" + + " field double_field type double { \n" + + " indexing: summary | attribute \n" + + " }\n" + + " field double_array_field type array { \n" + + " indexing: summary | attribute \n" + + " }\n" + + " field weightedset_field type weightedset { \n" + + " indexing: summary | attribute \n" + + " }\n" + + " field tensor_field_1 type tensor(x{}) { \n" + + " indexing: summary | attribute \n" + + " attribute: tensor(x{}) \n" + + " }\n" + + " field tensor_field_2 type tensor(x[3],y[3]) { \n" + + " indexing: summary | attribute \n" + + " attribute: tensor(x[3],y[3]) \n" + + " }\n" + + " }\n" + + " constant file_constant_tensor {\n" + + " file: constants/tensor.json\n" + + " type: tensor(x{})\n" + + " }\n" + + " rank-profile base {\n" + + " constants {\n" + + " base_constant_tensor {\n" + + " value: { {x:0}:0 }\n" + + " }\n" + + " }\n" + + " macro base_tensor() {\n" + + " expression: constant(base_constant_tensor)\n" + + " }\n" + + " }\n" + + " rank-profile test inherits base {\n" + + " constants {\n" + + " test_constant_tensor {\n" + + " value: { {x:0}:1 }\n" + + " }\n" + + " }\n" + + " macro returns_tensor_with_arg(arg1) {\n" + + " expression: 2.0 * arg1\n" + + " }\n" + + " macro wraps_returns_tensor() {\n" + + " expression: returns_tensor\n" + + " }\n" + + " macro returns_tensor() {\n" + + " expression: attribute(tensor_field_2)\n" + + " }\n" + + " macro tensor_inheriting() {\n" + + " expression: base_tensor\n" + + " }\n" + + " first-phase {\n" + + " expression: " + expression + "\n" + + " }\n" + + " }\n" + + "}\n"); + builder.build(new BaseDeployLogger(), setupQueryProfileTypes()); + Search s = builder.getSearch(); + RankProfile test = rankProfileRegistry.getRankProfile(s, "test").compile(); + List> testRankProperties = new RawRankProfile(test, new AttributeFields(s)).configProperties(); + for (Object o : testRankProperties) + System.out.println(o); + return testRankProperties; + } + + private static QueryProfiles setupQueryProfileTypes() { + QueryProfileRegistry registry = new QueryProfileRegistry(); + QueryProfileTypeRegistry typeRegistry = registry.getTypeRegistry(); + QueryProfileType type = new QueryProfileType(new ComponentId("testtype")); + type.addField(new FieldDescription("ranking.features.query(q)", + FieldType.fromString("tensor(x{})", typeRegistry)), typeRegistry); + type.addField(new FieldDescription("ranking.features.query(n)", + FieldType.fromString("integer", typeRegistry)), typeRegistry); + typeRegistry.register(type); + return new QueryProfiles(registry); + } + + private String censorBindingHash(String s) { + StringBuilder b = new StringBuilder(); + boolean areInHash = false; + for (int i = 0; i < s.length() ; i++) { + char current = s.charAt(i); + if ( ! Character.isLetterOrDigit(current)) // end of hash + areInHash = false; + if ( ! areInHash) + b.append(current); + if (current == '@') // start of hash + areInHash = true; + } + return b.toString(); + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java index ede7c861d98..ebad0d5c21f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/Simplifier.java @@ -94,7 +94,7 @@ public class Simplifier extends ExpressionTransformer { private ExpressionNode transformIf(IfNode node) { if ( ! isConstant(node.getCondition())) return node; - if (((BooleanValue)node.getCondition().evaluate(null)).asBoolean()) + if ((node.getCondition().evaluate(null)).asBoolean()) return node.getTrueExpression(); else return node.getFalseExpression(); -- cgit v1.2.3