1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
|
// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.document.json.readers;
import com.yahoo.document.Field;
import com.yahoo.document.TensorDataType;
import com.yahoo.document.datatypes.TensorFieldValue;
import com.yahoo.document.json.TokenBuffer;
import com.yahoo.document.update.TensorRemoveUpdate;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectArrayStart;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectCompositeEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectEnd;
import static com.yahoo.document.json.readers.JsonParserHelpers.expectObjectStart;
/**
* Class used to read a remove update for a tensor field.
*/
public class TensorRemoveUpdateReader {
public static final String TENSOR_ADDRESSES = "addresses";
static TensorRemoveUpdate createTensorRemoveUpdate(TokenBuffer buffer, Field field) {
expectObjectStart(buffer.currentToken());
expectTensorTypeHasSparseDimensions(field);
TensorDataType tensorDataType = (TensorDataType)field.getDataType();
TensorType originalType = tensorDataType.getTensorType();
TensorType convertedType = extractSparseDimensions(originalType);
Tensor tensor = readRemoveUpdateTensor(buffer, convertedType, originalType);
expectAddressesAreNonEmpty(field, tensor);
return new TensorRemoveUpdate(new TensorFieldValue(tensor));
}
private static void expectTensorTypeHasSparseDimensions(Field field) {
TensorType tensorType = ((TensorDataType)field.getDataType()).getTensorType();
if (tensorType.dimensions().stream().allMatch(TensorType.Dimension::isIndexed)) {
throw new IllegalArgumentException("A remove update can only be applied to tensors " +
"with at least one sparse dimension. Field '" + field.getName() +
"' has unsupported tensor type '" + tensorType + "'");
}
}
private static void expectAddressesAreNonEmpty(Field field, Tensor tensor) {
if (tensor.isEmpty()) {
throw new IllegalArgumentException("Remove update for field '" + field.getName() + "' does not contain tensor addresses");
}
}
/**
* Reads all addresses in buffer and returns a tensor where addresses have cell value 1.0
*/
private static Tensor readRemoveUpdateTensor(TokenBuffer buffer, TensorType type, TensorType originalType) {
Tensor.Builder builder = Tensor.Builder.of(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
if (TENSOR_ADDRESSES.equals(buffer.currentName())) {
expectArrayStart(buffer.currentToken());
int nesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= nesting; buffer.next()) {
builder.cell(readTensorAddress(buffer, type, originalType), 1.0);
}
expectCompositeEnd(buffer.currentToken());
}
}
expectObjectEnd(buffer.currentToken());
return builder.build();
}
private static TensorAddress readTensorAddress(TokenBuffer buffer, TensorType type, TensorType originalType) {
TensorAddress.Builder builder = new TensorAddress.Builder(type);
expectObjectStart(buffer.currentToken());
int initNesting = buffer.nesting();
for (buffer.next(); buffer.nesting() >= initNesting; buffer.next()) {
String dimension = buffer.currentName();
if ( ! type.dimension(dimension).isPresent() && originalType.dimension(dimension).isPresent()) {
throw new IllegalArgumentException("Indexed dimension address '" + dimension + "' should not be specified in remove update");
}
String label = buffer.currentText();
builder.add(dimension, label);
}
expectObjectEnd(buffer.currentToken());
return builder.build();
}
public static TensorType extractSparseDimensions(TensorType type) {
TensorType.Builder builder = new TensorType.Builder();
type.dimensions().stream().filter(dim -> ! dim.isIndexed()).forEach(dim -> builder.mapped(dim.name()));
return builder.build();
}
}
|