aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
blob: c18e9f179d6afd39b853d965a034619b2faf72e1 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.tensor.serialization;

import com.yahoo.io.GrowableByteBuffer;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;

import java.util.Iterator;
import java.util.List;
import java.util.Optional;
import java.util.function.Consumer;
import java.util.function.Supplier;
import java.util.stream.Collectors;

/**
 * Implementation of a mixed binary format for a tensor.
 * See eval/src/vespa/eval/tensor/serialization/format.txt for format.
 *
 * @author lesters
 */
class MixedBinaryFormat implements BinaryFormat {

    private final TensorType.Value serializationValueType;

    MixedBinaryFormat() {
        this(TensorType.Value.DOUBLE);
    }
    MixedBinaryFormat(TensorType.Value serializationValueType) {
        this.serializationValueType = serializationValueType;
    }

    @Override
    public void encode(GrowableByteBuffer buffer, Tensor tensor) {
        if ( ! ( tensor instanceof MixedTensor))
            throw new RuntimeException("The mixed format is only supported for mixed tensors");
        MixedTensor mixed = (MixedTensor) tensor;
        encodeSparseDimensions(buffer, mixed);
        encodeDenseDimensions(buffer, mixed);
        encodeCells(buffer, mixed);
    }

    private void encodeSparseDimensions(GrowableByteBuffer buffer, MixedTensor tensor) {
        List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).toList();
        buffer.putInt1_4Bytes(sparseDimensions.size());
        for (TensorType.Dimension dimension : sparseDimensions) {
            buffer.putUtf8String(dimension.name());
        }
    }

    private void encodeDenseDimensions(GrowableByteBuffer buffer, MixedTensor tensor) {
        List<TensorType.Dimension> denseDimensions = tensor.type().dimensions().stream().filter(d -> d.isIndexed()).toList();
        buffer.putInt1_4Bytes(denseDimensions.size());
        for (TensorType.Dimension dimension : denseDimensions) {
            buffer.putUtf8String(dimension.name());
            buffer.putInt1_4Bytes((int)dimension.size().orElseThrow(() ->
                                  new IllegalArgumentException("Unknown size of indexed dimension.")).longValue());  // XXX: Size truncation
        }
    }

    private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) {
        switch (serializationValueType) {
            case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
            case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
            case BFLOAT16: encodeCells(buffer, tensor, (val) ->
                    buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break;
            case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break;
        }
    }

    private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor, Consumer<Double> consumer) {
        List<TensorType.Dimension> sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).toList();
        long denseSubspaceSize = tensor.denseSubspaceSize();
        if (sparseDimensions.size() > 0) {
            buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize));  // XXX: Size truncation
        }
        Iterator<Tensor.Cell> cellIterator = tensor.cellIterator();
        while (cellIterator.hasNext()) {
            Tensor.Cell cell = cellIterator.next();
            for (TensorType.Dimension dimension : sparseDimensions) {
                int index = tensor.type().indexOfDimension(dimension.name()).orElseThrow(() ->
                    new IllegalStateException("Dimension not found in address."));
                buffer.putUtf8String(cell.getKey().label(index));
            }
            consumer.accept(cell.getValue());
            for (int i = 1; i < denseSubspaceSize; ++i ) {
                consumer.accept(cellIterator.next().getValue());
            }
        }
    }

    @Override
    public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
        TensorType type;
        if (optionalType.isPresent()) {
            type = optionalType.get();
            if (type.valueType() != this.serializationValueType) {
                throw new IllegalArgumentException("Tensor value type mismatch. Value type " + type.valueType() +
                        " is not " + this.serializationValueType);
            }
            TensorType serializedType = decodeType(buffer);
            if ( ! serializedType.isAssignableTo(type))
                throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
                                                   " cannot be assigned to type " + type);
        }
        else {
            type = decodeType(buffer);
        }
        MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)MixedTensor.Builder.of(type);
        decodeCells(buffer, builder, type);
        return builder.build();
    }

    private TensorType decodeType(GrowableByteBuffer buffer) {
        TensorType.Builder builder = new TensorType.Builder(serializationValueType);
        int numMappedDimensions = buffer.getInt1_4Bytes();
        for (int i = 0; i < numMappedDimensions; ++i) {
            builder.mapped(buffer.getUtf8String());
        }
        int numIndexedDimensions = buffer.getInt1_4Bytes();
        for (int i = 0; i < numIndexedDimensions; ++i) {
            builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes());  // XXX: Size truncation
        }
        return builder.build();
    }

    private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
        switch (serializationValueType) {
            case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
            case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
            case BFLOAT16: decodeCells(buffer, builder, type, () ->
                    (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break;
            case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
        }
    }

    private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type, Supplier<Double> supplier) {
        List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).toList();
        TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions);
        long denseSubspaceSize = builder.denseSubspaceSize();

        int numBlocks = 1;
        if (sparseDimensions.size() > 0) {
            numBlocks = buffer.getInt1_4Bytes();
        }

        double[] denseSubspace = new double[(int)denseSubspaceSize];
        for (int i = 0; i < numBlocks; ++i) {
            TensorAddress.Builder sparseAddress = new TensorAddress.Builder(sparseType);
            for (TensorType.Dimension sparseDimension : sparseDimensions) {
                sparseAddress.add(sparseDimension.name(), buffer.getUtf8String());
            }
            for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) {
                denseSubspace[(int)denseOffset] = supplier.get();
            }
            builder.block(sparseAddress.build(), denseSubspace);
        }
    }

}