diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-08 11:44:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-08 11:44:23 +0100 |
commit | f19b783d4014f799482daa13f8f8c26d5c4c84d9 (patch) | |
tree | 2679f94f20120b60317d54b7775e4de3d5ed4a2e /searchlib | |
parent | 692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (diff) |
Log OrderedTensorType of imported TensorFlow variables
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java | 31 |
1 files changed, 28 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 9ff88103f12..853a84ae226 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -8,7 +8,9 @@ import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.Dim import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const; import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; @@ -24,10 +26,12 @@ import org.tensorflow.framework.TensorInfo; import java.io.File; import java.io.IOException; +import java.util.Collection; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -38,6 +42,8 @@ import java.util.stream.Collectors; */ public class TensorFlowImporter { + private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); + /** * Imports a saved TensorFlow model from a directory. * The model should be saved as a .pbtxt or .pb file. @@ -83,6 +89,7 @@ public class TensorFlowImporter { importExpressions(model, index, bundle); reportWarnings(model, index); + logVariableTypes(index); return model; } @@ -268,7 +275,7 @@ public class TensorFlowImporter { List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); if (importedTensors.size() != 1) { throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " + - importedTensors.size()); + importedTensors.size()); } // Here we use the type from the operation, which will have correct dimension names after name resolving tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get()); @@ -308,7 +315,7 @@ public class TensorFlowImporter { } catch (ParseException e) { throw new RuntimeException("Tensorflow function " + function + - " cannot be parsed as a ranking expression", e); + " cannot be parsed as a ranking expression", e); } } } @@ -331,6 +338,22 @@ public class TensorFlowImporter { } } + /** + * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type. + * This allows users to learn the exact types (including dimension order after renaming) of the Variables + * such that these can be converted and fed to a parent document independently of the rest of the model + * for fast model weight updates. + */ + private static void logVariableTypes(OperationIndex index) { + for (TensorFlowOperation operation : index.operations()) { + if ( ! (operation instanceof Variable)) continue; + if ( ! operation.type().isPresent()) continue; // will not happen + + log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() + + " of type " + operation.type().get()); + } + } + private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) { for (String warning : operation.warnings()) { signature.importWarning(warning); @@ -364,12 +387,14 @@ public class TensorFlowImporter { return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1)); } - private static class OperationIndex { + private final Map<String, TensorFlowOperation> index = new HashMap<>(); public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); } public TensorFlowOperation get(String key) { return index.get(key); } public boolean alreadyImported(String key) { return index.containsKey(key); } + public Collection<TensorFlowOperation> operations() { return index.values(); } + } } |