summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
blob: 4379d50520cc24498966f5077dd982bc0306a9f1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
package com.yahoo.tensor;

/**
 * Utility class for efficient access and iteration along dimensions in Indexed tensors.
 * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
 *        long base = addr.getDirectIndex();
 *        long stride = addr.getStride(dimension)
 *        i = 0...size_of_dimension
 *            double value = tensor.get(base + i * stride);
 *
 * @author baldersheim
 */
public final class DirectIndexedAddress {

    private final DimensionSizes sizes;
    private final int[] indexes;
    private long directIndex;

    private DirectIndexedAddress(DimensionSizes sizes) {
        this.sizes = sizes;
        indexes = new int[sizes.dimensions()];
        directIndex = 0;
    }

    static DirectIndexedAddress of(DimensionSizes sizes) {
        return new DirectIndexedAddress(sizes);
    }

    /** Sets the current index of a dimension */
    public void setIndex(int dimension, int index) {
        if (index < 0 || index >= sizes.size(dimension)) {
            throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
        }
        int diff = index - indexes[dimension];
        directIndex += getStride(dimension) * diff;
        indexes[dimension] = index;
    }

    /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
    public long getDirectIndex() { return directIndex; }

    /** returns the stride to be used for the given dimension */
    public long getStride(int dimension) {
        return sizes.productOfDimensionsAfter(dimension);
    }

}