summaryrefslogtreecommitdiffstats
path: root/searchlib/src/vespa/searchlib/features/tensor_from_labels_feature.cpp
blob: b72a75bd19f0c92ca4b39981cd27ab397a88e979 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

#include "tensor_from_labels_feature.h"
#include "array_parser.hpp"
#include "constant_tensor_executor.h"
#include "tensor_from_attribute_executor.h"
#include <vespa/searchlib/fef/properties.h>
#include <vespa/searchlib/fef/feature_type.h>
#include <vespa/searchcommon/attribute/attributecontent.h>
#include <vespa/searchcommon/attribute/iattributevector.h>
#include <vespa/eval/eval/fast_value.h>
#include <vespa/eval/eval/value_type.h>
#include <vespa/vespalib/util/issue.h>

#include <vespa/log/log.h>
LOG_SETUP(".features.tensor_from_labels_feature");

using namespace search::fef;
using search::attribute::IAttributeVector;
using search::attribute::WeightedConstCharContent;
using search::attribute::WeightedStringContent;
using vespalib::eval::FastValueBuilderFactory;
using vespalib::eval::ValueType;
using vespalib::eval::CellType;
using vespalib::Issue;
using search::fef::FeatureType;

namespace search {
namespace features {

TensorFromLabelsBlueprint::TensorFromLabelsBlueprint()
    : TensorFactoryBlueprint("tensorFromLabels")
{
}

bool
TensorFromLabelsBlueprint::setup(const search::fef::IIndexEnvironment &env,
                                 const search::fef::ParameterList &params)
{
    (void) env;
    // _params[0] = source ('attribute(name)' OR 'query(param)');
    // _params[1] = dimension (optional);
    bool validSource = extractSource(params[0].getValue());
    if (params.size() == 2) {
        _dimension = params[1].getValue();
    } else {
        _dimension = _sourceParam;
    }
    describeOutput("tensor",
                   "The tensor created from the given source (attribute field or query parameter)",
                   FeatureType::object(ValueType::make_type(CellType::DOUBLE, {{_dimension}})));
    return validSource;
}

namespace {

FeatureExecutor &
createAttributeExecutor(const search::fef::IQueryEnvironment &env,
                        const vespalib::string &attrName,
                        const vespalib::string &dimension, vespalib::Stash &stash)
{
    const IAttributeVector *attribute = env.getAttributeContext().getAttribute(attrName);
    if (attribute == NULL) {
        Issue::report("tensor_from_labels feature: The attribute vector '%s' was not found."
                      " Returning empty tensor.", attrName.c_str());
        return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash);
    }
    if (attribute->isFloatingPointType()) {
        Issue::report("tensor_from_labels feature: The attribute vector '%s' must have basic type string or integer."
                      " Returning empty tensor.", attrName.c_str());
        return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash);
    }
    if (attribute->getCollectionType() == search::attribute::CollectionType::WSET) {
        Issue::report("tensor_from_labels feature: The attribute vector '%s' is a weighted set - use tensorFromWeightedSet instead."
                      " Returning empty tensor.", attrName.c_str());
        return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{dimension}}), stash);
    }
    // Note that for array attribute vectors the default weight is 1.0 for all values.
    // This means we can get the attribute content as weighted content and build
    // the tensor the same way as with weighted set attributes in tensorFromWeightedSet.
    if (attribute->isIntegerType()) {
        // Using WeightedStringContent ensures that the integer values are converted
        // to strings while extracting them from the attribute.
        return stash.create<TensorFromAttributeExecutor<WeightedStringContent>>(attribute, dimension);
    }
    // When the underlying attribute is of type string we can reference these values
    // using WeightedConstCharContent.
    return stash.create<TensorFromAttributeExecutor<WeightedConstCharContent>>(attribute, dimension);
}

FeatureExecutor &
createQueryExecutor(const search::fef::IQueryEnvironment &env,
                    const vespalib::string &queryKey,
                    const vespalib::string &dimension, vespalib::Stash &stash)
{
    ValueType type = ValueType::make_type(CellType::DOUBLE, {{dimension}});
    search::fef::Property prop = env.getProperties().lookup(queryKey);
    if (prop.found() && !prop.get().empty()) {
        std::vector<vespalib::string> vector;
        ArrayParser::parse(prop.get(), vector);
        auto factory = FastValueBuilderFactory::get();
        auto builder = factory.create_value_builder<double>(type, 1, 1, vector.size());
        std::vector<vespalib::stringref> addr_ref;
        for (const auto &elem : vector) {
            addr_ref.clear();
            addr_ref.push_back(elem);
            auto cell_array = builder->add_subspace(addr_ref);
            cell_array[0] = 1.0;
        }
        return ConstantTensorExecutor::create(builder->build(std::move(builder)), stash);
    }
    return ConstantTensorExecutor::createEmpty(type, stash);
}

}

FeatureExecutor &
TensorFromLabelsBlueprint::createExecutor(const search::fef::IQueryEnvironment &env, vespalib::Stash &stash) const
{
    if (_sourceType == ATTRIBUTE_SOURCE) {
        return createAttributeExecutor(env, _sourceParam, _dimension, stash);
    } else if (_sourceType == QUERY_SOURCE) {
        return createQueryExecutor(env, _sourceParam, _dimension, stash);
    }
    return ConstantTensorExecutor::createEmpty(ValueType::make_type(CellType::DOUBLE, {{_dimension}}), stash);
}

} // namespace features
} // namespace search