diff options
author | Tor Egge <Tor.Egge@online.no> | 2023-06-05 16:42:45 +0200 |
---|---|---|
committer | Tor Egge <Tor.Egge@online.no> | 2023-06-05 16:42:45 +0200 |
commit | c0d9b10280db007e376bbc28d75f8b774db2f9a8 (patch) | |
tree | 3b35a69f6006dc36ab139b24fa057f200a3fe03e /streamingvisitors | |
parent | 678fa9ff6d6e363416ec7fe400395c9f3003934e (diff) |
Setup distance metrics for streaming search.
Add range checks when converting to internal distance threshold.
Diffstat (limited to 'streamingvisitors')
4 files changed, 27 insertions, 4 deletions
diff --git a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp index c53dfae294a..9d62122af87 100644 --- a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp +++ b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp @@ -140,7 +140,7 @@ get_tensor_type(const document::FieldValue& fv) } AttributeVector::SP -createAttribute(const vespalib::string & name, const document::FieldValue & fv) +createAttribute(const vespalib::string & name, const document::FieldValue & fv, search::attribute::DistanceMetric dm) { LOG(debug, "Create single value attribute '%s' with value type '%s'", name.c_str(), fv.className()); if (fv.isA(document::FieldValue::Type::BYTE) || fv.isA(document::FieldValue::Type::INT) || fv.isA(document::FieldValue::Type::LONG)) { @@ -156,6 +156,7 @@ createAttribute(const vespalib::string & name, const document::FieldValue & fv) auto tdt = get_tensor_type(fv); assert(tdt != nullptr); cfg.setTensorType(tdt->getTensorType()); + cfg.set_distance_metric(dm); return std::make_shared<search::tensor::TensorExtAttribute>(name, cfg); } else { LOG(debug, "Can not make an attribute out of %s of type '%s'.", name.c_str(), fv.className()); @@ -860,7 +861,7 @@ void SearchVisitor::setupAttributeVector(const FieldPath &fieldPath) { } else if (typeSeen == WSET) { attr = createMultiValueAttribute (attrName, fv, false); } else { - attr = createAttribute(attrName, fv); + attr = createAttribute(attrName, fv, _fieldSearchSpecMap.get_distance_metric(attrName)); } if (attr) { diff --git a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp index c6dfa792c6a..772f336e5df 100644 --- a/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp +++ b/streamingvisitors/src/vespa/vsm/searcher/nearest_neighbor_field_searcher.cpp @@ -33,10 +33,11 @@ namespace { constexpr uint32_t scratch_docid = 0; std::unique_ptr<TensorExtAttribute> -make_attribute(const ValueType& tensor_type) +make_attribute(const ValueType& tensor_type, search::attribute::DistanceMetric dm) { Config cfg(BasicType::TENSOR, CollectionType::SINGLE); cfg.setTensorType(tensor_type); + cfg.set_distance_metric(dm); auto result = std::make_unique<TensorExtAttribute>("nnfs_attr", cfg); uint32_t docid; result->addDoc(docid); @@ -94,7 +95,7 @@ NearestNeighborFieldSearcher::prepare(search::streaming::QueryTermList& qtl, vespalib::Issue::report("Data type for field %u is '%s', but expected it to be a tensor type", field(), field_paths[field()].back().getDataType().toString().c_str()); } - _attr = make_attribute(tensor_type->getTensorType()); + _attr = make_attribute(tensor_type->getTensorType(), _metric); _calcs.clear(); for (auto term : qtl) { auto* nn_term = term->as_nearest_neighbor_query_node(); diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp index 98ed8a26938..f6ac3a6c88a 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.cpp @@ -313,6 +313,23 @@ void FieldSearchSpecMap::buildSearcherMap(const StringFieldIdTMapT & fieldsInQue std::sort(fieldSearcherMap.begin(), fieldSearcherMap.end(), lesserField); } +search::attribute::DistanceMetric +FieldSearchSpecMap::get_distance_metric(const vespalib::string& name) const +{ + auto dm = search::attribute::DistanceMetric::Euclidean; + auto fid = _nameIdMap.fieldNo(name); + if (fid == vsm::StringFieldIdTMap::npos) { + return dm; + } + auto itr = _specMap.find(fid); + if (itr == _specMap.end()) { + return dm; + } + if (!itr->second.uses_nearest_neighbor_search_method()) { + return dm; + } + return vsm::NearestNeighborFieldSearcher::distance_metric_from_string(itr->second.get_arg1()); +} vespalib::asciistream & operator <<(vespalib::asciistream & os, const FieldSearchSpecMap & df) { diff --git a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h index 14a30ed8c36..0fa0eca4357 100644 --- a/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h +++ b/streamingvisitors/src/vespa/vsm/vsm/fieldsearchspec.h @@ -1,6 +1,7 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. #pragma once +#include <vespa/searchcommon/attribute/distance_metric.h> #include <vespa/vsm/searcher/fieldsearcher.h> #include <vespa/vsm/config/vsm-cfif.h> @@ -21,6 +22,8 @@ public: FieldIdT id() const { return _id; } bool valid() const { return static_cast<bool>(_searcher); } size_t maxLength() const { return _maxLength; } + bool uses_nearest_neighbor_search_method() const noexcept { return _searchMethod == VsmfieldsConfig::Fieldspec::Searchmethod::NEAREST_NEIGHBOR; } + const vespalib::string& get_arg1() const noexcept { return _arg1; } /** * Reconfigures the field searcher based on information in the given query term. @@ -87,6 +90,7 @@ public: friend vespalib::asciistream & operator <<(vespalib::asciistream & os, const FieldSearchSpecMap & f); static vespalib::string stripNonFields(const vespalib::string & rawIndex); + search::attribute::DistanceMetric get_distance_metric(const vespalib::string& name) const; private: FieldSearchSpecMapT _specMap; // mapping from field id to field search spec |