diff options
author | Lester Solbakken <lesters@oath.com> | 2018-06-04 11:52:19 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-06-04 11:52:19 +0200 |
commit | 30ac849f0893b4d98e9392648a2f59e014d6f617 (patch) | |
tree | 2c1a540223867b0844c0e2a1c144c754bdf54e8c /searchlib | |
parent | 6d5e6caa958f3f5913922530f8656e4126d26817 (diff) |
Clean up a bit and add some clarification comments
Diffstat (limited to 'searchlib')
10 files changed, 59 insertions, 347 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index e6fa607539b..4b49f17f74e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -15,7 +15,6 @@ import java.util.regex.Pattern; * The result of importing a model (TensorFlow or ONNX) into Vespa. * * @author bratseth - * @author lesters */ public class ImportedModel { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java index dc70e694446..a658833b426 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java @@ -19,6 +19,15 @@ import java.util.Map; import java.util.Optional; import java.util.logging.Logger; +/** + * Base class for importing ML models (ONNX/TensorFlow) as native Vespa + * ranking expressions. The general mechanism for import is for the + * specific ML platform import implementations to create an + * IntermediateGraph. This class offers common code to convert the + * IntermediateGraph to Vespa ranking expressions and macros. + * + * @author lesters + */ public abstract class ModelImporter { private static final Logger log = Logger.getLogger(ModelImporter.class.getName()); @@ -32,6 +41,10 @@ public abstract class ModelImporter { return importModel(modelName, modelDir.toString()); } + /** + * Takes an IntermediateGraph and converts it to a ImportedModel containing + * the actual Vespa ranking expressions. + */ static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph) { ImportedModel model = new ImportedModel(graph.name()); @@ -40,7 +53,7 @@ public abstract class ModelImporter { importSignatures(graph, model); importExpressions(graph, model); reportWarnings(graph, model); - logVariableTypes(graph, model); + logVariableTypes(graph); return model; } @@ -192,9 +205,9 @@ public abstract class ModelImporter { } /** - * Convert intermediate representation to Vespa ranking expressions. + * Add any import warnings to the signature in the ImportedModel. */ - static void reportWarnings(IntermediateGraph graph, ImportedModel model) { + private static void reportWarnings(IntermediateGraph graph, ImportedModel model) { for (ImportedModel.Signature signature : model.signatures().values()) { for (String outputName : signature.outputs().values()) { reportWarnings(graph.get(outputName), model); @@ -217,11 +230,10 @@ public abstract class ModelImporter { * such that these can be converted and fed to a parent document independently of the rest of the model * for fast model weight updates. */ - private static void logVariableTypes(IntermediateGraph graph, ImportedModel model) { + private static void logVariableTypes(IntermediateGraph graph) { for (IntermediateOperation operation : graph.operations()) { if ( ! (operation instanceof Constant)) continue; if ( ! operation.type().isPresent()) continue; // will not happen - log.info("Importing TensorFlow variable " + operation.name() + " as " + operation.vespaName() + " of type " + operation.type().get()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java index b3f701eb88d..d3dd2a1d418 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java @@ -8,7 +8,6 @@ import onnx.Onnx; import java.io.FileInputStream; import java.io.IOException; -import java.util.logging.Logger; /** * Converts a ONNX model into a ranking expression and set of constants. @@ -17,8 +16,6 @@ import java.util.logging.Logger; */ public class OnnxImporter extends ModelImporter { - private static final Logger log = Logger.getLogger(OnnxImporter.class.getName()); - @Override public ImportedModel importModel(String modelName, String modelPath) { try (FileInputStream inputStream = new FileInputStream(modelPath)) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java index 3c77d026ec4..ff584559a83 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowImporter.java @@ -6,7 +6,6 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.tensorflow. import org.tensorflow.SavedModelBundle; import java.io.IOException; -import java.util.logging.Logger; /** * Converts a saved TensorFlow model into a ranking expression and set of constants. @@ -16,8 +15,6 @@ import java.util.logging.Logger; */ public class TensorFlowImporter extends ModelImporter { - private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); - /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. @@ -46,336 +43,5 @@ public class TensorFlowImporter extends ModelImporter { } } -// /** -// * Imports the TensorFlow graph by first importing the tensor types, then -// * finding a suitable set of dimensions names for each -// * placeholder/constant/variable, then importing the expressions. -// */ -// private static ImportedModel importGraph(String modelName, MetaGraphDef graph, SavedModelBundle bundle) { -// ImportedModel model = new ImportedModel(modelName); -// OperationIndex index = new OperationIndex(); -// -// importSignatures(graph, model); -// importNodes(graph, model, index); -// findDimensionNames(model, index); -// importExpressions(model, index, bundle); -// -// reportWarnings(model, index); -// logVariableTypes(index); -// -// return model; -// } - -// private static void importSignatures(MetaGraphDef graph, ImportedModel model) { -// for (Map.Entry<String, SignatureDef> signatureEntry : graph.getSignatureDefMap().entrySet()) { -// String signatureName = signatureEntry.getKey(); -// ImportedModel.Signature signature = model.signature(signatureName); -// -// Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); -// for (Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { -// String inputName = input.getKey(); -// signature.input(inputName, namePartOf(input.getValue().getName())); -// } -// -// Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); -// for (Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { -// String outputName = output.getKey(); -// signature.output(outputName, namePartOf(output.getValue().getName())); -// } -// } -// } -// -// private static boolean isSignatureInput(ImportedModel model, TensorFlowOperation operation) { -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String inputName : signature.inputs().values()) { -// if (inputName.equals(operation.node().getName())) { -// return true; -// } -// } -// } -// return false; -// } -// -// private static boolean isSignatureOutput(ImportedModel model, TensorFlowOperation operation) { -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String outputName : signature.outputs().values()) { -// if (outputName.equals(operation.node().getName())) { -// return true; -// } -// } -// } -// return false; -// } -// -// private static void importNodes(MetaGraphDef graph, ImportedModel model, OperationIndex index) { -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String outputName : signature.outputs().values()) { -// importNode(model.name(), outputName, graph.getGraphDef(), index); -// } -// } -// } -// -// private static TensorFlowOperation importNode(String modelName, String nodeName, GraphDef graph, OperationIndex index) { -// if (index.alreadyImported(nodeName)) { -// return index.get(nodeName); -// } -// NodeDef node = getTensorFlowNodeFromGraph(namePartOf(nodeName), graph); -// List<TensorFlowOperation> inputs = importNodeInputs(modelName, node, graph, index); -// TensorFlowOperation operation = OperationMapper.get(modelName, node, inputs, portPartOf(nodeName)); -// index.put(nodeName, operation); -// -// List<TensorFlowOperation> controlInputs = importControlInputs(modelName, node, graph, index); -// if (controlInputs.size() > 0) { -// operation.setControlInputs(controlInputs); -// } -// -// return operation; -// } -// -// private static List<TensorFlowOperation> importNodeInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { -// return node.getInputList().stream() -// .filter(name -> ! isControlDependency(name)) -// .map(nodeName -> importNode(modelName, nodeName, graph, index)) -// .collect(Collectors.toList()); -// } -// -// private static List<TensorFlowOperation> importControlInputs(String modelName, NodeDef node, GraphDef graph, OperationIndex index) { -// return node.getInputList().stream() -// .filter(nodeName -> isControlDependency(nodeName)) -// .map(nodeName -> importNode(modelName, nodeName, graph, index)) -// .collect(Collectors.toList()); -// } -// -// private static boolean isControlDependency(String name) { -// return name.startsWith("^"); -// } -// -// /** Find dimension names to avoid excessive renaming while evaluating the model. */ -// private static void findDimensionNames(ImportedModel model, OperationIndex index) { -// DimensionRenamer renamer = new DimensionRenamer(); -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String output : signature.outputs().values()) { -// addDimensionNameConstraints(index.get(output), renamer); -// } -// } -// renamer.solve(); -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String output : signature.outputs().values()) { -// renameDimensions(index.get(output), renamer); -// } -// } -// } -// -// private static void addDimensionNameConstraints(TensorFlowOperation operation, DimensionRenamer renamer) { -// if (operation.type().isPresent()) { -// operation.inputs().forEach(input -> addDimensionNameConstraints(input, renamer)); -// operation.addDimensionNameConstraints(renamer); -// } -// } -// -// private static void renameDimensions(TensorFlowOperation operation, DimensionRenamer renamer) { -// if (operation.type().isPresent()) { -// operation.inputs().forEach(input -> renameDimensions(input, renamer)); -// operation.renameDimensions(renamer); -// } -// } -// -// private static void importExpressions(ImportedModel model, OperationIndex index, SavedModelBundle bundle) { -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String outputName : signature.outputs().values()) { -// try { -// Optional<TensorFunction> function = importExpression(index.get(outputName), model, bundle); -// if (!function.isPresent()) { -// signature.skippedOutput(outputName, "No valid output function could be found."); -// } -// } -// catch (IllegalArgumentException e) { -// signature.skippedOutput(outputName, Exceptions.toMessageString(e)); -// } -// } -// } -// } -// -// private static Optional<TensorFunction> importExpression(TensorFlowOperation operation, ImportedModel model, SavedModelBundle bundle) { -// if (!operation.type().isPresent()) { -// return Optional.empty(); -// } -// if (operation.isConstant()) { -// return importConstant(model, operation, bundle); -// } -// -// importInputExpressions(operation, model, bundle); -// importRankingExpression(model, operation); -// importInputExpression(model, operation); -// importMacroExpression(model, operation); -// -// return operation.function(); -// } -// -// private static void importInputExpressions(TensorFlowOperation operation, ImportedModel model, -// SavedModelBundle bundle) { -// operation.inputs().forEach(input -> importExpression(input, model, bundle)); -// } -// -// private static void importMacroExpression(ImportedModel model, TensorFlowOperation operation) { -// if (operation.macro().isPresent()) { -// TensorFunction function = operation.macro().get(); -// try { -// model.macro(operation.macroName(), new RankingExpression(operation.macroName(), function.toString())); -// } -// catch (ParseException e) { -// throw new RuntimeException("Tensorflow function " + function + -// " cannot be parsed as a ranking expression", e); -// } -// } -// } -// -// private static Optional<TensorFunction> importConstant(ImportedModel model, TensorFlowOperation operation, -// SavedModelBundle bundle) { -// String name = operation.vespaName(); -// if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { -// return operation.function(); -// } -// -// Tensor tensor; -// if (operation.getConstantValue().isPresent()) { -// Value value = operation.getConstantValue().get(); -// if ( ! (value instanceof TensorValue)) { -// return operation.function(); // scalar values are inserted directly into the expression -// } -// tensor = value.asTensor(); -// } else { -// // Here we use the type from the operation, which will have correct dimension names after name resolving -// tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), -// operation.type().get()); -// operation.setConstantValue(new TensorValue(tensor)); -// } -// -// if (tensor.type().rank() == 0) { -// model.smallConstant(name, tensor); -// } else { -// model.largeConstant(name, tensor); -// } -// return operation.function(); -// } -// -// static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { -// Session.Runner fetched = bundle.session().runner().fetch(name); -// List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); -// if (importedTensors.size() != 1) -// throw new IllegalStateException("Expected 1 tensor from fetching " + name + -// ", but got " + importedTensors.size()); -// return importedTensors.get(0); -// } -// -// private static void importRankingExpression(ImportedModel model, TensorFlowOperation operation) { -// if (operation.function().isPresent()) { -// String name = operation.node().getName(); -// if (!model.expressions().containsKey(operation.node().getName())) { -// TensorFunction function = operation.function().get(); -// -// // Make sure output adheres to standard naming convention -// if (isSignatureOutput(model, operation)) { -// OrderedTensorType operationType = operation.type().get(); -// OrderedTensorType standardNamingType = OrderedTensorType.fromTensorFlowType(operation.node()); -// if ( ! operationType.equals(standardNamingType)) { -// List<String> renameFrom = operationType.dimensionNames(); -// List<String> renameTo = standardNamingType.dimensionNames(); -// function = new Rename(function, renameFrom, renameTo); -// } -// } -// -// try { -// // We add all intermediate nodes imported as separate expressions. Only -// // those referenced in a signature output will be used. We parse the -// // TensorFunction here to convert it to a RankingExpression tree. -// model.expression(name, new RankingExpression(name, function.toString())); -// } -// catch (ParseException e) { -// throw new RuntimeException("Tensorflow function " + function + -// " cannot be parsed as a ranking expression", e); -// } -// } -// } -// } -// -// private static void importInputExpression(ImportedModel model, TensorFlowOperation operation) { -// if (operation.isInput() && isSignatureInput(model, operation)) { -// // All inputs must have dimensions with standard naming convention: d0, d1, ... -// OrderedTensorType standardNamingConvention = OrderedTensorType.fromTensorFlowType(operation.node()); -// model.argument(operation.node().getName(), standardNamingConvention.type()); -// model.requiredMacro(operation.vespaName(), standardNamingConvention.type()); -// } -// } -// -// private static void reportWarnings(ImportedModel model, OperationIndex index) { -// for (ImportedModel.Signature signature : model.signatures().values()) { -// for (String output : signature.outputs().values()) { -// reportWarnings(index.get(output), signature); -// } -// } -// } -// -// /** -// * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. -// * This allows users to learn the exact types (including dimension order after renaming) of the Variables -// * such that these can be converted and fed to a parent document independently of the rest of the model -// * for fast model weight updates. -// */ -// private static void logVariableTypes(OperationIndex index) { -// for (TensorFlowOperation operation : index.operations()) { -// if ( ! (operation instanceof Variable)) continue; -// if ( ! operation.type().isPresent()) continue; // will not happen -// -// log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() + -// " of type " + operation.type().get()); -// } -// } -// -// private static void reportWarnings(TensorFlowOperation operation, ImportedModel.Signature signature) { -// for (String warning : operation.warnings()) { -// signature.importWarning(warning); -// } -// for (TensorFlowOperation input : operation.inputs()) { -// reportWarnings(input, signature); -// } -// } -// -// private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef graph) { -// for (NodeDef node : graph.getNodeList()) { -// if (node.getName().equals(name)) { -// return node; -// } -// } -// throw new IllegalArgumentException("Could not find node '" + name + "'"); -// } -// -// /** -// * A method signature input and output has the form name:index. -// * This returns the name part without the index. -// */ -// private static String namePartOf(String name) { -// name = name.startsWith("^") ? name.substring(1) : name; -// return name.split(":")[0]; -// } -// -// /** -// * This return the output port part. Indexes are used for nodes with -// * multiple outputs. -// */ -// private static int portPartOf(String name) { -// int i = name.indexOf(":"); -// return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); -// } -// -// private static class OperationIndex { -// -// private final Map<String, TensorFlowOperation> index = new HashMap<>(); -// public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } -// public TensorFlowOperation get(String key) { return index.get(key); } -// public boolean alreadyImported(String key) { return index.containsKey(key); } -// public Collection<TensorFlowOperation> operations() { return index.values(); } -// -// } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java index 2416d8697c1..470b04fb44a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/GraphImporter.java @@ -22,6 +22,12 @@ import onnx.Onnx; import java.util.List; import java.util.stream.Collectors; +/** + * Converts an ONNX graph to a Vespa IntermediateGraph which is the basis + * for generating Vespa ranking expressions. + * + * @author lesters + */ public class GraphImporter { public static IntermediateOperation mapOperation(Onnx.NodeProto node, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java index 8b25da7939e..715c55d8323 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/onnx/TypeConverter.java @@ -6,6 +6,11 @@ import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTens import com.yahoo.tensor.TensorType; import onnx.Onnx; +/** + * Converts and verifies ONNX tensor types into Vespa tensor types. + * + * @author lesters + */ public class TypeConverter { public static void verifyType(Onnx.TypeProto typeProto, OrderedTensorType type) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java index e24b2a828b5..43de29cedd5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/IntermediateOperation.java @@ -18,22 +18,33 @@ import java.util.List; import java.util.Optional; import java.util.function.Function; +/** + * Wraps an imported operation node and produces the respective Vespa tensor + * operation. During import, a graph of these operations are constructed. Then, + * the types are used to deduce sensible dimension names using the + * DimensionRenamer. After the types have been renamed, the proper Vespa + * expressions can be extracted. + * + * @author lesters + */ public abstract class IntermediateOperation { - protected final static String MACRO_PREFIX = "imported_ml_macro_"; + private final static String MACRO_PREFIX = "imported_ml_macro_"; protected final String name; protected final String modelName; protected final List<IntermediateOperation> inputs; protected final List<IntermediateOperation> outputs = new ArrayList<>(); - protected final List<String> importWarnings = new ArrayList<>(); protected OrderedTensorType type; protected TensorFunction function; protected TensorFunction macro = null; - protected Value constantValue = null; + + private final List<String> importWarnings = new ArrayList<>(); + private Value constantValue = null; + private List<IntermediateOperation> controlInputs = Collections.emptyList(); + protected Function<OrderedTensorType, Value> constantValueFunction = null; - protected List<IntermediateOperation> controlInputs = Collections.emptyList(); IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { this.name = name; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java index 6a05a7d9f9a..a815cbc3944 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/AttributeConverter.java @@ -14,6 +14,11 @@ import java.util.Map; import java.util.Optional; import java.util.stream.Collectors; +/** + * Converts TensorFlow node attributes to Vespa attribute values. + * + * @author lesters + */ public class AttributeConverter implements IntermediateOperation.AttributeMap { private final Map<String, AttrValue> attributeMap; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java index 3d70e551776..e1b292f9e61 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/GraphImporter.java @@ -37,6 +37,12 @@ import java.io.IOException; import java.util.List; import java.util.stream.Collectors; +/** + * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis + * for generating Vespa ranking expressions. + * + * @author lesters + */ public class GraphImporter { public static IntermediateOperation mapOperation(NodeDef node, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java index 596cb31bc37..67ad1edc312 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/tensorflow/TypeConverter.java @@ -10,6 +10,11 @@ import org.tensorflow.framework.TensorShapeProto; import java.util.List; +/** + * Converts and verifies TensorFlow tensor types into Vespa tensor types. + * + * @author lesters + */ public class TypeConverter { public static void verifyType(NodeDef node, OrderedTensorType type) { |