summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-08 11:44:23 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-08 11:44:23 +0100
commitf19b783d4014f799482daa13f8f8c26d5c4c84d9 (patch)
tree2679f94f20120b60317d54b7775e4de3d5ed4a2e /searchlib
parent692d43c3c85352c8f8e40615ce37aa3d2f83b5d3 (diff)
Log OrderedTensorType of imported TensorFlow variables
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java31
1 files changed, 28 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 9ff88103f12..853a84ae226 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -8,7 +8,9 @@ import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.Dim
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
@@ -24,10 +26,12 @@ import org.tensorflow.framework.TensorInfo;
import java.io.File;
import java.io.IOException;
+import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -38,6 +42,8 @@ import java.util.stream.Collectors;
*/
public class TensorFlowImporter {
+ private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
+
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a .pbtxt or .pb file.
@@ -83,6 +89,7 @@ public class TensorFlowImporter {
importExpressions(model, index, bundle);
reportWarnings(model, index);
+ logVariableTypes(index);
return model;
}
@@ -268,7 +275,7 @@ public class TensorFlowImporter {
List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
if (importedTensors.size() != 1) {
throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " +
- importedTensors.size());
+ importedTensors.size());
}
// Here we use the type from the operation, which will have correct dimension names after name resolving
tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get());
@@ -308,7 +315,7 @@ public class TensorFlowImporter {
}
catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
+ " cannot be parsed as a ranking expression", e);
}
}
}
@@ -331,6 +338,22 @@ public class TensorFlowImporter {
}
}
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(OperationIndex index) {
+ for (TensorFlowOperation operation : index.operations()) {
+ if ( ! (operation instanceof Variable)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+
+ log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
for (String warning : operation.warnings()) {
signature.importWarning(warning);
@@ -364,12 +387,14 @@ public class TensorFlowImporter {
return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
}
-
private static class OperationIndex {
+
private final Map<String, TensorFlowOperation> index = new HashMap<>();
public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
public TensorFlowOperation get(String key) { return index.get(key); }
public boolean alreadyImported(String key) { return index.containsKey(key); }
+ public Collection<TensorFlowOperation> operations() { return index.values(); }
+
}
}