summaryrefslogtreecommitdiffstats
path: root/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp
diff options
context:
space:
mode:
authorTor Egge <Tor.Egge@online.no>2023-04-03 11:39:51 +0200
committerTor Egge <Tor.Egge@online.no>2023-04-03 11:39:51 +0200
commit9514d4f9eb4c51bec63f0a22abd3f5717eb6118d (patch)
tree7c527ccf76cca56d0f5b12f6aa6b87b5a3f9a4d0 /streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp
parent9d93baa2b9d23258f8e760e3d804ee9065cf9a58 (diff)
Wire in TensorExtAttribute.
Diffstat (limited to 'streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp')
-rw-r--r--streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp26
1 files changed, 26 insertions, 0 deletions
diff --git a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp
index d0e3d1f038b..2caf2de1d0b 100644
--- a/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp
+++ b/streamingvisitors/src/vespa/searchvisitor/searchvisitor.cpp
@@ -7,6 +7,7 @@
#include <vespa/persistence/spi/docentry.h>
#include <vespa/document/datatype/positiondatatype.h>
#include <vespa/document/datatype/documenttype.h>
+#include <vespa/document/datatype/tensor_data_type.h>
#include <vespa/document/datatype/weightedsetdatatype.h>
#include <vespa/document/datatype/mapdatatype.h>
#include <vespa/searchlib/aggregation/modifiers.h>
@@ -14,6 +15,7 @@
#include <vespa/searchlib/common/packets.h>
#include <vespa/searchlib/uca/ucaconverter.h>
#include <vespa/searchlib/features/setup.h>
+#include <vespa/searchlib/tensor/tensor_ext_attribute.h>
#include <vespa/searchcommon/attribute/config.h>
#include <vespa/vespalib/geo/zcurve.h>
#include <vespa/vespalib/objects/nbostream.h>
@@ -99,6 +101,16 @@ createMultiValueAttribute(const vespalib::string & name, const document::FieldVa
return {};
}
+const document::TensorDataType*
+get_tensor_type(const document::FieldValue& fv)
+{
+ auto tfv = dynamic_cast<const document::TensorFieldValue*>(&fv);
+ if (tfv == nullptr) {
+ return nullptr;
+ }
+ return dynamic_cast<const document::TensorDataType*>(tfv->getDataType());
+}
+
AttributeVector::SP
createAttribute(const vespalib::string & name, const document::FieldValue & fv)
{
@@ -111,6 +123,12 @@ createAttribute(const vespalib::string & name, const document::FieldValue & fv)
return std::make_shared<search::SingleStringExtAttribute>(name);
} else if (fv.isA(document::FieldValue::Type::RAW)) {
return std::make_shared<search::attribute::SingleRawExtAttribute>(name);
+ } else if (fv.isA(document::FieldValue::Type::TENSOR) && get_tensor_type(fv) != nullptr) {
+ search::attribute::Config cfg(search::attribute::BasicType::TENSOR, search::attribute::CollectionType::SINGLE);
+ auto tdt = get_tensor_type(fv);
+ assert(tdt != nullptr);
+ cfg.setTensorType(tdt->getTensorType());
+ return std::make_shared<search::tensor::TensorExtAttribute>(name, cfg);
} else {
LOG(debug, "Can not make an attribute out of %s of type '%s'.", name.c_str(), fv.className());
}
@@ -444,6 +462,14 @@ SearchVisitor::AttributeInserter::onPrimitive(uint32_t, const Content & c)
} else if (_attribute.is_raw_type()) {
auto raw_value = value.getAsRaw();
attr.add(vespalib::ConstArrayRef<char>(raw_value.first, raw_value.second), c.getWeight());
+ } else if (_attribute.isTensorType()) {
+ auto tfvalue = dynamic_cast<const document::TensorFieldValue*>(&value);
+ if (tfvalue != nullptr) {
+ auto tensor = tfvalue->getAsTensorPtr();
+ if (tensor != nullptr) {
+ attr.add(*tensor, c.getWeight());
+ }
+ }
} else {
assert(false && "We got an attribute vector that is of an unknown type");
}