aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/impl/TensorAddressAny.java
blob: 31863c99a747fa3b09060a54b85e5b694910ef9f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.tensor.impl;

import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;

import static com.yahoo.tensor.impl.Convert.safe2Int;
import static com.yahoo.tensor.impl.Label.toNumber;
import static com.yahoo.tensor.impl.Label.fromNumber;

/**
 * Parent of tensor address family centered around each dimension as int.
 * A positive number represents a numeric index usable as a direect addressing.
 * - 1 is representing an invalid/null address
 * Other negative numbers are an enumeration maintained in {@link Label}
 *
 * @author baldersheim
 */
abstract public class TensorAddressAny extends TensorAddress {
    @Override
    public String label(int i) {
        return fromNumber((int)numericLabel(i));
    }

    public static TensorAddress of() {
        return TensorAddressEmpty.empty;
    }
    public static TensorAddress of(String label) {
        return new TensorAddressAny1(toNumber(label));
    }
    public static TensorAddress of(String label0, String label1) {
        return new TensorAddressAny2(toNumber(label0), toNumber(label1));
    }
    public static TensorAddress of(String label0, String label1, String label2) {
        return new TensorAddressAny3(toNumber(label0), toNumber(label1), toNumber(label2));
    }
    public static TensorAddress of(String label0, String label1, String label2, String label3) {
        return new TensorAddressAny4(toNumber(label0), toNumber(label1), toNumber(label2), toNumber(label3));
    }
    public static TensorAddress of(String [] labels) {
        int [] labelsAsInt = new int[labels.length];
        for (int i = 0; i < labels.length; i++) {
            labelsAsInt[i] = toNumber(labels[i]);
        }
        return ofUnsafe(labelsAsInt);
    }
    public static TensorAddress of(int label) {
        return new TensorAddressAny1(sanitize(label));
    }
    public static TensorAddress of(int label0, int label1) {
        return new TensorAddressAny2(sanitize(label0), sanitize(label1));
    }
    public static TensorAddress of(int label0, int label1, int label2) {
        return new TensorAddressAny3(sanitize(label0), sanitize(label1), sanitize(label2));
    }
    public static TensorAddress of(int label0, int label1, int label2, int label3) {
        return new TensorAddressAny4(sanitize(label0), sanitize(label1), sanitize(label2), sanitize(label3));
    }
    public static TensorAddress of(int ... labels) {
        return switch (labels.length) {
            case 0 -> of();
            case 1 -> new TensorAddressAny1(sanitize(labels[0]));
            case 2 -> new TensorAddressAny2(sanitize(labels[0]), sanitize(labels[1]));
            case 3 -> new TensorAddressAny3(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]));
            case 4 -> new TensorAddressAny4(sanitize(labels[0]), sanitize(labels[1]), sanitize(labels[2]), sanitize(labels[3]));
            default -> {
                for (int i = 0; i < labels.length; i++) {
                    sanitize(labels[i]);
                }
                yield new TensorAddressAnyN(labels);
            }
        };
    }
    public static TensorAddress of(long label) {
        return of(safe2Int(label));
    }

    public static TensorAddress of(long label0, long label1) {
        return of(safe2Int(label0), safe2Int(label1));
    }

    public static TensorAddress of(long label0, long label1, long label2) {
        return of(safe2Int(label0), safe2Int(label1), safe2Int(label2));
    }

    public static TensorAddress of(long label0, long label1, long label2, long label3) {
        return of(safe2Int(label0), safe2Int(label1), safe2Int(label2), safe2Int(label3));
    }

    public static TensorAddress of(long ... labels) {
        return switch (labels.length) {
            case 0 -> of();
            case 1 -> ofUnsafe(safe2Int(labels[0]));
            case 2 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]));
            case 3 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]));
            case 4 -> ofUnsafe(safe2Int(labels[0]), safe2Int(labels[1]), safe2Int(labels[2]), safe2Int(labels[3]));
            default -> {
                int [] labelsAsInt = new int[labels.length];
                for (int i = 0; i < labels.length; i++) {
                    labelsAsInt[i] = safe2Int(labels[i]);
                }
                yield of(labelsAsInt);
            }
        };
    }

    private static TensorAddress ofUnsafe(int label) {
        return new TensorAddressAny1(label);
    }
    private static TensorAddress ofUnsafe(int label0, int label1) {
        return new TensorAddressAny2(label0, label1);
    }
    private static TensorAddress ofUnsafe(int label0, int label1, int label2) {
        return new TensorAddressAny3(label0, label1, label2);
    }
    private static TensorAddress ofUnsafe(int label0, int label1, int label2, int label3) {
        return new TensorAddressAny4(label0, label1, label2, label3);
    }
    public static TensorAddress ofUnsafe(int ... labels) {
        return switch (labels.length) {
            case 0 -> of();
            case 1 -> ofUnsafe(labels[0]);
            case 2 -> ofUnsafe(labels[0], labels[1]);
            case 3 -> ofUnsafe(labels[0], labels[1], labels[2]);
            case 4 -> ofUnsafe(labels[0], labels[1], labels[2], labels[3]);
            default -> new TensorAddressAnyN(labels);
        };
    }
    private static int sanitize(int label) {
        if (label < Tensor.INVALID_INDEX) {
            throw new IndexOutOfBoundsException("cell label " + label + " must be positive");
        }
        return label;
    }
}