aboutsummaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com/yahoo/document/json/readers/TensorModifyUpdateReader.java
blob: d7e124fb7214e03f9b8aa34c8aedab763ad20e93 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package com.yahoo.document.json.readers;

import com.fasterxml.jackson.core.JsonToken;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorModifyUpdate;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;

import java.util.Iterator;

import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
import static com.yahoo.document.json.readers.TensorReader.TENSOR_BLOCKS;
import static com.yahoo.document.json.readers.TensorReader.TENSOR_CELLS;
import static com.yahoo.document.json.readers.TensorReader.readTensorBlocks;
import static com.yahoo.document.json.readers.TensorReader.readTensorCells;

/**
 * Reader of a "modify" update of a tensor field.
 */
public class TensorModifyUpdateReader {

    public static final String UPDATE_MODIFY = "modify";
    private static final String MODIFY_OPERATION = "operation";
    private static final String MODIFY_REPLACE = "replace";
    private static final String MODIFY_ADD = "add";
    private static final String MODIFY_MULTIPLY = "multiply";
    private static final String MODIFY_CREATE = "create";

    public static TensorModifyUpdate createModifyUpdate(TokenBuffer buffer, Field field) {
        expectFieldIsOfTypeTensor(field);
        expectTensorTypeHasNoIndexedUnboundDimensions(field);
        expectObjectStart(buffer.current());

        ModifyUpdateResult result = createModifyUpdateResult(buffer, field);
        expectOperationSpecified(result.operation, field.getName());
        expectTensorSpecified(result.tensor, field.getName());

        return new TensorModifyUpdate(result.operation, result.tensor, result.createNonExistingCells);
    }

    private static void expectFieldIsOfTypeTensor(Field field) {
        if ( ! (field.getDataType() instanceof TensorDataType)) {
            throw new IllegalArgumentException("A modify update can only be applied to tensor fields. " +
                                               "Field '" + field.getName() + "' is of type '" +
                                               field.getDataType().getName() + "'");
        }
    }

    private static void expectTensorTypeHasNoIndexedUnboundDimensions(Field field) {
        TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
        if (tensorType.dimensions().stream()
                .anyMatch(dim -> dim.type().equals(TensorType.Dimension.Type.indexedUnbound))) {
            throw new IllegalArgumentException("A modify update cannot be applied to tensor types with indexed unbound dimensions. " +
                                               "Field '" + field.getName() + "' has unsupported tensor type '" + tensorType + "'");
        }
    }

    private static void expectOperationSpecified(TensorModifyUpdate.Operation operation, String fieldName) {
        if (operation == null) {
            throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain an operation");
        }
    }

    private static void expectTensorSpecified(TensorFieldValue tensor, String fieldName) {
        if (tensor == null) {
            throw new IllegalArgumentException("Modify update for field '" + fieldName + "' does not contain tensor cells");
        }
    }

    private static class ModifyUpdateResult {
        TensorModifyUpdate.Operation operation = null;
        boolean createNonExistingCells = false;
        TensorFieldValue tensor = null;
    }

    private static ModifyUpdateResult createModifyUpdateResult(TokenBuffer buffer, Field field) {
        ModifyUpdateResult result = new ModifyUpdateResult();
        buffer.next();
        int localNesting = buffer.nesting();
        while (localNesting <= buffer.nesting()) {
            switch (buffer.currentName()) {
                case MODIFY_OPERATION:
                    result.operation = createOperation(buffer, field.getName());
                    break;
                case MODIFY_CREATE:
                    result.createNonExistingCells = createNonExistingCells(buffer, field.getName());
                    break;
                case TENSOR_CELLS:
                    result.tensor = createTensorFromCells(buffer, field);
                    break;
                case TENSOR_BLOCKS:
                    result.tensor = createTensorFromBlocks(buffer, field);
                    break;
                default:
                    throw new IllegalArgumentException("Unknown JSON string '" + buffer.currentName() + "' in modify update for field '" + field.getName() + "'");
            }
            buffer.next();
        }
        return result;
    }

    private static TensorModifyUpdate.Operation createOperation(TokenBuffer buffer, String fieldName) {
        switch (buffer.currentText()) {
            case MODIFY_REPLACE:
                return TensorModifyUpdate.Operation.REPLACE;
            case MODIFY_ADD:
                return TensorModifyUpdate.Operation.ADD;
            case MODIFY_MULTIPLY:
                return TensorModifyUpdate.Operation.MULTIPLY;
            default:
                throw new IllegalArgumentException("Unknown operation '" + buffer.currentText() + "' in modify update for field '" + fieldName + "'");
        }
    }

    private static Boolean createNonExistingCells(TokenBuffer buffer, String fieldName) {
        if (buffer.current() == JsonToken.VALUE_TRUE) {
            return true;
        } else if (buffer.current() == JsonToken.VALUE_FALSE) {
            return false;
        } else {
            throw new IllegalArgumentException("Unknown value '" + buffer.currentText() + "' for '" + MODIFY_CREATE + "' in modify update for field '" + fieldName + "'");
        }
    }

    private static TensorFieldValue createTensorFromCells(TokenBuffer buffer, Field field) {
        TensorDataType tensorDataType = (TensorDataType)field.getDataType();
        TensorType originalType = tensorDataType.getTensorType();
        TensorType convertedType = TensorModifyUpdate.convertDimensionsToMapped(originalType);

        Tensor.Builder tensorBuilder = Tensor.Builder.of(convertedType);
        readTensorCells(buffer, tensorBuilder);
        Tensor tensor = tensorBuilder.build();

        validateBounds(tensor, originalType);

        return new TensorFieldValue(tensor);
    }

    private static TensorFieldValue createTensorFromBlocks(TokenBuffer buffer, Field field) {
        TensorDataType tensorDataType = (TensorDataType)field.getDataType();
        TensorType type = tensorDataType.getTensorType();

        Tensor.Builder tensorBuilder = Tensor.Builder.of(type);
        readTensorBlocks(buffer, tensorBuilder);
        Tensor tensor = convertToSparse(tensorBuilder.build());
        validateBounds(tensor, type);

        return new TensorFieldValue(tensor);
    }

    private static Tensor convertToSparse(Tensor tensor) {
        if (tensor.type().dimensions().stream().noneMatch(dimension -> dimension.isIndexed())) return tensor;
        Tensor.Builder b = Tensor.Builder.of(TensorModifyUpdate.convertDimensionsToMapped(tensor.type()));
        for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); )
            b.cell(i.next());
        return b.build();
    }

    /** Only validate if original type has indexed bound dimensions */
    static void validateBounds(Tensor convertedTensor, TensorType originalType) {
        if (originalType.dimensions().stream().noneMatch(d -> d instanceof TensorType.IndexedBoundDimension)) {
            return;
        }
        for (Iterator<Tensor.Cell> cellIterator = convertedTensor.cellIterator(); cellIterator.hasNext(); ) {
            Tensor.Cell cell = cellIterator.next();
            TensorAddress address = cell.getKey();
            for (int i = 0; i < address.size(); ++i) {
                TensorType.Dimension dim = originalType.dimensions().get(i);
                if (dim instanceof TensorType.IndexedBoundDimension) {
                    long label = address.numericLabel(i);
                    long bound = dim.size().get();  // size is non-optional for indexed bound
                    if (label >= bound) {
                        throw new IndexOutOfBoundsException("Dimension '" + originalType.dimensions().get(i).name() +
                                                            "' has label '" + label + "' but type is " + originalType.toString());
                    }
                }
            }
        }
    }

}