summaryrefslogtreecommitdiffstats
path: root/document/src/main/java/com/yahoo/document/json/readers/TensorRemoveUpdateReader.java
blob: 0d12e7c074b605a55d805c135e593074e8097606 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.json.readers;

import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorRemoveUpdate;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;

import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;

/**
 * Class used to read a remove update for a tensor field.
 */
public class TensorRemoveUpdateReader {

    public static final String TENSOR_ADDRESSES = "addresses";

    static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
        expectObjectStart(buffer.currentToken());
        expectTensorTypeHasSparseDimensions(field);

        TensorDataType tensorDataType = (TensorDataType)field.getDataType();
        TensorType originalType = tensorDataType.getTensorType();
        TensorType convertedType = extractSparseDimensions(originalType);

        Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
        expectAddressesAreNonEmpty(field, tensor);
        return new TensorRemoveUpdate(new TensorFieldValue(tensor));
    }

    private static void expectTensorTypeHasSparseDimensions(Field field) {
        TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
        if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
            throw new IllegalArgumentException("A remove update can only be applied to tensors " +
                    "with at least one sparse dimension. Field '" + field.getName() +
                    "' has unsupported tensor type '" + tensorType + "'");
        }
    }

    private static void expectAddressesAreNonEmpty(Field field, Tensor tensor) {
        if (tensor.isEmpty()) {
            throw new IllegalArgumentException("Remove update for field '" + field.getName() + "' does not contain tensor addresses");
        }
    }

    /**
     * Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
     */
    private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
        Tensor.Builder builder = Tensor.Builder.of(type);
        expectObjectStart(buffer.currentToken());
        int initNesting = buffer.nesting();
        for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
            if (TENSOR_ADDRESSES.equals(buffer.currentName())) {
                expectArrayStart(buffer.currentToken());
                int nesting = buffer.nesting();
                for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
                    builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
                }
                expectCompositeEnd(buffer.currentToken());
            }
        }
        expectObjectEnd(buffer.currentToken());
        return builder.build();
    }

    private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
        TensorAddress.Builder builder = new TensorAddress.Builder(type);
        expectObjectStart(buffer.currentToken());
        int initNesting = buffer.nesting();
        for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
            String dimension = buffer.currentName();
            if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
                throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
            }
            String label = buffer.currentText();
            builder.add(dimension, label);
        }
        expectObjectEnd(buffer.currentToken());
        return builder.build();
    }

    public static TensorType extractSparseDimensions(TensorType type) {
        TensorType.Builder builder = new TensorType.Builder();
        type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
        return builder.build();
    }
}