summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
blob: de3d2be265a077c31f117de5b803a1bb832e02bc (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;

import com.google.common.annotations.Beta;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;

import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Objects;

/**
 * The <i>rename</i> tensor function returns a tensor where some dimensions are assigned new names.
 *
 * @author bratseth
 */
@Beta
public class Rename extends PrimitiveTensorFunction {

    private final TensorFunction argument;
    private final List<String> fromDimensions;
    private final List<String> toDimensions;
    private final Map<String, String> fromToMap;

    public Rename(TensorFunction argument, String fromDimension, String toDimension) {
        this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
    }

    public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
        Objects.requireNonNull(argument, "The argument tensor cannot be null");
        Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
        Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
        if (fromDimensions.size() < 1)
            throw new IllegalArgumentException("from dimensions is empty, must rename at least one dimension");
        if (fromDimensions.size() != toDimensions.size())
            throw new IllegalArgumentException("Rename from and to dimensions must be equal, was " +
                                               fromDimensions.size() + " and " + toDimensions.size());
        this.argument = argument;
        this.fromDimensions = ImmutableList.copyOf(fromDimensions);
        this.toDimensions = ImmutableList.copyOf(toDimensions);
        this.fromToMap = fromToMap(fromDimensions, toDimensions);
    }

    public List<String> fromDimensions() { return fromDimensions; }
    public List<String> toDimensions() { return toDimensions; }

    private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) {
        Map<String, String> map = new HashMap<>();
        for (int i = 0; i < fromDimensions.size(); i++)
            map.put(fromDimensions.get(i), toDimensions.get(i));
        return map;
    }

    @Override
    public List<TensorFunction> arguments() { return Collections.singletonList(argument); }

    @Override
    public TensorFunction withArguments(List<TensorFunction> arguments) {
        if ( arguments.size() != 1)
            throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
        return new Rename(arguments.get(0), fromDimensions, toDimensions);
    }

    @Override
    public PrimitiveTensorFunction toPrimitive() { return this; }

    @Override
    public TensorType type(TypeContext context) {
        return type(argument.type(context));
    }

    private TensorType type(TensorType type) {
        TensorType.Builder builder = new TensorType.Builder();
        for (TensorType.Dimension dimension : type.dimensions())
            builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
        return builder.build();
    }

    @Override
    public Tensor evaluate(EvaluationContext context) {
        Tensor tensor = argument.evaluate(context);

        TensorType renamedType = type(tensor.type());

        // an array which lists the index of each label in the renamed type
        int[] toIndexes = new int[tensor.type().dimensions().size()];
        for (int i = 0; i < tensor.type().dimensions().size(); i++) {
            String dimensionName = tensor.type().dimensions().get(i).name();
            String newDimensionName = fromToMap.getOrDefault(dimensionName, dimensionName);
            toIndexes[i] = renamedType.indexOfDimension(newDimensionName).get();
        }

        Tensor.Builder builder = Tensor.Builder.of(renamedType);
        for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
            Map.Entry<TensorAddress, Double> cell = i.next();
            TensorAddress renamedAddress = rename(cell.getKey(), toIndexes);
            builder.cell(renamedAddress, cell.getValue());
        }
        return builder.build();
    }

    private TensorAddress rename(TensorAddress address, int[] toIndexes) {
        String[] reorderedLabels = new String[toIndexes.length];
        for (int i = 0; i < toIndexes.length; i++)
            reorderedLabels[toIndexes[i]] = address.label(i);
        return TensorAddress.of(reorderedLabels);
    }

    @Override
    public String toString(ToStringContext context) {
        return "rename(" + argument.toString(context) + ", " +
                       toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
    }

    private String toVectorString(List<String> elements) {
        if (elements.size() == 1)
            return elements.get(0);
        StringBuilder b = new StringBuilder("(");
        for (String element : elements)
            b.append(element).append(", ");
        b.setLength(b.length() - 2);
        b.append(")");
        return b.toString();
    }

}