summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
blob: 5d711aac100c67bc2633629220e42b9f0887701e (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations;

import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
import org.tensorflow.framework.NodeDef;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;

/**
 * Wraps a TensorFlow node and produces the respective Vespa tensor operation.
 * During import, a graph of these operations are constructed. Then, the
 * types are used to deduce sensible dimension names using the
 * DimensionRenamer. After the types have been renamed, the proper
 * Vespa expressions can be extracted.
 *
 * @author lesters
 */
public abstract class TensorFlowOperation {

    protected final static String MACRO_PREFIX = "tf_macro_";

    protected final NodeDef node;
    protected final int port;
    protected final List<TensorFlowOperation> inputs;
    protected final List<TensorFlowOperation> outputs = new ArrayList<>();
    protected final List<String> importWarnings = new ArrayList<>();

    protected OrderedTensorType type;
    protected TensorFunction function;
    protected TensorFunction macro = null;

    private Value constantValue = null;
    private List<TensorFlowOperation> controlInputs = Collections.emptyList();

    TensorFlowOperation(NodeDef node, List<TensorFlowOperation> inputs, int port) {
        this.node = node;
        this.port = port;
        this.inputs = Collections.unmodifiableList(inputs);
        this.inputs.forEach(i -> i.outputs.add(this));
    }

    protected abstract OrderedTensorType lazyGetType();
    protected abstract TensorFunction lazyGetFunction();

    /** Returns the Vespa tensor type of this operation if it exists */
    public Optional<OrderedTensorType> type() {
        if (type == null) {
            type = lazyGetType();
        }
        OrderedTensorType.verifyType(node, type);
        return Optional.ofNullable(type);
    }

    /** Returns the Vespa tensor function implementing all operations from this node with inputs */
    public Optional<TensorFunction> function() {
        if (function == null) {
            if (isConstant()) {
                ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
                function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
            } else if (outputs.size() > 1) {
                macro = lazyGetFunction();
                function = new VariableTensor(macroName(), type.type());
            } else {
                function = lazyGetFunction();
            }
        }
        return Optional.ofNullable(function);
    }

    /** Return TensorFlow node */
    public NodeDef node() { return node; }

    /** Return unmodifiable list of inputs */
    public List<TensorFlowOperation> inputs() { return inputs; }

    /** Return unmodifiable list of outputs. If a node has multiple outputs, consider adding a macro. */
    public List<TensorFlowOperation> outputs() { return Collections.unmodifiableList(outputs); }

    /** Returns a Vespa ranking expression that should be added as a macro */
    public Optional<TensorFunction> macro() { return Optional.ofNullable(macro); }

    /** Add dimension name constraints for this operation */
    public void addDimensionNameConstraints(DimensionRenamer renamer) { }

    /** Performs dimension rename for this operation */
    public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); }

    /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
    public boolean isInput() { return false; }

    /** Return true if this node is constant */
    public boolean isConstant() { return inputs.stream().allMatch(TensorFlowOperation::isConstant); }

    /** Sets the constant value */
    public void setConstantValue(Value value) { constantValue = value; }

    /** Gets the constant value if it exists */
    public Optional<Value> getConstantValue() { return Optional.ofNullable(constantValue); }

    /** Sets the external control inputs */
    public void setControlInputs(List<TensorFlowOperation> inputs) { this.controlInputs = inputs; }

    /** Retrieve the control inputs for this operation */
    public List<TensorFlowOperation> getControlInputs() { return Collections.unmodifiableList(this.controlInputs); }

    /** Retrieve the valid Vespa name of this node */
    public String vespaName() { return node.getName() != null ? node.getName().replace('/', '_') : null; }

    /** Retrieve the valid Vespa name of this node if it is a macro */
    public String macroName() { return vespaName() != null ? MACRO_PREFIX + vespaName() : null; }

    /** Retrieve the list of warnings produced during its lifetime */
    public List<String> warnings() { return Collections.unmodifiableList(importWarnings); }

    boolean verifyInputs(int expected, Function<TensorFlowOperation, Optional<?>> func) {
        if (!controlInputs.stream().map(func).allMatch(Optional::isPresent)) {
            return false;
        }
        if (inputs.size() != expected) {
            throw new IllegalArgumentException("Expected " + expected + " inputs " +
                    "for '" + node.getName() + "', got " + inputs.size());
        }
        return inputs.stream().map(func).allMatch(Optional::isPresent);
    }

    boolean allInputTypesPresent(int expected) {
        return verifyInputs(expected, TensorFlowOperation::type);
    }

    boolean allInputFunctionsPresent(int expected) {
        return verifyInputs(expected, TensorFlowOperation::function);
    }

}