summaryrefslogtreecommitdiffstats
path: root/config-model/src/main/java/com/yahoo/searchdefinition/expressiontransforms/TensorFlowFeatureConverter.java
blob: e5886030d44f4f38e51cb46e4dbd91da0b074b3d (plain) (blame)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
package com.yahoo.searchdefinition.expressiontransforms;

import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;

import java.util.Map;
import java.util.Optional;

/**
 * Replaces instances of the tensorflow(model-path, signature, output)
 * pseudofeature with the native Vespa ranking expression implementing
 * the same computation.
 *
 * @author bratseth
 */
public class TensorFlowFeatureConverter extends ExpressionTransformer {

    private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
    private final RankProfile profile;

    public TensorFlowFeatureConverter(RankProfile profile) {
        this.profile = profile;
    }

    @Override
    public ExpressionNode transform(ExpressionNode node) {
        if (node instanceof ReferenceNode)
            return transformFeature((ReferenceNode) node);
        else if (node instanceof CompositeNode)
            return super.transformChildren((CompositeNode) node);
        else
            return node;
    }

    private ExpressionNode transformFeature(ReferenceNode feature) {
        try {
            if ( ! feature.getName().equals("tensorflow")) return feature;

            if (feature.getArguments().isEmpty())
                throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
                                                   "the tensorflow model directory under [application]/models");

            // Find the specified expression
            ImportResult result = tensorFlowImporter.importModel(asString(feature.getArguments().expressions().get(0)));
            ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(),
                                                               optionalArgument(1, feature.getArguments()));
            String output = chooseOrDefault("outputs", signature.outputs(),
                                            optionalArgument(2, feature.getArguments()));

            // Add all constants
            result.constants().forEach((k, v) -> profile.addConstantTensor(k, new TensorValue(v)));

            return result.expressions().get(output).getRoot();
        }
        catch (IllegalArgumentException e) {
            throw new IllegalArgumentException("Could not import tensorflow model from " + feature, e);
        }
    }

    /**
     * Returns the specified, existing map value, or the only map value if no key is specified.
     * Throws IllegalArgumentException in all other cases.
     */
    private <T> T chooseOrDefault(String valueDescription, Map<String, T> map, Optional<String> key) {
        if ( ! key.isPresent()) {
            if (map.size() == 0)
                throw new IllegalArgumentException("No " + valueDescription + " are present");
            if (map.size() > 1)
                throw new IllegalArgumentException("Model has multiple " + valueDescription + ", but no " +
                                                   valueDescription + " argument is specified");
            return map.values().stream().findFirst().get();
        }
        else {
            T value = map.get(key.get());
            if (value == null)
                throw new IllegalArgumentException("Model does not have the specified " +
                                                   valueDescription + " '" + key.get() + "'");
            return value;
        }
    }

    private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
        if (argumentIndex >= arguments.expressions().size())
            return Optional.empty();
        return Optional.of(asString(arguments.expressions().get(argumentIndex)));
    }

    private String asString(ExpressionNode node) {
        if ( ! (node instanceof ConstantNode))
            throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
        return stripQuotes(((ConstantNode)node).sourceString());
    }

    private String stripQuotes(String s) {
        if ( ! isQuoteSign(s.codePointAt(0))) return s;
        if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
            throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
        return s.substring(1, s.length()-1);
    }

    private boolean isQuoteSign(int c) {
        return c == '\'' || c == '"';
    }


}