diff options
Diffstat (limited to 'searchlib/src/main')
3 files changed, 75 insertions, 10 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java index 853a84ae226..80a9262afeb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java @@ -240,7 +240,8 @@ public class TensorFlowImporter { return operation.function(); } - private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) { + private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, + SavedModelBundle bundle) { operation.inputs().forEach(input -> importExpression(input, model, bundle)); } @@ -257,7 +258,8 @@ public class TensorFlowImporter { } } - private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) { + private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, + SavedModelBundle bundle) { String name = operation.vespaName(); if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) { return operation.function(); @@ -271,14 +273,9 @@ public class TensorFlowImporter { } tensor = value.asTensor(); } else { - Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName()); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) { - throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " + - importedTensors.size()); - } // Here we use the type from the operation, which will have correct dimension names after name resolving - tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get()); + tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle), + operation.type().get()); operation.setConstantValue(new TensorValue(tensor)); } @@ -290,6 +287,15 @@ public class TensorFlowImporter { return operation.function(); } + static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { + Session.Runner fetched = bundle.session().runner().fetch(name); + List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); + if (importedTensors.size() != 1) + throw new IllegalStateException("Expected 1 tensor from fetching " + name + + ", but got " + importedTensors.size()); + return importedTensors.get(0); + } + private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) { if (operation.function().isPresent()) { String name = operation.node().getName(); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java new file mode 100644 index 00000000000..c5ac7ace0fc --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java @@ -0,0 +1,59 @@ +package com.yahoo.searchlib.rankingexpression.integration.tensorflow; + +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter; +import com.yahoo.tensor.serialization.JsonFormat; +import com.yahoo.yolean.Exceptions; +import org.tensorflow.SavedModelBundle; + +import java.nio.charset.StandardCharsets; + +/** + * Converts TensorFlow Variables to the Vespa document format. + * Intended to be used from the command line to convert trained tensors to document form. + * + * @author bratseth + */ +public class VariableConverter { + + /** + * Reads the tensor with the given TensorFlow name at the given model location, + * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type. + * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor + * tensor dimensions are implicitly ordered. + */ + public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { + try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { + return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName, + bundle), + OrderedTensorType.fromSpec(orderedTypeSpec))); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); + } + } + + public static void main(String[] args) { + if ( args.length != 3) { + System.out.println("Converts a TensorFlow variable into Vespa tensor document field value JSON:"); + System.out.println("A JSON map containing a 'cells' array, see"); + System.out.println("http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor)"); + System.out.println(""); + System.out.println("Arguments: modelDirectory tensorFlowVariableName orderedTypeSpec"); + System.out.println(" - modelDirectory: The directory of the TensorFlow SavedModel"); + System.out.println(" - tensorFlowVariableName: The name of the TensorFlow variable to convert"); + System.out.println(" - orderedTypeSpec: The tensor type, e.g tensor(b[],a[10]), where dimensions are "); + System.out.println(" ordered as given in the deployment log message starting by "); + System.out.println(" 'Importing TensorFlow variable'"); + return; + } + + try { + System.out.println(new String(importVariable(args[0], args[1], args[2]), StandardCharsets.UTF_8)); + } + catch (Exception e) { + System.err.println("Import failed: " + Exceptions.toMessageString(e)); + } + } + +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java index 3a6e5e6ebe9..03a65333192 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -186,7 +186,7 @@ public class OrderedTensorType { * where dimensions are listed in the order of this rather than the natural order of their names. */ public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.fromSpec(typeSpec)); + return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); } public static OrderedTensorType fromTensorFlowType(NodeDef node) { |