1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
|
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package ai.vespa.rankingexpression.importer;
import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.MatMul;
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
/**
* Holds an intermediate representation of an imported model graph.
* After this intermediate representation is constructed, it is used to
* simplify and optimize the computational graph and then converted into the
* final ImportedModel that holds the Vespa ranking expressions for the model.
*
* @author lesters
*/
public class IntermediateGraph {
private final String modelName;
private final Map<String, IntermediateOperation> operations = new HashMap<>();
private final Map<String, GraphSignature> signatures = new HashMap<>();
private static class GraphSignature {
final Map<String, String> inputs = new HashMap<>();
final Map<String, String> outputs = new HashMap<>();
}
public IntermediateGraph(String modelName) {
this.modelName = modelName;
}
public String name() {
return modelName;
}
public IntermediateOperation put(String key, IntermediateOperation operation) {
return operations.put(key, operation);
}
public IntermediateOperation get(String key) {
return operations.get(key);
}
public Set<String> signatures() {
return signatures.keySet();
}
public Map<String, String> inputs(String signature) {
return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
}
public Map<String, String> outputs(String signature) {
return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
}
public String defaultSignature() {
return "default";
}
public boolean alreadyImported(String key) {
return operations.containsKey(key);
}
public Map<String, IntermediateOperation> operations() {
return operations;
}
public void optimize() {
renameDimensions();
}
/**
* Find dimension names to avoid excessive renaming while evaluating the model.
*/
private void renameDimensions() {
DimensionRenamer renamer = new DimensionRenamer(this);
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
addDimensionNameConstraints(operations.get(output), renamer);
}
}
renamer.solve();
for (String signature : signatures()) {
for (String output : outputs(signature).values()) {
renameDimensions(operations.get(output), renamer);
}
}
}
private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
if (operation.type().isPresent()) {
operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
operation.addDimensionNameConstraints(renamer);
}
}
private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
if (operation.type().isPresent()) {
operation.inputs().forEach(input -> renameDimensions(input, renamer));
operation.renameDimensions(renamer);
}
}
@Override
public String toString() {
return "intermediate graph for '" + modelName + "'";
}
public String toFullString() {
StringBuilder b = new StringBuilder();
for (var input : operations.entrySet())
b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
return b.toString();
}
}
|