aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
blob: 3c8a6bde23204a01a1b96c981571479ab871dd80 (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package ai.vespa.rankingexpression.importer;

import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.MatMul;

import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Holds an intermediate representation of an imported model graph.
 * After this intermediate representation is constructed, it is used to
 * simplify and optimize the computational graph and then converted into the
 * final ImportedModel that holds the Vespa ranking expressions for the model.
 *
 * @author lesters
 */
public class IntermediateGraph {

    private final String modelName;
    private final Map<String, IntermediateOperation> operations = new HashMap<>();
    private final Map<String, GraphSignature> signatures = new HashMap<>();

    private static class GraphSignature {
        final Map<String, String> inputs = new HashMap<>();
        final Map<String, String> outputs = new HashMap<>();
    }

    public IntermediateGraph(String modelName) {
        this.modelName = modelName;
    }

    public String name() {
        return modelName;
    }

    public IntermediateOperation put(String key, IntermediateOperation operation) {
        return operations.put(key, operation);
    }

    public IntermediateOperation get(String key) {
        return operations.get(key);
    }

    public Set<String> signatures() {
        return signatures.keySet();
    }

    public Map<String, String> inputs(String signature) {
        return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
    }

    public Map<String, String> outputs(String signature) {
        return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
    }

    public String defaultSignature() {
        return "default";
    }

    public boolean alreadyImported(String key) {
        return operations.containsKey(key);
    }

    public Map<String, IntermediateOperation> operations() {
        return operations;
    }

    public void optimize() {
        renameDimensions();
    }

    static int counter = 0;

    /**
     * Find dimension names to avoid excessive renaming while evaluating the model.
     */
    private void renameDimensions() {
        DimensionRenamer renamer = new DimensionRenamer(this);
        for (String signature : signatures()) {
            for (String output : outputs(signature).values()) {
                addDimensionNameConstraints(operations.get(output), renamer);
            }
        }
        renamer.solve();
        for (String signature : signatures()) {
            for (String output : outputs(signature).values()) {
                renameDimensions(operations.get(output), renamer);
            }
        }
    }

    private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
        Set<String> operations = new HashSet<>();
        addDimensionNameConstraints(operation, renamer, operations);
    }

    private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
        if (operations.contains(operation.name())) {
            return;
        }
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer, operations));
            operation.addDimensionNameConstraints(renamer);
            operations.add(operation.name());
        }
    }

    private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
        Set<String> operations = new HashSet<>();
        renameDimensions(operation, renamer, operations);
    }

    private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer, Set<String> operations) {
        if (operations.contains(operation.name())) {
            return;
        }
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> renameDimensions(input, renamer, operations));
            operation.renameDimensions(renamer);
            operations.add(operation.name());
        }
    }

    @Override
    public String toString() {
        return "intermediate graph for '" + modelName + "'";
    }

    public String toFullString() {
        StringBuilder b = new StringBuilder();
        for (var input : operations.entrySet())
            b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
        return b.toString();
    }

}