aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reduce.java
blob: 54ca82c4bdfe6c63826821c98dba08c84ea35eec (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer.operations;

import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.GeneratorLambdaFunctionNode;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.Generate;
import com.yahoo.tensor.functions.ScalarFunctions;
import com.yahoo.tensor.functions.TensorFunction;

import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.DoubleUnaryOperator;

/**
 * ONNX Reduce[Sum/Mean/etc] operation
 */
public class Reduce extends IntermediateOperation {

    private final AttributeMap attributeMap;
    private final com.yahoo.tensor.functions.Reduce.Aggregator aggregator;
    private final DoubleUnaryOperator preOperator;
    private final DoubleUnaryOperator postOperator;

    private List<String> reduceDimensions;

    public Reduce(String modelName, String nodeName,
                  List<IntermediateOperation> inputs,
                  AttributeMap attributeMap,
                  com.yahoo.tensor.functions.Reduce.Aggregator aggregator) {
        this(modelName, nodeName, inputs, attributeMap, aggregator, null, null);
    }

    public Reduce(String modelName, String nodeName,
                  List<IntermediateOperation> inputs,
                  AttributeMap attributeMap,
                  com.yahoo.tensor.functions.Reduce.Aggregator aggregator,
                  DoubleUnaryOperator preOperator,
                  DoubleUnaryOperator postOperator) {
        super(modelName, nodeName, inputs);
        this.attributeMap = attributeMap;
        this.aggregator = aggregator;
        this.preOperator = preOperator;
        this.postOperator = postOperator;
    }


    @Override
    protected OrderedTensorType lazyGetType() {
        if ( ! allInputTypesPresent(1)) return null;

        OrderedTensorType inputType = inputs.get(0).type().get();

        reduceDimensions = inputType.dimensionNames();  // default is to reduce all dimensions
        if (attributeMap.getList("axes").isPresent()) {
            reduceDimensions = new ArrayList<>();
            for (Value i : attributeMap.getList("axes").get()) {
                int dimensionIndex = (int) i.asDouble();
                if (dimensionIndex < 0) {
                    dimensionIndex = inputType.dimensions().size() + dimensionIndex;
                }
                reduceDimensions.add(inputType.dimensions().get(dimensionIndex).name());
            }
        }
        return reducedType(inputType, shouldKeepDimensions());
    }

    @Override
    protected TensorFunction<Reference> lazyGetFunction() {
        if ( ! allInputTypesPresent(1)) return null;

        TensorFunction<Reference> inputFunction = inputs.get(0).function().get();
        if (preOperator != null) {
            inputFunction = new com.yahoo.tensor.functions.Map<>(inputFunction, preOperator);
        }
        TensorFunction<Reference> output = new com.yahoo.tensor.functions.Reduce<>(inputFunction, aggregator, reduceDimensions);
        if (shouldKeepDimensions()) {
            // multiply with a generated tensor created from the reduced dimensions
            TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
            for (String name : reduceDimensions) {
                typeBuilder.indexed(name, 1);
            }
            TensorType generatedType = typeBuilder.build();
            ExpressionNode generatedExpression = new ConstantNode(new DoubleValue(1));
            Generate<Reference> generatedFunction = new Generate<>(generatedType,
                    new GeneratorLambdaFunctionNode(generatedType, generatedExpression).asLongListToDoubleOperator());
            output = new com.yahoo.tensor.functions.Join<>(output, generatedFunction, ScalarFunctions.multiply());
        }
        if (postOperator != null) {
            output = new com.yahoo.tensor.functions.Map<>(output, postOperator);
        }
        return output;
    }

    @Override
    public void renameDimensions(DimensionRenamer renamer) {
        super.renameDimensions(renamer);
        List<String> renamedDimensions = new ArrayList<>(reduceDimensions.size());
        for (String name : reduceDimensions) {
            Optional<String> newName = renamer.dimensionNameOf(name);
            if (newName.isEmpty()) {
                return;  // presumably, already renamed
            }
            renamedDimensions.add(newName.get());
        }
        reduceDimensions = renamedDimensions;
    }

    @Override
    public Reduce withInputs(List<IntermediateOperation> inputs) {
        return new Reduce(modelName(), name(), inputs, attributeMap, aggregator, preOperator, postOperator);
    }

    @Override
    public String operationName() { return "Reduce"; }

    private boolean shouldKeepDimensions() {
        Optional<Value> keepDims = attributeMap.get("keepdims");
        return keepDims.isEmpty() || keepDims.get().asBoolean();  // default is 1
    }

    private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
        OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
        for (TensorType.Dimension dimension: inputType.type().dimensions()) {
            if ( ! reduceDimensions.contains(dimension.name())) {
                builder.add(dimension);
            } else if (keepDimensions) {
                builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
            }
        }
        return builder.build();
    }



}