summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorGeir Storli <geirst@yahooinc.com>2022-07-06 16:05:30 +0000
committerGeir Storli <geirst@yahooinc.com>2022-07-06 16:05:30 +0000
commit848f046a4878ec17276bd25524d50e87ed55383d (patch)
treeb7e8c4f8178b45f1d9e93927db8d98507cacd6a3 /searchlib
parent193ab3ee7d39f555c4f7959f668d494a06a3d4b8 (diff)
Refactor shared code between closeness and distance features.
This is in preparation for using DistanceCalculator when raw score is not available.
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/vespa/searchlib/features/CMakeLists.txt3
-rw-r--r--searchlib/src/vespa/searchlib/features/closenessfeature.cpp30
-rw-r--r--searchlib/src/vespa/searchlib/features/distance_calculator_bundle.cpp52
-rw-r--r--searchlib/src/vespa/searchlib/features/distance_calculator_bundle.h40
-rw-r--r--searchlib/src/vespa/searchlib/features/distancefeature.cpp34
5 files changed, 110 insertions, 49 deletions
diff --git a/searchlib/src/vespa/searchlib/features/CMakeLists.txt b/searchlib/src/vespa/searchlib/features/CMakeLists.txt
index 88531a46cb1..8acf28f4a2f 100644
--- a/searchlib/src/vespa/searchlib/features/CMakeLists.txt
+++ b/searchlib/src/vespa/searchlib/features/CMakeLists.txt
@@ -12,7 +12,7 @@ vespa_add_library(searchlib_features OBJECT
debug_wait.cpp
dense_tensor_attribute_executor.cpp
direct_tensor_attribute_executor.cpp
- great_circle_distance_feature.cpp
+ distance_calculator_bundle.cpp
distancefeature.cpp
distancetopathfeature.cpp
documenttestutils.cpp
@@ -29,6 +29,7 @@ vespa_add_library(searchlib_features OBJECT
foreachfeature.cpp
freshnessfeature.cpp
global_sequence_feature.cpp
+ great_circle_distance_feature.cpp
internal_max_reduce_prod_join_feature.cpp
item_raw_score_feature.cpp
jarowinklerdistancefeature.cpp
diff --git a/searchlib/src/vespa/searchlib/features/closenessfeature.cpp b/searchlib/src/vespa/searchlib/features/closenessfeature.cpp
index 04fc2a263be..17bd914d690 100644
--- a/searchlib/src/vespa/searchlib/features/closenessfeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/closenessfeature.cpp
@@ -1,6 +1,7 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
#include "closenessfeature.h"
+#include "distance_calculator_bundle.h"
#include "utils.h"
#include <vespa/searchcommon/common/schema.h>
#include <vespa/searchlib/fef/properties.h>
@@ -16,8 +17,8 @@ namespace search::features {
/** Implements the executor for converting NNS rawscore to a closeness feature. */
class ConvertRawScoreToCloseness : public fef::FeatureExecutor {
private:
- std::vector<fef::TermFieldHandle> _handles;
- const fef::MatchData *_md;
+ DistanceCalculatorBundle _bundle;
+ const fef::MatchData *_md;
void handle_bind_match_data(const fef::MatchData &md) override {
_md = &md;
}
@@ -28,32 +29,15 @@ public:
};
ConvertRawScoreToCloseness::ConvertRawScoreToCloseness(const fef::IQueryEnvironment &env, uint32_t fieldId)
- : _handles(),
+ : _bundle(env, fieldId),
_md(nullptr)
{
- _handles.reserve(env.getNumTerms());
- for (uint32_t i = 0; i < env.getNumTerms(); ++i) {
- search::fef::TermFieldHandle handle = util::getTermFieldHandle(env, i, fieldId);
- if (handle != search::fef::IllegalHandle) {
- _handles.push_back(handle);
- }
- }
}
ConvertRawScoreToCloseness::ConvertRawScoreToCloseness(const fef::IQueryEnvironment &env, const vespalib::string &label)
- : _handles(),
+ : _bundle(env, label),
_md(nullptr)
{
- const ITermData *term = util::getTermByLabel(env, label);
- if (term != nullptr) {
- // expect numFields() == 1
- for (uint32_t i = 0; i < term->numFields(); ++i) {
- TermFieldHandle handle = term->field(i).getHandle();
- if (handle != IllegalHandle) {
- _handles.push_back(handle);
- }
- }
- }
}
void
@@ -61,8 +45,8 @@ ConvertRawScoreToCloseness::execute(uint32_t docId)
{
feature_t max_closeness = 0.0;
assert(_md);
- for (auto handle : _handles) {
- const TermFieldMatchData *tfmd = _md->resolveTermField(handle);
+ for (const auto& elem : _bundle.elements()) {
+ const TermFieldMatchData *tfmd = _md->resolveTermField(elem.handle);
if (tfmd->getDocId() == docId) {
feature_t converted = tfmd->getRawScore();
max_closeness = std::max(max_closeness, converted);
diff --git a/searchlib/src/vespa/searchlib/features/distance_calculator_bundle.cpp b/searchlib/src/vespa/searchlib/features/distance_calculator_bundle.cpp
new file mode 100644
index 00000000000..361ccce2fe2
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/features/distance_calculator_bundle.cpp
@@ -0,0 +1,52 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#include "distance_calculator_bundle.h"
+#include "utils.h"
+#include <vespa/searchlib/fef/iqueryenvironment.h>
+#include <vespa/searchlib/tensor/distance_calculator.h>
+
+using search::fef::ITermData;
+using search::fef::IllegalHandle;
+using search::fef::TermFieldHandle;
+
+namespace search::features {
+
+DistanceCalculatorBundle::Element::Element(fef::TermFieldHandle handle_in)
+ : handle(handle_in),
+ calc()
+{
+}
+
+DistanceCalculatorBundle::Element::~Element() = default;
+
+DistanceCalculatorBundle::DistanceCalculatorBundle(const fef::IQueryEnvironment& env,
+ uint32_t field_id)
+ : _elems()
+{
+ _elems.reserve(env.getNumTerms());
+ for (uint32_t i = 0; i < env.getNumTerms(); ++i) {
+ search::fef::TermFieldHandle handle = util::getTermFieldHandle(env, i, field_id);
+ if (handle != search::fef::IllegalHandle) {
+ _elems.emplace_back(handle);
+ }
+ }
+}
+
+DistanceCalculatorBundle::DistanceCalculatorBundle(const fef::IQueryEnvironment& env,
+ const vespalib::string& label)
+ : _elems()
+{
+ const ITermData *term = util::getTermByLabel(env, label);
+ if (term != nullptr) {
+ // expect numFields() == 1
+ for (uint32_t i = 0; i < term->numFields(); ++i) {
+ TermFieldHandle handle = term->field(i).getHandle();
+ if (handle != IllegalHandle) {
+ _elems.emplace_back(handle);
+ }
+ }
+ }
+}
+
+}
+
diff --git a/searchlib/src/vespa/searchlib/features/distance_calculator_bundle.h b/searchlib/src/vespa/searchlib/features/distance_calculator_bundle.h
new file mode 100644
index 00000000000..d28a315edd1
--- /dev/null
+++ b/searchlib/src/vespa/searchlib/features/distance_calculator_bundle.h
@@ -0,0 +1,40 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+#pragma once
+
+#include <vespa/searchlib/fef/handle.h>
+#include <vespa/vespalib/stllike/string.h>
+#include <memory>
+#include <vector>
+
+namespace search::tensor { class DistanceCalculator; }
+namespace search::fef { class IQueryEnvironment; }
+
+namespace search::features {
+
+/**
+ * A bundle of term-field tuples (TermFieldHandle, DistanceCalculator) used by the closeness and distance rank features.
+ *
+ * For most document ids the raw score is available in the TermFieldMatchData retrieved using the TermFieldHandle,
+ * as it was calculated during matching. In the other cases the DistanceCalculator can be used to calculate the score on the fly.
+ */
+class DistanceCalculatorBundle {
+public:
+ struct Element {
+ fef::TermFieldHandle handle;
+ std::unique_ptr<search::tensor::DistanceCalculator> calc;
+ Element(Element&& rhs) noexcept = default; // Needed as std::vector::reserve() is used.
+ Element(fef::TermFieldHandle handle_in);
+ ~Element();
+ };
+private:
+ std::vector<Element> _elems;
+
+public:
+ DistanceCalculatorBundle(const fef::IQueryEnvironment& env, uint32_t field_id);
+ DistanceCalculatorBundle(const fef::IQueryEnvironment& env, const vespalib::string& label);
+
+ const std::vector<Element>& elements() const { return _elems; }
+};
+
+}
diff --git a/searchlib/src/vespa/searchlib/features/distancefeature.cpp b/searchlib/src/vespa/searchlib/features/distancefeature.cpp
index d0a2c1a3838..4a279da9bdd 100644
--- a/searchlib/src/vespa/searchlib/features/distancefeature.cpp
+++ b/searchlib/src/vespa/searchlib/features/distancefeature.cpp
@@ -1,16 +1,17 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+#include "distance_calculator_bundle.h"
#include "distancefeature.h"
+#include "utils.h"
+#include <vespa/document/datatype/positiondatatype.h>
#include <vespa/searchcommon/common/schema.h>
#include <vespa/searchlib/common/geo_location_spec.h>
#include <vespa/searchlib/fef/matchdata.h>
-#include <vespa/document/datatype/positiondatatype.h>
#include <vespa/vespalib/geo/zcurve.h>
#include <vespa/vespalib/util/issue.h>
#include <vespa/vespalib/util/stash.h>
#include <cmath>
#include <limits>
-#include "utils.h"
#include <vespa/log/log.h>
LOG_SETUP(".features.distancefeature");
@@ -24,8 +25,8 @@ namespace search::features {
/** Implements the executor for converting NNS rawscore to a distance feature. */
class ConvertRawscoreToDistance : public fef::FeatureExecutor {
private:
- std::vector<fef::TermFieldHandle> _handles;
- const fef::MatchData *_md;
+ DistanceCalculatorBundle _bundle;
+ const fef::MatchData *_md;
void handle_bind_match_data(const fef::MatchData &md) override {
_md = &md;
}
@@ -36,32 +37,15 @@ public:
};
ConvertRawscoreToDistance::ConvertRawscoreToDistance(const fef::IQueryEnvironment &env, uint32_t fieldId)
- : _handles(),
+ : _bundle(env, fieldId),
_md(nullptr)
{
- _handles.reserve(env.getNumTerms());
- for (uint32_t i = 0; i < env.getNumTerms(); ++i) {
- search::fef::TermFieldHandle handle = util::getTermFieldHandle(env, i, fieldId);
- if (handle != search::fef::IllegalHandle) {
- _handles.push_back(handle);
- }
- }
}
ConvertRawscoreToDistance::ConvertRawscoreToDistance(const fef::IQueryEnvironment &env, const vespalib::string &label)
- : _handles(),
+ : _bundle(env, label),
_md(nullptr)
{
- const ITermData *term = util::getTermByLabel(env, label);
- if (term != nullptr) {
- // expect numFields() == 1
- for (uint32_t i = 0; i < term->numFields(); ++i) {
- TermFieldHandle handle = term->field(i).getHandle();
- if (handle != IllegalHandle) {
- _handles.push_back(handle);
- }
- }
- }
}
void
@@ -69,8 +53,8 @@ ConvertRawscoreToDistance::execute(uint32_t docId)
{
feature_t min_distance = std::numeric_limits<feature_t>::max();
assert(_md);
- for (auto handle : _handles) {
- const TermFieldMatchData *tfmd = _md->resolveTermField(handle);
+ for (const auto& elem : _bundle.elements()) {
+ const TermFieldMatchData *tfmd = _md->resolveTermField(elem.handle);
if (tfmd->getDocId() == docId) {
feature_t invdist = tfmd->getRawScore();
feature_t converted = (1.0 / invdist) - 1.0;