diff options
Diffstat (limited to 'searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp')
-rw-r--r-- | searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp | 39 |
1 files changed, 24 insertions, 15 deletions
diff --git a/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp b/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp index b72a75bd19f..f36c1dbfdaa 100644 --- a/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp +++ b/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp @@ -41,15 +41,23 @@ TensorFromLabelsBlueprint::setup(const search::fef::IIndexEnvironment &env, // _params[0] = source ('attribute(name)' OR 'query(param)'); // _params[1] = dimension (optional); bool validSource = extractSource(params[0].getValue()); + if (! validSource) { + return fail("invalid source: '%s'", params[0].getValue().c_str()); + } if (params.size() == 2) { _dimension = params[1].getValue(); } else { _dimension = _sourceParam; } + auto vt = ValueType::make_type(CellType::DOUBLE, {{_dimension}}); + _valueType = ValueType::from_spec(vt.to_spec()); + if (_valueType.is_error()) { + return fail("invalid dimension name: '%s'", _dimension.c_str()); + } describeOutput("tensor", "The tensor created from the given source (attribute field or query parameter)", - FeatureType::object(ValueType::make_type(CellType::DOUBLE, {{_dimension}}))); - return validSource; + FeatureType::object(_valueType)); + return true; } namespace { @@ -57,23 +65,24 @@ namespace { FeatureExecutor & createAttributeExecutor(const search::fef::IQueryEnvironment &env, const vespalib::string &attrName, - const vespalib::string &dimension, vespalib::Stash &stash) + const ValueType &valueType, + vespalib::Stash &stash) { const IAttributeVector *attribute = env.getAttributeContext().getAttribute(attrName); if (attribute == NULL) { Issue::report("tensor_from_labels feature: The attribute vector '%s' was not found." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(valueType, stash); } if (attribute->isFloatingPointType()) { Issue::report("tensor_from_labels feature: The attribute vector '%s' must have basic type string or integer." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(valueType, stash); } if (attribute->getCollectionType() == search::attribute::CollectionType::WSET) { Issue::report("tensor_from_labels feature: The attribute vector '%s' is a weighted set - use tensorFromWeightedSet instead." " Returning empty tensor.", attrName.c_str()); - return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash); + return ConstantTensorExecutor::createEmpty(valueType, stash); } // Note that for array attribute vectors the default weight is 1.0 for all values. // This means we can get the attribute content as weighted content and build @@ -81,25 +90,25 @@ createAttributeExecutor(const search::fef::IQueryEnvironment &env, if (attribute->isIntegerType()) { // Using WeightedStringContent ensures that the integer values are converted // to strings while extracting them from the attribute. - return stash.create<TensorFromAttributeExecutor<WeightedStringContent>>(attribute, dimension); + return stash.create<TensorFromAttributeExecutor<WeightedStringContent>>(attribute, valueType); } // When the underlying attribute is of type string we can reference these values // using WeightedConstCharContent. - return stash.create<TensorFromAttributeExecutor<WeightedConstCharContent>>(attribute, dimension); + return stash.create<TensorFromAttributeExecutor<WeightedConstCharContent>>(attribute, valueType); } FeatureExecutor & createQueryExecutor(const search::fef::IQueryEnvironment &env, const vespalib::string &queryKey, - const vespalib::string &dimension, vespalib::Stash &stash) + const ValueType &valueType, + vespalib::Stash &stash) { - ValueType type = ValueType::make_type(CellType::DOUBLE, {{dimension}}); search::fef::Property prop = env.getProperties().lookup(queryKey); if (prop.found() && !prop.get().empty()) { std::vector<vespalib::string> vector; ArrayParser::parse(prop.get(), vector); auto factory = FastValueBuilderFactory::get(); - auto builder = factory.create_value_builder<double>(type, 1, 1, vector.size()); + auto builder = factory.create_value_builder<double>(valueType, 1, 1, vector.size()); std::vector<vespalib::stringref> addr_ref; for (const auto &elem : vector) { addr_ref.clear(); @@ -109,7 +118,7 @@ createQueryExecutor(const search::fef::IQueryEnvironment &env, } return ConstantTensorExecutor::create(builder->build(std::move(builder)), stash); } - return ConstantTensorExecutor::createEmpty(type, stash); + return ConstantTensorExecutor::createEmpty(valueType, stash); } } @@ -118,11 +127,11 @@ FeatureExecutor & TensorFromLabelsBlueprint::createExecutor(const search::fef::IQueryEnvironment &env, vespalib::Stash &stash) const { if (_sourceType == ATTRIBUTE_SOURCE) { - return createAttributeExecutor(env, _sourceParam, _dimension, stash); + return createAttributeExecutor(env, _sourceParam, _valueType, stash); } else if (_sourceType == QUERY_SOURCE) { - return createQueryExecutor(env, _sourceParam, _dimension, stash); + return createQueryExecutor(env, _sourceParam, _valueType, stash); } - return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{_dimension}}), stash); + return ConstantTensorExecutor::createEmpty(_valueType, stash); } } // namespace features |