summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java53
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java59
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java58
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java21
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java20
6 files changed, 184 insertions, 35 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
index 9ff88103f12..80a9262afeb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorFlowImporter.java
@@ -8,7 +8,9 @@ import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.Dim
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OperationMapper;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Const;
import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.TensorFlowOperation;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.operations.Variable;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
@@ -24,10 +26,12 @@ import org.tensorflow.framework.TensorInfo;
import java.io.File;
import java.io.IOException;
+import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -38,6 +42,8 @@ import java.util.stream.Collectors;
*/
public class TensorFlowImporter {
+ private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
+
/**
* Imports a saved TensorFlow model from a directory.
* The model should be saved as a .pbtxt or .pb file.
@@ -83,6 +89,7 @@ public class TensorFlowImporter {
importExpressions(model, index, bundle);
reportWarnings(model, index);
+ logVariableTypes(index);
return model;
}
@@ -233,7 +240,8 @@ public class TensorFlowImporter {
return operation.function();
}
- private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model, SavedModelBundle bundle) {
+ private static void importInputExpressions(TensorFlowOperation operation, TensorFlowModel model,
+ SavedModelBundle bundle) {
operation.inputs().forEach(input -> importExpression(input, model, bundle));
}
@@ -250,7 +258,8 @@ public class TensorFlowImporter {
}
}
- private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation, SavedModelBundle bundle) {
+ private static Optional<TensorFunction> importConstant(TensorFlowModel model, TensorFlowOperation operation,
+ SavedModelBundle bundle) {
String name = operation.vespaName();
if (model.largeConstants().containsKey(name) || model.smallConstants().containsKey(name)) {
return operation.function();
@@ -264,14 +273,9 @@ public class TensorFlowImporter {
}
tensor = value.asTensor();
} else {
- Session.Runner fetched = bundle.session().runner().fetch(operation.node().getName());
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1) {
- throw new IllegalStateException("Expected 1 tensor from fetching " + operation.node().getName() + ", but got " +
- importedTensors.size());
- }
// Here we use the type from the operation, which will have correct dimension names after name resolving
- tensor = TensorConverter.toVespaTensor(importedTensors.get(0), operation.type().get());
+ tensor = TensorConverter.toVespaTensor(readVariable(operation.node().getName(), bundle),
+ operation.type().get());
operation.setConstantValue(new TensorValue(tensor));
}
@@ -283,6 +287,15 @@ public class TensorFlowImporter {
return operation.function();
}
+ static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
+ Session.Runner fetched = bundle.session().runner().fetch(name);
+ List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
+ if (importedTensors.size() != 1)
+ throw new IllegalStateException("Expected 1 tensor from fetching " + name +
+ ", but got " + importedTensors.size());
+ return importedTensors.get(0);
+ }
+
private static void importRankingExpression(TensorFlowModel model, TensorFlowOperation operation) {
if (operation.function().isPresent()) {
String name = operation.node().getName();
@@ -308,7 +321,7 @@ public class TensorFlowImporter {
}
catch (ParseException e) {
throw new RuntimeException("Tensorflow function " + function +
- " cannot be parsed as a ranking expression", e);
+ " cannot be parsed as a ranking expression", e);
}
}
}
@@ -331,6 +344,22 @@ public class TensorFlowImporter {
}
}
+ /**
+ * Log all TensorFlow Variables (i.e file constants) imported as part of this with their ordered type.
+ * This allows users to learn the exact types (including dimension order after renaming) of the Variables
+ * such that these can be converted and fed to a parent document independently of the rest of the model
+ * for fast model weight updates.
+ */
+ private static void logVariableTypes(OperationIndex index) {
+ for (TensorFlowOperation operation : index.operations()) {
+ if ( ! (operation instanceof Variable)) continue;
+ if ( ! operation.type().isPresent()) continue; // will not happen
+
+ log.info("Importing TensorFlow variable " + operation.node().getName() + " as " + operation.vespaName() +
+ " of type " + operation.type().get());
+ }
+ }
+
private static void reportWarnings(TensorFlowOperation operation, TensorFlowModel.Signature signature) {
for (String warning : operation.warnings()) {
signature.importWarning(warning);
@@ -364,12 +393,14 @@ public class TensorFlowImporter {
return i < 0 ? 0 : Integer.parseInt(name.substring(i + 1));
}
-
private static class OperationIndex {
+
private final Map<String, TensorFlowOperation> index = new HashMap<>();
public TensorFlowOperation put(String key, TensorFlowOperation operation) { return index.put(key, operation); }
public TensorFlowOperation get(String key) { return index.get(key); }
public boolean alreadyImported(String key) { return index.containsKey(key); }
+ public Collection<TensorFlowOperation> operations() { return index.values(); }
+
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
new file mode 100644
index 00000000000..c5ac7ace0fc
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverter.java
@@ -0,0 +1,59 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.TensorConverter;
+import com.yahoo.tensor.serialization.JsonFormat;
+import com.yahoo.yolean.Exceptions;
+import org.tensorflow.SavedModelBundle;
+
+import java.nio.charset.StandardCharsets;
+
+/**
+ * Converts TensorFlow Variables to the Vespa document format.
+ * Intended to be used from the command line to convert trained tensors to document form.
+ *
+ * @author bratseth
+ */
+public class VariableConverter {
+
+ /**
+ * Reads the tensor with the given TensorFlow name at the given model location,
+ * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type.
+ * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
+ * tensor dimensions are implicitly ordered.
+ */
+ public static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
+ try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
+ return JsonFormat.encode(TensorConverter.toVespaTensor(TensorFlowImporter.readVariable(tensorFlowVariableName,
+ bundle),
+ OrderedTensorType.fromSpec(orderedTypeSpec)));
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
+ }
+ }
+
+ public static void main(String[] args) {
+ if ( args.length != 3) {
+ System.out.println("Converts a TensorFlow variable into Vespa tensor document field value JSON:");
+ System.out.println("A JSON map containing a 'cells' array, see");
+ System.out.println("http://docs.vespa.ai/documentation/reference/document-json-put-format.html#tensor)");
+ System.out.println("");
+ System.out.println("Arguments: modelDirectory tensorFlowVariableName orderedTypeSpec");
+ System.out.println(" - modelDirectory: The directory of the TensorFlow SavedModel");
+ System.out.println(" - tensorFlowVariableName: The name of the TensorFlow variable to convert");
+ System.out.println(" - orderedTypeSpec: The tensor type, e.g tensor(b[],a[10]), where dimensions are ");
+ System.out.println(" ordered as given in the deployment log message starting by ");
+ System.out.println(" 'Importing TensorFlow variable'");
+ return;
+ }
+
+ try {
+ System.out.println(new String(importVariable(args[0], args[1], args[2]), StandardCharsets.UTF_8));
+ }
+ catch (Exception e) {
+ System.err.println("Import failed: " + Exceptions.toMessageString(e));
+ }
+ }
+
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
index db762d5ddb0..03a65333192 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorTypeParser;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
@@ -119,23 +120,20 @@ public class OrderedTensorType {
return true;
}
- public static void verifyType(NodeDef node, OrderedTensorType type) {
- if (type == null) {
- return;
- }
+ public void verifyType(NodeDef node) {
TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null && type.type != null) {
- if (shape.getDimCount() != type.type.rank()) {
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
+ "does not match Vespa shape");
}
- for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) {
- int vespaIndex = type.dimensionMap[tensorFlowIndex];
+ for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
+ int vespaIndex = dimensionMap[tensorFlowIndex];
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
+ "does not match Vespa dimensions");
}
}
}
@@ -145,23 +143,23 @@ public class OrderedTensorType {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
if (attrValueList == null) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
+ "does not exist");
}
if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
+ "is not of expected type");
}
List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
return shapeList.get(0); // support multiple outputs?
}
- public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) {
- List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size());
- for (TensorType.Dimension dimension : type.dimensions) {
+ public OrderedTensorType rename(DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
+ for (TensorType.Dimension dimension : dimensions) {
String oldName = dimension.name();
Optional<String> newName = renamer.dimensionNameOf(oldName);
if (!newName.isPresent())
- return type; // presumably, already renamed
+ return this; // presumably, already renamed
TensorType.Dimension.Type dimensionType = dimension.type();
if (dimensionType == TensorType.Dimension.Type.indexedBound) {
renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
@@ -174,6 +172,23 @@ public class OrderedTensorType {
return new OrderedTensorType(renamedDimensions);
}
+ /**
+ * Returns a string representation of this: A standard tensor type string where dimensions
+ * are listed in the order of this rather than in the natural order of their names.
+ */
+ @Override
+ public String toString() {
+ return "tensor(" + dimensions.stream().map(TensorType.Dimension::toString).collect(Collectors.joining(",")) + ")";
+ }
+
+ /**
+ * Creates an instance from the string representation of this: A standard tensor type string
+ * where dimensions are listed in the order of this rather than the natural order of their names.
+ */
+ public static OrderedTensorType fromSpec(String typeSpec) {
+ return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ }
+
public static OrderedTensorType fromTensorFlowType(NodeDef node) {
return fromTensorFlowType(node, "d"); // standard naming convention: d0, d1, ...
}
@@ -210,20 +225,21 @@ public class OrderedTensorType {
if (size >= 0) {
if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
+ "dimension types");
}
if (!vespaDimension.size().isPresent()) {
throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
+ "not have a size");
}
if (vespaDimension.size().get() != size) {
throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get());
+ "dimension sizes. TensorFlow: " + size + " Vespa: " +
+ vespaDimension.size().get());
}
} else {
if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
+ "dimension types");
}
}
this.dimensions.add(vespaDimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index 5d711aac100..2533148e5be 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -60,7 +60,9 @@ public abstract class TensorFlowOperation {
if (type == null) {
type = lazyGetType();
}
- OrderedTensorType.verifyType(node, type);
+ if (type != null) {
+ type.verifyType(node);
+ }
return Optional.ofNullable(type);
}
@@ -96,7 +98,7 @@ public abstract class TensorFlowOperation {
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
/** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); }
+ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
/** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
public boolean isInput() { return false; }
@@ -131,7 +133,7 @@ public abstract class TensorFlowOperation {
}
if (inputs.size() != expected) {
throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
+ "for '" + node.getName() + "', got " + inputs.size());
}
return inputs.stream().map(func).allMatch(Optional::isPresent);
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
new file mode 100644
index 00000000000..beec2ab1ead
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/OrderedTensorTypeTestCase.java
@@ -0,0 +1,21 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import com.yahoo.searchlib.rankingexpression.integration.tensorflow.importer.OrderedTensorType;
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class OrderedTensorTypeTestCase {
+
+ @Test
+ public void testToFromSpec() {
+ String spec = "tensor(b[],c{},a[3])";
+ OrderedTensorType type = OrderedTensorType.fromSpec(spec);
+ assertEquals(spec, type.toString());
+ assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ }
+
+}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
new file mode 100644
index 00000000000..051c2c60c95
--- /dev/null
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/VariableConverterTestCase.java
@@ -0,0 +1,20 @@
+package com.yahoo.searchlib.rankingexpression.integration.tensorflow;
+
+import org.junit.Test;
+
+import java.nio.charset.StandardCharsets;
+
+import static org.junit.Assert.assertEquals;
+
+public class VariableConverterTestCase {
+
+ @Test
+ public void testConversion() {
+ byte[] converted = VariableConverter.importVariable("src/test/files/integration/tensorflow/mnist_softmax/saved",
+ "Variable_1",
+ "tensor(d0[10],d1[1])");
+ assertEquals("{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"1\",\"d1\":\"0\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"2\",\"d1\":\"0\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"3\",\"d1\":\"0\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"4\",\"d1\":\"0\"},\"value\":0.01795101352035999},{\"address\":{\"d0\":\"5\",\"d1\":\"0\"},\"value\":1.289906740188599},{\"address\":{\"d0\":\"6\",\"d1\":\"0\"},\"value\":-0.1038961559534073},{\"address\":{\"d0\":\"7\",\"d1\":\"0\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"8\",\"d1\":\"0\"},\"value\":-1.413674473762512},{\"address\":{\"d0\":\"9\",\"d1\":\"0\"},\"value\":-0.2573896050453186}]}",
+ new String(converted, StandardCharsets.UTF_8));
+ }
+
+}