summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <jonbratseth@yahoo.com>2018-01-17 13:51:14 +0100
committerGitHub <noreply@github.com>2018-01-17 13:51:14 +0100
commitfd26b36e3607df463b35e856b37d24b5e3514fb7 (patch)
tree403836969d050736403f6512a455198a2c63edad /searchlib
parentceec6d572c06ff812715c97d2c35383c48402f24 (diff)
parentc84b8f952ef5857aa44fad479551eda1f3a4e106 (diff)
Merge pull request #4692 from vespa-engine/bratseth/store-converted-expressions-in-zk
Bratseth/store converted expressions in zk
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java7
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java2
2 files changed, 8 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 42945c59105..45f2b21343f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -14,6 +14,7 @@ import org.tensorflow.framework.SignatureDef;
import org.tensorflow.framework.TensorInfo;
import org.tensorflow.framework.TensorShapeProto;
+import java.io.File;
import java.io.IOException;
import java.util.List;
import java.util.Map;
@@ -30,7 +31,7 @@ public class TensorFlowImporter {
/**
* Imports a saved TensorFlow model from a directory.
- * The model should be saved as a pbtxt file.
+ * The model should be saved as a .pbtxt or .pb file.
* The name of the model is taken as the db/pbtxt file name (not including the file ending).
*
* @param modelDir the directory containing the TensorFlow model files to import
@@ -44,6 +45,10 @@ public class TensorFlowImporter {
}
}
+ public TensorFlowModel importModel(File modelDir) {
+ return importModel(modelDir.toString());
+ }
+
/** Imports a TensorFlow model */
public TensorFlowModel importModel(SavedModelBundle model) {
try {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
index fd981a14c3e..9fdc45ab3bc 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowModel.java
@@ -69,6 +69,8 @@ public class TensorFlowModel {
void output(String name, String expressionName) { outputs.put(name, expressionName); }
void skippedOutput(String name, String reason) { skippedOutputs.put(name, reason); }
+ public String name() { return name; }
+
/** Returns the result this is part of */
TensorFlowModel owner() { return TensorFlowModel.this; }