summaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp')
-rw-r--r--searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp39
1 files changed, 24 insertions, 15 deletions
diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp b/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp
index b72a75bd19f..f36c1dbfdaa 100644
--- a/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp
+++ b/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp
@@ -41,15 +41,23 @@ TensorFromLabelsBlueprint::setup(const search::fef::IIndexEnvironment &env,
// _params[0] = source ('attribute(name)' OR 'query(param)');
// _params[1] = dimension (optional);
bool validSource = extractSource(params[0].getValue());
+ if (! validSource) {
+ return fail("invalid source: '%s'", params[0].getValue().c_str());
+ }
if (params.size() == 2) {
_dimension = params[1].getValue();
} else {
_dimension = _sourceParam;
}
+ auto vt = ValueType::make_type(CellType::DOUBLE, {{_dimension}});
+ _valueType = ValueType::from_spec(vt.to_spec());
+ if (_valueType.is_error()) {
+ return fail("invalid dimension name: '%s'", _dimension.c_str());
+ }
describeOutput("tensor",
"The tensor created from the given source (attribute field or query parameter)",
- FeatureType::object(ValueType::make_type(CellType::DOUBLE, {{_dimension}})));
- return validSource;
+ FeatureType::object(_valueType));
+ return true;
}
namespace {
@@ -57,23 +65,24 @@ namespace {
FeatureExecutor &
createAttributeExecutor(const search::fef::IQueryEnvironment &env,
const vespalib::string &attrName,
- const vespalib::string &dimension, vespalib::Stash &stash)
+ const ValueType &valueType,
+ vespalib::Stash &stash)
{
const IAttributeVector *attribute = env.getAttributeContext().getAttribute(attrName);
if (attribute == NULL) {
Issue::report("tensor_from_labels feature: The attribute vector '%s' was not found."
" Returning empty tensor.", attrName.c_str());
- return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash);
+ return ConstantTensorExecutor::createEmpty(valueType, stash);
}
if (attribute->isFloatingPointType()) {
Issue::report("tensor_from_labels feature: The attribute vector '%s' must have basic type string or integer."
" Returning empty tensor.", attrName.c_str());
- return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash);
+ return ConstantTensorExecutor::createEmpty(valueType, stash);
}
if (attribute->getCollectionType() == search::attribute::CollectionType::WSET) {
Issue::report("tensor_from_labels feature: The attribute vector '%s' is a weighted set - use tensorFromWeightedSet instead."
" Returning empty tensor.", attrName.c_str());
- return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash);
+ return ConstantTensorExecutor::createEmpty(valueType, stash);
}
// Note that for array attribute vectors the default weight is 1.0 for all values.
// This means we can get the attribute content as weighted content and build
@@ -81,25 +90,25 @@ createAttributeExecutor(const search::fef::IQueryEnvironment &env,
if (attribute->isIntegerType()) {
// Using WeightedStringContent ensures that the integer values are converted
// to strings while extracting them from the attribute.
- return stash.create<TensorFromAttributeExecutor<WeightedStringContent>>(attribute, dimension);
+ return stash.create<TensorFromAttributeExecutor<WeightedStringContent>>(attribute, valueType);
}
// When the underlying attribute is of type string we can reference these values
// using WeightedConstCharContent.
- return stash.create<TensorFromAttributeExecutor<WeightedConstCharContent>>(attribute, dimension);
+ return stash.create<TensorFromAttributeExecutor<WeightedConstCharContent>>(attribute, valueType);
}
FeatureExecutor &
createQueryExecutor(const search::fef::IQueryEnvironment &env,
const vespalib::string &queryKey,
- const vespalib::string &dimension, vespalib::Stash &stash)
+ const ValueType &valueType,
+ vespalib::Stash &stash)
{
- ValueType type = ValueType::make_type(CellType::DOUBLE, {{dimension}});
search::fef::Property prop = env.getProperties().lookup(queryKey);
if (prop.found() && !prop.get().empty()) {
std::vector<vespalib::string> vector;
ArrayParser::parse(prop.get(), vector);
auto factory = FastValueBuilderFactory::get();
- auto builder = factory.create_value_builder<double>(type, 1, 1, vector.size());
+ auto builder = factory.create_value_builder<double>(valueType, 1, 1, vector.size());
std::vector<vespalib::stringref> addr_ref;
for (const auto &elem : vector) {
addr_ref.clear();
@@ -109,7 +118,7 @@ createQueryExecutor(const search::fef::IQueryEnvironment &env,
}
return ConstantTensorExecutor::create(builder->build(std::move(builder)), stash);
}
- return ConstantTensorExecutor::createEmpty(type, stash);
+ return ConstantTensorExecutor::createEmpty(valueType, stash);
}
}
@@ -118,11 +127,11 @@ FeatureExecutor &
TensorFromLabelsBlueprint::createExecutor(const search::fef::IQueryEnvironment &env, vespalib::Stash &stash) const
{
if (_sourceType == ATTRIBUTE_SOURCE) {
- return createAttributeExecutor(env, _sourceParam, _dimension, stash);
+ return createAttributeExecutor(env, _sourceParam, _valueType, stash);
} else if (_sourceType == QUERY_SOURCE) {
- return createQueryExecutor(env, _sourceParam, _dimension, stash);
+ return createQueryExecutor(env, _sourceParam, _valueType, stash);
}
- return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{_dimension}}), stash);
+ return ConstantTensorExecutor::createEmpty(_valueType, stash);
}
} // namespace features