summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-06-04 11:52:19 +0200
committerLester Solbakken <lesters@oath.com>2018-06-04 11:52:19 +0200
commit30ac849f0893b4d98e9392648a2f59e014d6f617 (patch)
tree2c1a540223867b0844c0e2a1c144c754bdf54e8c /searchlib
parent6d5e6caa958f3f5913922530f8656e4126d26817 (diff)
Clean up a bit and add some clarification comments
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java22
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java334
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java19
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java6
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java5
10 files changed, 59 insertions, 347 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index e6fa607539b..4b49f17f74e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -15,7 +15,6 @@ import java.util.regex.Pattern;
* The result of importing a model (TensorFlow or ONNX) into Vespa.
*
* @author bratseth
- * @author lesters
*/
public class ImportedModel {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index dc70e694446..a658833b426 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -19,6 +19,15 @@ import java.util.Map;
import java.util.Optional;
import java.util.logging.Logger;
+/**
+ * Base class for importing ML models (ONNX/TensorFlow) as native Vespa
+ * ranking expressions. The general mechanism for import is for the
+ * specific ML platform import implementations to create an
+ * IntermediateGraph. This class offers common code to convert the
+ * IntermediateGraph to Vespa ranking expressions and macros.
+ *
+ * @author lesters
+ */
public abstract class ModelImporter {
private static final Logger log = Logger.getLogger(ModelImporter.class.getName());
@@ -32,6 +41,10 @@ public abstract class ModelImporter {
return importModel(modelName, modelDir.toString());
}
+ /**
+ * Takes an IntermediateGraph and converts it to a ImportedModel containing
+ * the actual Vespa ranking expressions.
+ */
static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) {
ImportedModel model = new ImportedModel(graph.name());
@@ -40,7 +53,7 @@ public abstract class ModelImporter {
importSignatures(graph, model);
importExpressions(graph, model);
reportWarnings(graph, model);
- logVariableTypes(graph, model);
+ logVariableTypes(graph);
return model;
}
@@ -192,9 +205,9 @@ public abstract class ModelImporter {
}
/**
- * Convert intermediate representation to Vespa ranking expressions.
+ * Add any import warnings to the signature in the ImportedModel.
*/
- static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
+ private static void reportWarnings(IntermediateGraph graph, ImportedModel model) {
for (ImportedModel.Signature signature : model.signatures().values()) {
for (String outputName : signature.outputs().values()) {
reportWarnings(graph.get(outputName), model);
@@ -217,11 +230,10 @@ public abstract class ModelImporter {
* such that these can be converted and fed to a parent document independently of the rest of the model
* for fast model weight updates.
*/
- private static void logVariableTypes(IntermediateGraph graph, ImportedModel model) {
+ private static void logVariableTypes(IntermediateGraph graph) {
for (IntermediateOperation operation : graph.operations()) {
if ( ! (operation instanceof Constant)) continue;
if ( ! operation.type().isPresent()) continue; // will not happen
-
log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() +
" of type " + operation.type().get());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
index b3f701eb88d..d3dd2a1d418 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -8,7 +8,6 @@ import onnx.Onnx;
import java.io.FileInputStream;
import java.io.IOException;
-import java.util.logging.Logger;
/**
* Converts a ONNX model into a ranking expression and set of constants.
@@ -17,8 +16,6 @@ import java.util.logging.Logger;
*/
public class OnnxImporter extends ModelImporter {
- private static final Logger log = Logger.getLogger(OnnxImporter.class.getName());
-
@Override
public ImportedModel importModel(String modelName, String modelPath) {
try (FileInputStream inputStream = new FileInputStream(modelPath)) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
index 3c77d026ec4..ff584559a83 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java
@@ -6,7 +6,6 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow.
import org.tensorflow.SavedModelBundle;
import java.io.IOException;
-import java.util.logging.Logger;
/**
* Converts a saved TensorFlow model into a ranking expression and set of constants.
@@ -16,8 +15,6 @@ import java.util.logging.Logger;
*/
public class TensorFlowImporter extends ModelImporter {
- private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
-
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a .pbtxt or .pb file.
@@ -46,336 +43,5 @@ public class TensorFlowImporter extends ModelImporter {
}
}
-// /**
-// * Imports the TensorFlow graph by first importing the tensor types, then
-// * finding a suitable set of dimensions names for each
-// * placeholder/constant/variable, then importing the expressions.
-// */
-// private static ImportedModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) {
-// ImportedModel model = new ImportedModel(modelName);
-// OperationIndex index = new OperationIndex();
-//
-// importSignatures(graph, model);
-// importNodes(graph, model, index);
-// findDimensionNames(model, index);
-// importExpressions(model, index, bundle);
-//
-// reportWarnings(model, index);
-// logVariableTypes(index);
-//
-// return model;
-// }
-
-// private static void importSignatures(MetaGraphDef graph, ImportedModel model) {
-// for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) {
-// String signatureName = signatureEntry.getKey();
-// ImportedModel.Signature signature = model.signature(signatureName);
-//
-// Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
-// for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
-// String inputName = input.getKey();
-// signature.input(inputName, namePartOf(input.getValue().getName()));
-// }
-//
-// Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
-// for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
-// String outputName = output.getKey();
-// signature.output(outputName, namePartOf(output.getValue().getName()));
-// }
-// }
-// }
-//
-// private static boolean isSignatureInput(ImportedModel model, TensorFlowOperation operation) {
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String inputName : signature.inputs().values()) {
-// if (inputName.equals(operation.node().getName())) {
-// return true;
-// }
-// }
-// }
-// return false;
-// }
-//
-// private static boolean isSignatureOutput(ImportedModel model, TensorFlowOperation operation) {
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String outputName : signature.outputs().values()) {
-// if (outputName.equals(operation.node().getName())) {
-// return true;
-// }
-// }
-// }
-// return false;
-// }
-//
-// private static void importNodes(MetaGraphDef graph, ImportedModel model, OperationIndex index) {
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String outputName : signature.outputs().values()) {
-// importNode(model.name(), outputName, graph.getGraphDef(), index);
-// }
-// }
-// }
-//
-// private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) {
-// if (index.alreadyImported(nodeName)) {
-// return index.get(nodeName);
-// }
-// NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph);
-// List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index);
-// TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName));
-// index.put(nodeName, operation);
-//
-// List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index);
-// if (controlInputs.size() > 0) {
-// operation.setControlInputs(controlInputs);
-// }
-//
-// return operation;
-// }
-//
-// private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
-// return node.getInputList().stream()
-// .filter(name -> ! isControlDependency(name))
-// .map(nodeName -> importNode(modelName, nodeName, graph, index))
-// .collect(Collectors.toList());
-// }
-//
-// private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) {
-// return node.getInputList().stream()
-// .filter(nodeName -> isControlDependency(nodeName))
-// .map(nodeName -> importNode(modelName, nodeName, graph, index))
-// .collect(Collectors.toList());
-// }
-//
-// private static boolean isControlDependency(String name) {
-// return name.startsWith("^");
-// }
-//
-// /** Find dimension names to avoid excessive renaming while evaluating the model. */
-// private static void findDimensionNames(ImportedModel model, OperationIndex index) {
-// DimensionRenamer renamer = new DimensionRenamer();
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String output : signature.outputs().values()) {
-// addDimensionNameConstraints(index.get(output), renamer);
-// }
-// }
-// renamer.solve();
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String output : signature.outputs().values()) {
-// renameDimensions(index.get(output), renamer);
-// }
-// }
-// }
-//
-// private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) {
-// if (operation.type().isPresent()) {
-// operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer));
-// operation.addDimensionNameConstraints(renamer);
-// }
-// }
-//
-// private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) {
-// if (operation.type().isPresent()) {
-// operation.inputs().forEach(input -> renameDimensions(input, renamer));
-// operation.renameDimensions(renamer);
-// }
-// }
-//
-// private static void importExpressions(ImportedModel model, OperationIndex index, SavedModelBundle bundle) {
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String outputName : signature.outputs().values()) {
-// try {
-// Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle);
-// if (!function.isPresent()) {
-// signature.skippedOutput(outputName, "No valid output function could be found.");
-// }
-// }
-// catch (IllegalArgumentException e) {
-// signature.skippedOutput(outputName, Exceptions.toMessageString(e));
-// }
-// }
-// }
-// }
-//
-// private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, ImportedModel model, SavedModelBundle bundle) {
-// if (!operation.type().isPresent()) {
-// return Optional.empty();
-// }
-// if (operation.isConstant()) {
-// return importConstant(model, operation, bundle);
-// }
-//
-// importInputExpressions(operation, model, bundle);
-// importRankingExpression(model, operation);
-// importInputExpression(model, operation);
-// importMacroExpression(model, operation);
-//
-// return operation.function();
-// }
-//
-// private static void importInputExpressions(TensorFlowOperation operation, ImportedModel model,
-// SavedModelBundle bundle) {
-// operation.inputs().forEach(input -> importExpression(input, model, bundle));
-// }
-//
-// private static void importMacroExpression(ImportedModel model, TensorFlowOperation operation) {
-// if (operation.macro().isPresent()) {
-// TensorFunction function = operation.macro().get();
-// try {
-// model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString()));
-// }
-// catch (ParseException e) {
-// throw new RuntimeException("Tensorflow function " + function +
-// " cannot be parsed as a ranking expression", e);
-// }
-// }
-// }
-//
-// private static Optional<TensorFunction> importConstant(ImportedModel model, TensorFlowOperation operation,
-// SavedModelBundle bundle) {
-// String name = operation.vespaName();
-// if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
-// return operation.function();
-// }
-//
-// Tensor tensor;
-// if (operation.getConstantValue().isPresent()) {
-// Value value = operation.getConstantValue().get();
-// if ( ! (value instanceof TensorValue)) {
-// return operation.function(); // scalar values are inserted directly into the expression
-// }
-// tensor = value.asTensor();
-// } else {
-// // Here we use the type from the operation, which will have correct dimension names after name resolving
-// tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
-// operation.type().get());
-// operation.setConstantValue(new TensorValue(tensor));
-// }
-//
-// if (tensor.type().rank() == 0) {
-// model.smallConstant(name, tensor);
-// } else {
-// model.largeConstant(name, tensor);
-// }
-// return operation.function();
-// }
-//
-// static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
-// Session.Runner fetched = bundle.session().runner().fetch(name);
-// List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
-// if (importedTensors.size() != 1)
-// throw new IllegalStateException("Expected 1 tensor from fetching " + name +
-// ", but got " + importedTensors.size());
-// return importedTensors.get(0);
-// }
-//
-// private static void importRankingExpression(ImportedModel model, TensorFlowOperation operation) {
-// if (operation.function().isPresent()) {
-// String name = operation.node().getName();
-// if (!model.expressions().containsKey(operation.node().getName())) {
-// TensorFunction function = operation.function().get();
-//
-// // Make sure output adheres to standard naming convention
-// if (isSignatureOutput(model, operation)) {
-// OrderedTensorType operationType = operation.type().get();
-// OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node());
-// if ( ! operationType.equals(standardNamingType)) {
-// List<String> renameFrom = operationType.dimensionNames();
-// List<String> renameTo = standardNamingType.dimensionNames();
-// function = new Rename(function, renameFrom, renameTo);
-// }
-// }
-//
-// try {
-// // We add all intermediate nodes imported as separate expressions. Only
-// // those referenced in a signature output will be used. We parse the
-// // TensorFunction here to convert it to a RankingExpression tree.
-// model.expression(name, new RankingExpression(name, function.toString()));
-// }
-// catch (ParseException e) {
-// throw new RuntimeException("Tensorflow function " + function +
-// " cannot be parsed as a ranking expression", e);
-// }
-// }
-// }
-// }
-//
-// private static void importInputExpression(ImportedModel model, TensorFlowOperation operation) {
-// if (operation.isInput() && isSignatureInput(model, operation)) {
-// // All inputs must have dimensions with standard naming convention: d0, d1, ...
-// OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node());
-// model.argument(operation.node().getName(), standardNamingConvention.type());
-// model.requiredMacro(operation.vespaName(), standardNamingConvention.type());
-// }
-// }
-//
-// private static void reportWarnings(ImportedModel model, OperationIndex index) {
-// for (ImportedModel.Signature signature : model.signatures().values()) {
-// for (String output : signature.outputs().values()) {
-// reportWarnings(index.get(output), signature);
-// }
-// }
-// }
-//
-// /**
-// * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
-// * This allows users to learn the exact types (including dimension order after renaming) of the Variables
-// * such that these can be converted and fed to a parent document independently of the rest of the model
-// * for fast model weight updates.
-// */
-// private static void logVariableTypes(OperationIndex index) {
-// for (TensorFlowOperation operation : index.operations()) {
-// if ( ! (operation instanceof Variable)) continue;
-// if ( ! operation.type().isPresent()) continue; // will not happen
-//
-// log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
-// " of type " + operation.type().get());
-// }
-// }
-//
-// private static void reportWarnings(TensorFlowOperation operation, ImportedModel.Signature signature) {
-// for (String warning : operation.warnings()) {
-// signature.importWarning(warning);
-// }
-// for (TensorFlowOperation input : operation.inputs()) {
-// reportWarnings(input, signature);
-// }
-// }
-//
-// private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) {
-// for (NodeDef node : graph.getNodeList()) {
-// if (node.getName().equals(name)) {
-// return node;
-// }
-// }
-// throw new IllegalArgumentException("Could not find node '" + name + "'");
-// }
-//
-// /**
-// * A method signature input and output has the form name:index.
-// * This returns the name part without the index.
-// */
-// private static String namePartOf(String name) {
-// name = name.startsWith("^") ? name.substring(1) : name;
-// return name.split(":")[0];
-// }
-//
-// /**
-// * This return the output port part. Indexes are used for nodes with
-// * multiple outputs.
-// */
-// private static int portPartOf(String name) {
-// int i = name.indexOf(":");
-// return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
-// }
-//
-// private static class OperationIndex {
-//
-// private final Map<String, TensorFlowOperation> index = new HashMap<>();
-// public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
-// public TensorFlowOperation get(String key) { return index.get(key); }
-// public boolean alreadyImported(String key) { return index.containsKey(key); }
-// public Collection<TensorFlowOperation> operations() { return index.values(); }
-//
-// }
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
index 2416d8697c1..470b04fb44a 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java
@@ -22,6 +22,12 @@ import onnx.Onnx;
import java.util.List;
import java.util.stream.Collectors;
+/**
+ * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
public class GraphImporter {
public static IntermediateOperation mapOperation(Onnx.NodeProto node,
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
index 8b25da7939e..715c55d8323 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java
@@ -6,6 +6,11 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTens
import com.yahoo.tensor.TensorType;
import onnx.Onnx;
+/**
+ * Converts and verifies ONNX tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
public class TypeConverter {
public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
index e24b2a828b5..43de29cedd5 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java
@@ -18,22 +18,33 @@ import java.util.List;
import java.util.Optional;
import java.util.function.Function;
+/**
+ * Wraps an imported operation node and produces the respective Vespa tensor
+ * operation. During import, a graph of these operations are constructed. Then,
+ * the types are used to deduce sensible dimension names using the
+ * DimensionRenamer. After the types have been renamed, the proper Vespa
+ * expressions can be extracted.
+ *
+ * @author lesters
+ */
public abstract class IntermediateOperation {
- protected final static String MACRO_PREFIX = "imported_ml_macro_";
+ private final static String MACRO_PREFIX = "imported_ml_macro_";
protected final String name;
protected final String modelName;
protected final List<IntermediateOperation> inputs;
protected final List<IntermediateOperation> outputs = new ArrayList<>();
- protected final List<String> importWarnings = new ArrayList<>();
protected OrderedTensorType type;
protected TensorFunction function;
protected TensorFunction macro = null;
- protected Value constantValue = null;
+
+ private final List<String> importWarnings = new ArrayList<>();
+ private Value constantValue = null;
+ private List<IntermediateOperation> controlInputs = Collections.emptyList();
+
protected Function<OrderedTensorType, Value> constantValueFunction = null;
- protected List<IntermediateOperation> controlInputs = Collections.emptyList();
IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) {
this.name = name;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
index 6a05a7d9f9a..a815cbc3944 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java
@@ -14,6 +14,11 @@ import java.util.Map;
import java.util.Optional;
import java.util.stream.Collectors;
+/**
+ * Converts TensorFlow node attributes to Vespa attribute values.
+ *
+ * @author lesters
+ */
public class AttributeConverter implements IntermediateOperation.AttributeMap {
private final Map<String, AttrValue> attributeMap;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
index 3d70e551776..e1b292f9e61 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java
@@ -37,6 +37,12 @@ import java.io.IOException;
import java.util.List;
import java.util.stream.Collectors;
+/**
+ * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
+ * for generating Vespa ranking expressions.
+ *
+ * @author lesters
+ */
public class GraphImporter {
public static IntermediateOperation mapOperation(NodeDef node,
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
index 596cb31bc37..67ad1edc312 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java
@@ -10,6 +10,11 @@ import org.tensorflow.framework.TensorShapeProto;
import java.util.List;
+/**
+ * Converts and verifies TensorFlow tensor types into Vespa tensor types.
+ *
+ * @author lesters
+ */
public class TypeConverter {
public static void verifyType(NodeDef node, OrderedTensorType type) {