1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
|
package com.yahoo.searchdefinition.expressiontransforms;
import com.yahoo.searchdefinition.RankProfile;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.ImportResult;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.TensorFlowImporter;
import com.yahoo.searchlib.rankingexpression.rule.Arguments;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
import java.util.Map;
import java.util.Optional;
/**
* Replaces instances of the tensorflow(model-path, signature, output)
* pseudofeature with the native Vespa ranking expression implementing
* the same computation.
*
* @author bratseth
*/
public class TensorFlowFeatureConverter extends ExpressionTransformer {
private final TensorFlowImporter tensorFlowImporter = new TensorFlowImporter();
private final RankProfile profile;
public TensorFlowFeatureConverter(RankProfile profile) {
this.profile = profile;
}
@Override
public ExpressionNode transform(ExpressionNode node) {
if (node instanceof ReferenceNode)
return transformFeature((ReferenceNode) node);
else if (node instanceof CompositeNode)
return super.transformChildren((CompositeNode) node);
else
return node;
}
private ExpressionNode transformFeature(ReferenceNode feature) {
try {
if ( ! feature.getName().equals("tensorflow")) return feature;
if (feature.getArguments().isEmpty())
throw new IllegalArgumentException("A tensorflow node must take an argument pointing to " +
"the tensorflow model directory under [application]/models");
// Find the specified expression
ImportResult result = tensorFlowImporter.importModel(asString(feature.getArguments().expressions().get(0)));
ImportResult.Signature signature = chooseOrDefault("signatures", result.signatures(),
optionalArgument(1, feature.getArguments()));
String output = chooseOrDefault("outputs", signature.outputs(),
optionalArgument(2, feature.getArguments()));
// Add all constants
result.constants().forEach((k, v) -> profile.addConstantTensor(k, new TensorValue(v)));
return result.expressions().get(output).getRoot();
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Could not import tensorflow model from " + feature, e);
}
}
/**
* Returns the specified, existing map value, or the only map value if no key is specified.
* Throws IllegalArgumentException in all other cases.
*/
private <T> T chooseOrDefault(String valueDescription, Map<String, T> map, Optional<String> key) {
if ( ! key.isPresent()) {
if (map.size() == 0)
throw new IllegalArgumentException("No " + valueDescription + " are present");
if (map.size() > 1)
throw new IllegalArgumentException("Model has multiple " + valueDescription + ", but no " +
valueDescription + " argument is specified");
return map.values().stream().findFirst().get();
}
else {
T value = map.get(key.get());
if (value == null)
throw new IllegalArgumentException("Model does not have the specified " +
valueDescription + " '" + key.get() + "'");
return value;
}
}
private Optional<String> optionalArgument(int argumentIndex, Arguments arguments) {
if (argumentIndex >= arguments.expressions().size())
return Optional.empty();
return Optional.of(asString(arguments.expressions().get(argumentIndex)));
}
private String asString(ExpressionNode node) {
if ( ! (node instanceof ConstantNode))
throw new IllegalArgumentException("Expected a constant string as tensorflow argument, but got '" + node);
return stripQuotes(((ConstantNode)node).sourceString());
}
private String stripQuotes(String s) {
if ( ! isQuoteSign(s.codePointAt(0))) return s;
if ( ! isQuoteSign(s.codePointAt(s.length() - 1 )))
throw new IllegalArgumentException("tensorflow argument [" + s + "] is missing endquote");
return s.substring(1, s.length()-1);
}
private boolean isQuoteSign(int c) {
return c == '\'' || c == '"';
}
}
|