aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java70
1 files changed, 55 insertions, 15 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
index b4a9b363ade..b3c1708a0f4 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/ImportResult.java
@@ -12,40 +12,80 @@ import java.util.Map;
import java.util.stream.Collectors;
/**
- * The result of importing a TensorFlow model into Vespa:
- * - A list of ranking expressions reproducing the computations of the outputs in the TensorFlow model
- * - A list of named constant tensors
- * - A list of expected input tensors, with their tensor type
- * - A list of warning messages
+ * The result of importing a TensorFlow model into Vespa.
+ * - A set of signatures which are named collections of inputs and outputs.
+ * - A set of named constant tensors represented by Variable nodes in TensorFlow.
+ * - A list of warning messages.
*
* @author bratseth
*/
// This object can be built incrementally within this package, but is immutable when observed from outside the package
-// TODO: Retain signature structure in ImportResult (input + output-expression bundles)
public class ImportResult {
- private final List<RankingExpression> expressions = new ArrayList<>();
- private final Map<String, Tensor> constants = new HashMap<>();
+ private final Map<String, Signature> signatures = new HashMap<>();
private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, Tensor> constants = new HashMap<>();
private final List<String> warnings = new ArrayList<>();
- void add(RankingExpression expression) { expressions.add(expression); }
- void set(String name, Tensor constant) { constants.put(name, constant); }
- void set(String name, TensorType argument) { arguments.put(name, argument); }
+ void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void constant(String name, Tensor constant) { constants.put(name, constant); }
void warn(String warning) { warnings.add(warning); }
- /** Returns an immutable list of the expressions of this */
- public List<RankingExpression> expressions() { return Collections.unmodifiableList(expressions); }
+ /** Returns the given signature. If it does not already exist it is added to this. */
+ Signature signature(String name) {
+ return signatures.computeIfAbsent(name, n -> new Signature(n));
+ }
+
+ /** Returns an immutable map of the arguments ("Placeholders") of this */
+ public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
/** Returns an immutable map of the constants of this */
public Map<String, Tensor> constants() { return Collections.unmodifiableMap(constants); }
- /** Returns an immutable map of the arguments of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+ /** Returns an immutable map of the signatures of this */
+ public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
/** Returns an immutable list, in natural sort order of the warnings generated while importing this */
public List<String> warnings() {
return warnings.stream().sorted().collect(Collectors.toList());
}
+ /**
+ * A signature is a set of named inputs and outputs, where the inputs maps to argument ("placeholder") names+types,
+ * and outputs maps to ranking expressions stemming from conversion of TensorFlow nodes and the inputs make up the
+ * context which is needed to evaluate the expression.
+ */
+ public class Signature {
+
+ private final String name;
+ private final Map<String, String> inputs = new HashMap<>();
+ private final Map<String, RankingExpression> outputs = new HashMap<>();
+
+ Signature(String name) {
+ this.name = name;
+ }
+
+ void input(String inputName, String argumentName) { inputs.put(inputName, argumentName); }
+ void output(String name, RankingExpression expression) { outputs.put(name, expression); }
+
+ /** Returns the result this is part of */
+ ImportResult owner() { return ImportResult.this; }
+
+ /**
+ * Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
+ * to argument (Placeholder) name in the owner of this
+ */
+ public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
+
+ /** Returns owner().arguments().get(inputs.get(name)), e.g the type of the argument this input references */
+ public TensorType inputType(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+
+ /** Returns an immutable list of the expressions of this */
+ public Map<String, RankingExpression> outputs() { return Collections.unmodifiableMap(outputs); }
+
+ @Override
+ public String toString() { return "signature '" + name + "'"; }
+
+ }
+
}