summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-10-04 13:38:42 +0200
committerLester Solbakken <lesters@oath.com>2021-10-04 13:38:42 +0200
commit765571b29407071271806bf95766f100dde82dd8 (patch)
tree573e62f6382dd5cec386b3c74b31acb8d87cf7f8 /model-integration/src/main/java/ai
parentfed386f182a0b600a72b333bb308c15870c3f04e (diff)
Remove Java dependencies to tensorflow
Diffstat (limited to 'model-integration/src/main/java/ai')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java87
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java254
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java238
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java128
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java59
6 files changed, 1 insertions, 780 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
deleted file mode 100644
index f2c6dfd9069..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java
+++ /dev/null
@@ -1,87 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.rankingexpression.importer.tensorflow;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
-import ai.vespa.rankingexpression.importer.OrderedTensorType;
-import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.NodeDef;
-
-import java.util.List;
-import java.util.Map;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * Converts TensorFlow node attributes to Vespa attribute values.
- *
- * @author lesters
- */
-class AttributeConverter implements IntermediateOperation.AttributeMap {
-
- private final Map<String, AttrValue> attributeMap;
-
- private AttributeConverter(NodeDef node) {
- attributeMap = node.getAttrMap();
- }
-
- static AttributeConverter convert(NodeDef node) {
- return new AttributeConverter(node);
- }
-
- @Override
- public Optional<Value> get(String key) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.empty(); // requires type
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.B) {
- return Optional.of(new BooleanValue(attrValue.getB()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.I) {
- return Optional.of(new DoubleValue(attrValue.getI()));
- }
- if (attrValue.getValueCase() == AttrValue.ValueCase.F) {
- return Optional.of(new DoubleValue(attrValue.getF()));
- }
- }
- return Optional.empty();
- }
-
- @Override
- public Optional<Value> get(String key, OrderedTensorType type) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) {
- return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type)));
- }
- }
- return get(key);
- }
-
- @Override
- public Optional<List<Value>> getList(String key) {
- if (attributeMap.containsKey(key)) {
- AttrValue attrValue = attributeMap.get(key);
- if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) {
- AttrValue.ListValue listValue = attrValue.getList();
- if ( ! listValue.getBList().isEmpty()) {
- return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList()));
- }
- if ( ! listValue.getIList().isEmpty()) {
- return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList()));
- }
- if ( ! listValue.getFList().isEmpty()) {
- return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList()));
- }
- // add the rest
- }
- }
- return Optional.empty();
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
deleted file mode 100644
index 0d2ba0cc714..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java
+++ /dev/null
@@ -1,254 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.rankingexpression.importer.tensorflow;
-
-import ai.vespa.rankingexpression.importer.operations.Softmax;
-import ai.vespa.rankingexpression.importer.operations.Sum;
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import ai.vespa.rankingexpression.importer.IntermediateGraph;
-import ai.vespa.rankingexpression.importer.OrderedTensorType;
-import ai.vespa.rankingexpression.importer.operations.Argument;
-import ai.vespa.rankingexpression.importer.operations.ConcatV2;
-import ai.vespa.rankingexpression.importer.operations.Const;
-import ai.vespa.rankingexpression.importer.operations.Constant;
-import ai.vespa.rankingexpression.importer.operations.ExpandDims;
-import ai.vespa.rankingexpression.importer.operations.Identity;
-import ai.vespa.rankingexpression.importer.operations.IntermediateOperation;
-import ai.vespa.rankingexpression.importer.operations.Join;
-import ai.vespa.rankingexpression.importer.operations.Map;
-import ai.vespa.rankingexpression.importer.operations.MatMul;
-import ai.vespa.rankingexpression.importer.operations.Mean;
-import ai.vespa.rankingexpression.importer.operations.Merge;
-import ai.vespa.rankingexpression.importer.operations.NoOp;
-import ai.vespa.rankingexpression.importer.operations.PlaceholderWithDefault;
-import ai.vespa.rankingexpression.importer.operations.Reshape;
-import ai.vespa.rankingexpression.importer.operations.Select;
-import ai.vespa.rankingexpression.importer.operations.Shape;
-import ai.vespa.rankingexpression.importer.operations.Squeeze;
-import ai.vespa.rankingexpression.importer.operations.Switch;
-import com.yahoo.tensor.functions.ScalarFunctions;
-import org.tensorflow.SavedModelBundle;
-import org.tensorflow.Session;
-import org.tensorflow.framework.GraphDef;
-import org.tensorflow.framework.MetaGraphDef;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.SignatureDef;
-import org.tensorflow.framework.TensorInfo;
-
-import java.io.IOException;
-import java.util.List;
-import java.util.stream.Collectors;
-
-/**
- * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis
- * for generating Vespa ranking expressions.
- *
- * @author lesters
- */
-class GraphImporter {
-
- private static IntermediateOperation mapOperation(NodeDef node,
- List<IntermediateOperation> inputs,
- IntermediateGraph graph) {
- String nodeName = node.getName();
- String modelName = graph.name();
- int nodePort = IntermediateOperation.indexPartOf(nodeName);
- OrderedTensorType nodeType = TypeConverter.typeFrom(node);
- AttributeConverter attributes = AttributeConverter.convert(node);
-
- switch (node.getOp().toLowerCase()) {
- // array ops
- case "concatv2": return new ConcatV2(modelName, nodeName, inputs);
- case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType);
- case "expanddims": return new ExpandDims(modelName, nodeName, inputs);
- case "identity": return new Identity(modelName, nodeName, inputs);
- case "placeholder": return new Argument(modelName, nodeName, nodeType);
- case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs);
- case "reshape": return new Reshape(modelName, nodeName, inputs, attributes);
- case "shape": return new Shape(modelName, nodeName, inputs);
- case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes);
-
- // control flow
- case "merge": return new Merge(modelName, nodeName, inputs);
- case "switch": return new Switch(modelName, nodeName, inputs, nodePort);
-
- // math ops
- case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs());
- case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos());
- case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin());
- case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan());
- case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil());
- case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos());
- case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp());
- case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide());
- case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor());
- case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log());
- case "matmul": return new MatMul(modelName, nodeName, inputs);
- case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max());
- case "mean": return new Mean(modelName, nodeName, inputs, attributes);
- case "reducemean": return new Mean(modelName, nodeName, inputs, attributes);
- case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply());
- case "negate": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg());
- case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal());
- case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt());
- case "select": return new Select(modelName, nodeName, inputs);
- case "where3": return new Select(modelName, nodeName, inputs);
- case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid());
- case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin());
- case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference());
- case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract());
- case "sum": return new Sum(modelName, nodeName, inputs, attributes);
- case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square());
- case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt());
- case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan());
- case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh());
-
- // nn ops
- case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add());
- case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu());
- case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu());
- case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu());
- case "softmax": return new Softmax(modelName, nodeName, inputs, attributes);
-
- // state ops
- case "variable": return new Constant(modelName, nodeName, nodeType);
- case "variablev2": return new Constant(modelName, nodeName, nodeType);
- case "varhandleop": return new Constant(modelName, nodeName, nodeType);
- case "readvariableop":return new Identity(modelName, nodeName, inputs);
-
- // evaluation no-ops
- case "stopgradient":return new Identity(modelName, nodeName, inputs);
- case "noop": return new NoOp(modelName, nodeName, inputs);
-
- }
-
- IntermediateOperation op = new NoOp(modelName, node.getName(), inputs);
- op.warning("Operation '" + node.getOp() + "' is currently not implemented");
- return op;
- }
-
- static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException {
- MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef());
-
- IntermediateGraph intermediateGraph = new IntermediateGraph(modelName);
- importSignatures(tfGraph, intermediateGraph);
- importOperations(tfGraph, intermediateGraph, bundle);
- verifyOutputTypes(tfGraph, intermediateGraph);
-
- return intermediateGraph;
- }
-
- private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
- for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) {
- String signatureName = signatureEntry.getKey();
- java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap();
- for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) {
- String inputName = input.getKey();
- String nodeName = input.getValue().getName();
- intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName));
- }
- java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap();
- for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) {
- String outputName = output.getKey();
- String nodeName = output.getValue().getName();
- intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName));
- }
- }
- }
-
- private static void importOperations(MetaGraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- for (String signatureName : intermediateGraph.signatures()) {
- for (String outputName : intermediateGraph.outputs(signatureName).values()) {
- importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle);
- }
- }
- }
-
- private static IntermediateOperation importOperation(String nodeName,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- if (intermediateGraph.alreadyImported(nodeName)) {
- return intermediateGraph.get(nodeName);
- }
- NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph);
- List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle);
- IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph);
- intermediateGraph.put(nodeName, operation);
-
- List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle);
- if (controlInputs.size() > 0) {
- operation.setControlInputs(controlInputs);
- }
-
- if (operation.isConstant()) {
- operation.setConstantValueFunction(
- type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type)));
- }
-
- return operation;
- }
-
- private static List<IntermediateOperation> importOperationInputs(NodeDef node,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- return node.getInputList().stream()
- .filter(name -> ! isControlDependency(name))
- .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
- .collect(Collectors.toList());
- }
-
- private static List<IntermediateOperation> importControlInputs(NodeDef node,
- GraphDef tfGraph,
- IntermediateGraph intermediateGraph,
- SavedModelBundle bundle) {
- return node.getInputList().stream()
- .filter(nodeName -> isControlDependency(nodeName))
- .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle))
- .collect(Collectors.toList());
- }
-
- private static boolean isControlDependency(String name) {
- return name.startsWith("^");
- }
-
- private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) {
- for (NodeDef node : tfGraph.getNodeList()) {
- if (node.getName().equals(name)) {
- return node;
- }
- }
- throw new IllegalArgumentException("Could not find node '" + name + "'");
- }
-
- static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) {
- Session.Runner fetched = bundle.session().runner().fetch(name);
- List<org.tensorflow.Tensor<?>> importedTensors = fetched.run();
- if (importedTensors.size() != 1)
- throw new IllegalStateException("Expected 1 tensor from fetching " + name +
- ", but got " + importedTensors.size());
- return importedTensors.get(0);
- }
-
- private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) {
- for (String signatureName : intermediateGraph.signatures()) {
- for (String outputName : intermediateGraph.outputs(signatureName).values()) {
- IntermediateOperation operation = intermediateGraph.get(outputName);
- NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef());
- OrderedTensorType type = operation.type().orElseThrow(
- () -> new IllegalArgumentException("Output of '" + outputName + "' has no type."));
- TypeConverter.verifyType(node, type);
- }
- }
-
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
deleted file mode 100644
index 95727acb5b4..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ /dev/null
@@ -1,238 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.rankingexpression.importer.tensorflow;
-
-import ai.vespa.rankingexpression.importer.OrderedTensorType;
-import com.yahoo.tensor.IndexedTensor;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.DataType;
-import org.tensorflow.framework.TensorProto;
-import org.tensorflow.framework.TensorShapeProto;
-
-import java.nio.ByteBuffer;
-import java.nio.DoubleBuffer;
-import java.nio.FloatBuffer;
-import java.nio.IntBuffer;
-import java.nio.LongBuffer;
-import java.util.List;
-
-/**
- * Converts TensorFlow tensors into Vespa tensors.
- *
- * @author bratseth
- * @author lesters
- */
-public class TensorConverter {
-
- public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) {
- return toVespaTensor(tfTensor, "d");
- }
-
- private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix);
- Values values = readValuesOf(tfTensor);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
- for (int i = 0; i < values.size(); i++)
- builder.cellByDirectIndex(i, values.get(i));
- return builder.build();
- }
-
- static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) {
- Values values = readValuesOf(tfTensor);
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type());
- for (int i = 0; i < values.size(); i++) {
- builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i));
- }
- return builder.build();
- }
-
- static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) {
- Values values = readValuesOf(tensorProto);
- if (values.size() == 0) { // Might be stored as "tensor_content" instead
- return toVespaTensor(readTensorContentOf(tensorProto), type);
- }
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type());
- for (int i = 0; i < values.size(); ++i)
- builder.cellByDirectIndex(i, values.get(i));
- return builder.build();
- }
-
- public static Long tensorSize(TensorType type) {
- Long size = 1L;
- for (TensorType.Dimension dimension : type.dimensions()) {
- size *= dimensionSize(dimension);
- }
- return size;
- }
-
- private static Long dimensionSize(TensorType.Dimension dim) {
- return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size"));
- }
-
- private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) {
- switch (tfTensor.dataType()) {
- case DOUBLE: return new DoubleValues(tfTensor);
- case FLOAT: return new FloatValues(tfTensor);
- case BOOL: return new BoolValues(tfTensor);
- case UINT8: return new IntValues(tfTensor);
- case INT32: return new IntValues(tfTensor);
- case INT64: return new LongValues(tfTensor);
- default: throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
- }
- }
-
- private static Values readValuesOf(TensorProto tensorProto) {
- switch (tensorProto.getDtype()) {
- case DT_BOOL: return new ProtoBoolValues(tensorProto);
- case DT_HALF: return new ProtoHalfValues(tensorProto);
- case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto);
- case DT_INT64: return new ProtoInt64Values(tensorProto);
- case DT_FLOAT: return new ProtoFloatValues(tensorProto);
- case DT_DOUBLE: return new ProtoDoubleValues(tensorProto);
- default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
- }
- }
-
- private static Class dataTypeToClass(DataType dataType) {
- switch (dataType) {
- case DT_BOOL: return Boolean.class;
- case DT_INT16: return Short.class;
- case DT_INT32: return Integer.class;
- case DT_INT64: return Long.class;
- case DT_HALF: return Float.class;
- case DT_FLOAT: return Float.class;
- case DT_DOUBLE: return Double.class;
- default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
- }
- }
-
- private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) {
- return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()),
- asSizeArray(tensorProto.getTensorShape().getDimList()),
- tensorProto.getTensorContent().asReadOnlyByteBuffer());
- }
-
- private static long[] asSizeArray(List<TensorShapeProto.Dim> dimensions) {
- long[] sizes = new long[dimensions.size()];
- for (int i = 0; i < dimensions.size(); i++)
- sizes[i] = dimensions.get(i).getSize();
- return sizes;
- }
-
- /** Allows reading values from buffers of various numeric types as bytes */
- private static abstract class Values {
- abstract double get(int i);
- abstract int size();
- }
-
- private static abstract class TensorFlowValues extends Values {
- private final int size;
- TensorFlowValues(int size) {
- this.size = size;
- }
- @Override int size() { return this.size; }
- }
-
- private static class DoubleValues extends TensorFlowValues {
- private final DoubleBuffer values;
- DoubleValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = DoubleBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class FloatValues extends TensorFlowValues {
- private final FloatBuffer values;
- FloatValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = FloatBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class BoolValues extends TensorFlowValues {
- private final ByteBuffer values;
- BoolValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = ByteBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class IntValues extends TensorFlowValues {
- private final IntBuffer values;
- IntValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = IntBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static class LongValues extends TensorFlowValues {
- private final LongBuffer values;
- LongValues(org.tensorflow.Tensor<?> tfTensor) {
- super(tfTensor.numElements());
- values = LongBuffer.allocate(tfTensor.numElements());
- tfTensor.writeTo(values);
- }
- @Override double get(int i) {
- return values.get(i);
- }
- }
-
- private static abstract class ProtoValues extends Values {
- final TensorProto tensorProto;
- ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; }
- }
-
- private static class ProtoBoolValues extends ProtoValues {
- ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; }
- @Override int size() { return tensorProto.getBoolValCount(); }
- }
-
- private static class ProtoHalfValues extends ProtoValues {
- ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getHalfVal(i); }
- @Override int size() { return tensorProto.getHalfValCount(); }
- }
-
- private static class ProtoIntValues extends ProtoValues {
- ProtoIntValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getIntVal(i); }
- @Override int size() { return tensorProto.getIntValCount(); }
- }
-
- private static class ProtoInt64Values extends ProtoValues {
- ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getInt64Val(i); }
- @Override int size() { return tensorProto.getInt64ValCount(); }
- }
-
- private static class ProtoFloatValues extends ProtoValues {
- ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getFloatVal(i); }
- @Override int size() { return tensorProto.getFloatValCount(); }
- }
-
- private static class ProtoDoubleValues extends ProtoValues {
- ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); }
- @Override double get(int i) { return tensorProto.getDoubleVal(i); }
- @Override int size() { return tensorProto.getDoubleValCount(); }
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
index 0e307992143..04ddb48e859 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java
@@ -2,14 +2,12 @@
package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.ImportedModel;
-import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.ModelImporter;
import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel;
import ai.vespa.rankingexpression.importer.onnx.OnnxImporter;
import com.yahoo.collections.Pair;
import com.yahoo.io.IOUtils;
import com.yahoo.system.ProcessExecuter;
-import org.tensorflow.SavedModelBundle;
import java.io.File;
import java.io.IOException;
@@ -27,7 +25,7 @@ public class TensorFlowImporter extends ModelImporter {
private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName());
- private final static int[] onnxOpsetsToTry = {8, 10, 12};
+ private final static int[] onnxOpsetsToTry = {12, 10, 8};
private final OnnxImporter onnxImporter = new OnnxImporter();
@@ -56,17 +54,6 @@ public class TensorFlowImporter extends ModelImporter {
return convertToOnnxAndImport(modelName, modelDir);
}
- /** Imports a TensorFlow model - DEPRECATED */
- public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) {
- try {
- IntermediateGraph graph = GraphImporter.importGraph(modelName, model);
- return convertIntermediateGraphToModel(graph, modelDir, ImportedMlModel.ModelType.TENSORFLOW);
- }
- catch (IOException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e);
- }
- }
-
private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) {
Path tempDir = null;
try {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
deleted file mode 100644
index 3102d5431d4..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ /dev/null
@@ -1,128 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-
-package ai.vespa.rankingexpression.importer.tensorflow;
-
-import ai.vespa.rankingexpression.importer.OrderedTensorType;
-import com.yahoo.tensor.TensorType;
-import org.tensorflow.framework.AttrValue;
-import org.tensorflow.framework.DataType;
-import org.tensorflow.framework.NodeDef;
-import org.tensorflow.framework.TensorShapeProto;
-
-/**
- * Converts and verifies TensorFlow tensor types into Vespa tensor types.
- *
- * @author lesters
- */
-class TypeConverter {
-
- static void verifyType(NodeDef node, OrderedTensorType type) {
- TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null) {
- if (shape.getDimCount() != type.rank()) {
- throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
- }
- for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
- int vespaIndex = type.dimensionMap(tensorFlowIndex);
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
- if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
- throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
- }
- }
- }
- }
-
- static OrderedTensorType typeFrom(NodeDef node) {
- String dimensionPrefix = "d"; // standard naming convention: d0, d1, ...
- TensorShapeProto shape = tensorFlowShape(node);
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node)));
- for (int i = 0; i < shape.getDimCount(); ++ i) {
- String dimensionName = dimensionPrefix + i;
- TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
- if (tensorFlowDimension.getSize() >= 0) {
- builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize()));
- } else {
- builder.add(TensorType.Dimension.indexed(dimensionName));
- }
- }
- return builder.build();
- }
-
- static TensorType typeFrom(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
- int dimensionIndex = 0;
- for (long dimensionSize : tfTensor.shape()) {
- if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
- b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
- }
- return b.build();
- }
-
- private static TensorShapeProto tensorFlowShape(NodeDef node) {
- // Use specific shape if available...
- AttrValue attrShape = node.getAttrMap().get("shape");
- if (attrShape != null && attrShape.getValueCase() == AttrValue.ValueCase.SHAPE) {
- return attrShape.getShape();
- }
-
- // ... else use inferred shape
- AttrValue attrOutputShapes = node.getAttrMap().get("_output_shapes");
- if (attrOutputShapes == null)
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- if (attrOutputShapes.getValueCase() != AttrValue.ValueCase.LIST)
- throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
-
- return attrOutputShapes.getList().getShape(0); // support multiple outputs?
- }
-
- private static DataType tensorFlowValueType(NodeDef node) {
- AttrValue attrValueList = node.getAttrMap().get("dtypes");
- if (attrValueList == null)
- return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better?
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
- return DataType.DT_DOUBLE; // default
-
- return attrValueList.getList().getType(0); // support multiple outputs?
- }
-
- private static TensorType.Value toValueType(DataType dataType) {
- switch (dataType) {
- case DT_FLOAT: return TensorType.Value.FLOAT;
- case DT_DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case DT_BOOL: return TensorType.Value.FLOAT;
- case DT_BFLOAT16: return TensorType.Value.FLOAT;
- case DT_HALF: return TensorType.Value.FLOAT;
- case DT_INT8: return TensorType.Value.FLOAT;
- case DT_INT16: return TensorType.Value.DOUBLE;
- case DT_INT32: return TensorType.Value.DOUBLE;
- case DT_INT64: return TensorType.Value.DOUBLE;
- case DT_UINT8: return TensorType.Value.FLOAT;
- case DT_UINT16: return TensorType.Value.DOUBLE;
- case DT_UINT32: return TensorType.Value.DOUBLE;
- case DT_UINT64: return TensorType.Value.DOUBLE;
- default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
- }
- }
-
- private static TensorType.Value toValueType(org.tensorflow.DataType dataType) {
- switch (dataType) {
- case FLOAT: return TensorType.Value.FLOAT;
- case DOUBLE: return TensorType.Value.DOUBLE;
- // Imperfect conversion, for now:
- case BOOL: return TensorType.Value.FLOAT;
- case INT32: return TensorType.Value.DOUBLE;
- case UINT8: return TensorType.Value.FLOAT;
- case INT64: return TensorType.Value.DOUBLE;
- default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
- " cannot be converted to a Vespa tensor type");
- }
- }
-
-}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
deleted file mode 100644
index 85ae5238bae..00000000000
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java
+++ /dev/null
@@ -1,59 +0,0 @@
-// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.rankingexpression.importer.tensorflow;
-
-import ai.vespa.rankingexpression.importer.OrderedTensorType;
-import com.yahoo.tensor.serialization.JsonFormat;
-import com.yahoo.yolean.Exceptions;
-import org.tensorflow.SavedModelBundle;
-
-import java.nio.charset.StandardCharsets;
-
-/**
- * Converts TensorFlow Variables to the Vespa document format.
- * Intended to be used from the command line to convert trained tensors to document form.
- *
- * @author bratseth
- */
-class VariableConverter {
-
- /**
- * Reads the tensor with the given TensorFlow name at the given model location,
- * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type.
- * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor
- * tensor dimensions are implicitly ordered.
- */
- static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) {
- try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) {
- return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName,
- bundle),
- OrderedTensorType.fromSpec(orderedTypeSpec)));
- }
- catch (IllegalArgumentException e) {
- throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e);
- }
- }
-
- public static void main(String[] args) {
- if ( args.length != 3) {
- System.out.println("Converts a TensorFlow variable into Vespa tensor document field value JSON:");
- System.out.println("A JSON map containing a 'cells' array, see");
- System.out.println("https://docs.vespa.ai/en/reference/document-json-format.html#tensor");
- System.out.println("");
- System.out.println("Arguments: modelDirectory tensorFlowVariableName orderedTypeSpec");
- System.out.println(" - modelDirectory: The directory of the TensorFlow SavedModel");
- System.out.println(" - tensorFlowVariableName: The name of the TensorFlow variable to convert");
- System.out.println(" - orderedTypeSpec: The tensor type, e.g tensor(b[],a[10]), where dimensions are ");
- System.out.println(" ordered as given in the deployment log message starting by ");
- System.out.println(" 'Importing TensorFlow variable'");
- return;
- }
-
- try {
- System.out.println(new String(importVariable(args[0], args[1], args[2]), StandardCharsets.UTF_8));
- }
- catch (Exception e) {
- System.err.println("Import failed: " + Exceptions.toMessageString(e));
- }
- }
-
-}