summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/DirectIndexedAddress.java
blob: 377523618760dc8511b3da15be344e5ac17e1d42 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
package com.yahoo.tensor;

/**
 * Utility class for efficient access and iteration along dimensions in Indexed tensors.
 * Usage: Use setIndex to lock the indexes of the dimensions that don't change in this iteration.
 *        long base = addr.getDirectIndex();
 *        long stride = addr.getStride(dimension)
 *        i = 0...size_of_dimension
 *            double value = tensor.get(base + i * stride);
 */
public final class DirectIndexedAddress {
    private final DimensionSizes sizes;
    private final int [] indexes;
    private long directIndex;
    private DirectIndexedAddress(DimensionSizes sizes) {
        this.sizes = sizes;
        indexes = new int[sizes.dimensions()];
        directIndex = 0;
    }
    static DirectIndexedAddress of(DimensionSizes sizes) {
        return new DirectIndexedAddress(sizes);
    }
    /** Sets the current index of a dimension */
    public void setIndex(int dimension, int index) {
        if (index < 0 || index >= sizes.size(dimension)) {
            throw new IndexOutOfBoundsException("Index " + index + " outside of [0," + sizes.size(dimension) + ">");
        }
        int diff = index - indexes[dimension];
        directIndex += getStride(dimension) * diff;
        indexes[dimension] = index;
    }
    /** Retrieve the index that can be used for direct lookup in an indexed tensor. */
    public long getDirectIndex() { return directIndex; }
    /** returns the stride to be used for the given dimension */
    public long getStride(int dimension) {
        return sizes.productOfDimensionsAfter(dimension);
    }
}