diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java | 70 |
1 files changed, 55 insertions, 15 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java index b4a9b363ade..b3c1708a0f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java @@ -12,40 +12,80 @@ import java.util.Map; import java.util.stream.Collectors; /** - * The result of importing a TensorFlow model into Vespa: - * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model - * - A list of named constant tensors - * - A list of expected input tensors, with their tensor type - * - A list of warning messages + * The result of importing a TensorFlow model into Vespa. + * - A set of signatures which are named collections of inputs and outputs. + * - A set of named constant tensors represented by Variable nodes in TensorFlow. + * - A list of warning messages. * * @author bratseth */ // This object can be built incrementally within this package, but is immutable when observed from outside the package -// TODO: Retain signature structure in ImportResult (input + output-expression bundles) public class ImportResult { - private final List<RankingExpression> expressions = new ArrayList<>(); - private final Map<String, Tensor> constants = new HashMap<>(); + private final Map<String, Signature> signatures = new HashMap<>(); private final Map<String, TensorType> arguments = new HashMap<>(); + private final Map<String, Tensor> constants = new HashMap<>(); private final List<String> warnings = new ArrayList<>(); - void add(RankingExpression expression) { expressions.add(expression); } - void set(String name, Tensor constant) { constants.put(name, constant); } - void set(String name, TensorType argument) { arguments.put(name, argument); } + void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); } + void constant(String name, Tensor constant) { constants.put(name, constant); } void warn(String warning) { warnings.add(warning); } - /** Returns an immutable list of the expressions of this */ - public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); } + /** Returns the given signature. If it does not already exist it is added to this. */ + Signature signature(String name) { + return signatures.computeIfAbsent(name, n -> new Signature(n)); + } + + /** Returns an immutable map of the arguments ("Placeholders") of this */ + public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } /** Returns an immutable map of the constants of this */ public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); } - /** Returns an immutable map of the arguments of this */ - public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); } + /** Returns an immutable map of the signatures of this */ + public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); } /** Returns an immutable list, in natural sort order of the warnings generated while importing this */ public List<String> warnings() { return warnings.stream().sorted().collect(Collectors.toList()); } + /** + * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types, + * and outputs maps to ranking expressions stemming from conversion of TensorFlow nodes and the inputs make up the + * context which is needed to evaluate the expression. + */ + public class Signature { + + private final String name; + private final Map<String, String> inputs = new HashMap<>(); + private final Map<String, RankingExpression> outputs = new HashMap<>(); + + Signature(String name) { + this.name = name; + } + + void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); } + void output(String name, RankingExpression expression) { outputs.put(name, expression); } + + /** Returns the result this is part of */ + ImportResult owner() { return ImportResult.this; } + + /** + * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name + * to argument (Placeholder) name in the owner of this + */ + public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); } + + /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */ + public TensorType inputType(String inputName) { return owner().arguments().get(inputs.get(inputName)); } + + /** Returns an immutable list of the expressions of this */ + public Map<String, RankingExpression> outputs() { return Collections.unmodifiableMap(outputs); } + + @Override + public String toString() { return "signature '" + name + "'"; } + + } + } |