aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
blob: 14aa3ebf84e3f8c61a7e0b6e75e73ec64690065f (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.

package ai.vespa.rankingexpression.importer;

import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
import ai.vespa.rankingexpression.importer.operations.MatMul;

import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**
 * Holds an intermediate representation of an imported model graph.
 * After this intermediate representation is constructed, it is used to
 * simplify and optimize the computational graph and then converted into the
 * final ImportedModel that holds the Vespa ranking expressions for the model.
 *
 * @author lesters
 */
public class IntermediateGraph {

    private final String modelName;
    private final Map<String, IntermediateOperation> operations = new HashMap<>();
    private final Map<String, GraphSignature> signatures = new HashMap<>();

    private static class GraphSignature {
        final Map<String, String> inputs = new HashMap<>();
        final Map<String, String> outputs = new HashMap<>();
    }

    public IntermediateGraph(String modelName) {
        this.modelName = modelName;
    }

    public String name() {
        return modelName;
    }

    public IntermediateOperation put(String key, IntermediateOperation operation) {
        return operations.put(key, operation);
    }

    public IntermediateOperation get(String key) {
        return operations.get(key);
    }

    public Set<String> signatures() {
        return signatures.keySet();
    }

    public Map<String, String> inputs(String signature) {
        return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).inputs;
    }

    public Map<String, String> outputs(String signature) {
        return signatures.computeIfAbsent(signature, (k) -> new GraphSignature()).outputs;
    }

    public String defaultSignature() {
        return "default";
    }

    public boolean alreadyImported(String key) {
        return operations.containsKey(key);
    }

    public Map<String, IntermediateOperation> operations() {
        return operations;
    }

    public void optimize() {
        renameDimensions();
    }

    /**
     * Find dimension names to avoid excessive renaming while evaluating the model.
     */
    private void renameDimensions() {
        DimensionRenamer renamer = new DimensionRenamer(this);
        for (String signature : signatures()) {
            for (String output : outputs(signature).values()) {
                addDimensionNameConstraints(operations.get(output), renamer);
            }
        }
        renamer.solve();
        for (String signature : signatures()) {
            for (String output : outputs(signature).values()) {
                renameDimensions(operations.get(output), renamer);
            }
        }
    }

    private static void addDimensionNameConstraints(IntermediateOperation operation, DimensionRenamer renamer) {
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
            operation.addDimensionNameConstraints(renamer);
        }
    }

    private static void renameDimensions(IntermediateOperation operation, DimensionRenamer renamer) {
        if (operation.type().isPresent()) {
            operation.inputs().forEach(input -> renameDimensions(input, renamer));
            operation.renameDimensions(renamer);
        }
    }

    @Override
    public String toString() {
        return "intermediate graph for '" + modelName + "'";
    }

    public String toFullString() {
        StringBuilder b = new StringBuilder();
        for (var input : operations.entrySet())
            b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
        return b.toString();
    }

}