aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
blob: a2a3874ecedb5f0814b275497cc2c5e376016e2f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
 *
 * @author bratseth
 */
public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {

    private final TensorFunction<NAMETYPE> argument;
    private final List<String> fromDimensions;
    private final List<String> toDimensions;
    private final Map<String, String> fromToMap;

    public Rename(TensorFunction<NAMETYPE> argument, String fromDimension, String toDimension) {
        this(argument, List.of(fromDimension), List.of(toDimension));
    }

    public Rename(TensorFunction<NAMETYPE> argument, List<String> fromDimensions, List<String> toDimensions) {
        Objects.requireNonNull(argument, "The argument tensor cannot be null");
        Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
        Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
        if (fromDimensions.size() < 1)
            throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
        if (fromDimensions.size() != toDimensions.size())
            throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
                                               fromDimensions.size() + " and " + toDimensions.size());
        this.argument = argument;
        this.fromDimensions = List.copyOf(fromDimensions);
        this.toDimensions = List.copyOf(toDimensions);
        this.fromToMap = fromToMap(fromDimensions, toDimensions);
    }

    public List<String> fromDimensions() { return fromDimensions; }
    public List<String> toDimensions() { return toDimensions; }

    private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) {
        Map<String, String> map = new HashMap<>();
        for (int i = 0; i < fromDimensions.size(); i++)
            map.put(fromDimensions.get(i), toDimensions.get(i));
        return map;
    }

    @Override
    public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }

    @Override
    public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
        if ( arguments.size() != 1)
            throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
        return new Rename<>(arguments.get(0), fromDimensions, toDimensions);
    }

    @Override
    public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }

    @Override
    public TensorType type(TypeContext<NAMETYPE> context) {
        return type(argument.type(context));
    }

    private TensorType type(TensorType type) {
        return TypeResolver.rename(type, fromDimensions, toDimensions);
    }

    @Override
    public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
        Tensor tensor = argument.evaluate(context);

        TensorType renamedType = type(tensor.type());

        // an array which lists the index of each label in the renamed type
        int[] toIndexes = new int[tensor.type().dimensions().size()];
        for (int i = 0; i < tensor.type().dimensions().size(); i++) {
            String dimensionName = tensor.type().dimensions().get(i).name();
            String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
            toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
        }

        // avoid building a new tensor if dimensions can simply be renamed
        if (simpleRenameIsPossible(toIndexes)) {
            return tensor.withType(renamedType);
        }

        Tensor.Builder builder = Tensor.Builder.of(renamedType);
        for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
            Map.Entry<TensorAddress, Double> cell = i.next();
            TensorAddress renamedAddress = rename(cell.getKey(), toIndexes);
            builder.cell(renamedAddress, cell.getValue());
        }
        return builder.build();
    }

    /**
     * If none of the dimensions change order after rename we can do a simple rename.
     */
    private boolean simpleRenameIsPossible(int[] toIndexes) {
        for (int i = 0; i < toIndexes.length; ++i) {
            if (toIndexes[i] != i) {
                return false;
            }
        }
        return true;
    }

    private TensorAddress rename(TensorAddress address, int[] toIndexes) {
        String[] reorderedLabels = new String[toIndexes.length];
        for (int i = 0; i < toIndexes.length; i++)
            reorderedLabels[toIndexes[i]] = address.label(i);
        return TensorAddress.of(reorderedLabels);
    }

    private String toVectorString(List<String> elements) {
        if (elements.size() == 1)
            return elements.get(0);
        StringBuilder b = new StringBuilder("(");
        for (String element : elements)
            b.append(element).append(", ");
        b.setLength(b.length() - 2);
        b.append(")");
        return b.toString();
    }

    @Override
    public String toString(ToStringContext<NAMETYPE> context) {
        return "rename(" + argument.toString(context) + ", " +
                       toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
    }

    @Override
    public int hashCode() { return Objects.hash("rename", argument, fromDimensions, toDimensions); }

}