diff options
author | Lester Solbakken <lesters@oath.com> | 2021-10-04 13:38:42 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-10-04 13:38:42 +0200 |
commit | 765571b29407071271806bf95766f100dde82dd8 (patch) | |
tree | 573e62f6382dd5cec386b3c74b31acb8d87cf7f8 /model-integration/src/main/java/ai | |
parent | fed386f182a0b600a72b333bb308c15870c3f04e (diff) |
Remove Java dependencies to tensorflow
Diffstat (limited to 'model-integration/src/main/java/ai')
6 files changed, 1 insertions, 780 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java deleted file mode 100644 index f2c6dfd9069..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/AttributeConverter.java +++ /dev/null @@ -1,87 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer.tensorflow; - -import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; -import ai.vespa.rankingexpression.importer.OrderedTensorType; -import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.NodeDef; - -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.stream.Collectors; - -/** - * Converts TensorFlow node attributes to Vespa attribute values. - * - * @author lesters - */ -class AttributeConverter implements IntermediateOperation.AttributeMap { - - private final Map<String, AttrValue> attributeMap; - - private AttributeConverter(NodeDef node) { - attributeMap = node.getAttrMap(); - } - - static AttributeConverter convert(NodeDef node) { - return new AttributeConverter(node); - } - - @Override - public Optional<Value> get(String key) { - if (attributeMap.containsKey(key)) { - AttrValue attrValue = attributeMap.get(key); - if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return Optional.empty(); // requires type - } - if (attrValue.getValueCase() == AttrValue.ValueCase.B) { - return Optional.of(new BooleanValue(attrValue.getB())); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.I) { - return Optional.of(new DoubleValue(attrValue.getI())); - } - if (attrValue.getValueCase() == AttrValue.ValueCase.F) { - return Optional.of(new DoubleValue(attrValue.getF())); - } - } - return Optional.empty(); - } - - @Override - public Optional<Value> get(String key, OrderedTensorType type) { - if (attributeMap.containsKey(key)) { - AttrValue attrValue = attributeMap.get(key); - if (attrValue.getValueCase() == AttrValue.ValueCase.TENSOR) { - return Optional.of(new TensorValue(TensorConverter.toVespaTensor(attrValue.getTensor(), type))); - } - } - return get(key); - } - - @Override - public Optional<List<Value>> getList(String key) { - if (attributeMap.containsKey(key)) { - AttrValue attrValue = attributeMap.get(key); - if (attrValue.getValueCase() == AttrValue.ValueCase.LIST) { - AttrValue.ListValue listValue = attrValue.getList(); - if ( ! listValue.getBList().isEmpty()) { - return Optional.of(listValue.getBList().stream().map(BooleanValue::new).collect(Collectors.toList())); - } - if ( ! listValue.getIList().isEmpty()) { - return Optional.of(listValue.getIList().stream().map(DoubleValue::new).collect(Collectors.toList())); - } - if ( ! listValue.getFList().isEmpty()) { - return Optional.of(listValue.getFList().stream().map(DoubleValue::new).collect(Collectors.toList())); - } - // add the rest - } - } - return Optional.empty(); - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java deleted file mode 100644 index 0d2ba0cc714..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/GraphImporter.java +++ /dev/null @@ -1,254 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.rankingexpression.importer.tensorflow; - -import ai.vespa.rankingexpression.importer.operations.Softmax; -import ai.vespa.rankingexpression.importer.operations.Sum; -import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; -import ai.vespa.rankingexpression.importer.IntermediateGraph; -import ai.vespa.rankingexpression.importer.OrderedTensorType; -import ai.vespa.rankingexpression.importer.operations.Argument; -import ai.vespa.rankingexpression.importer.operations.ConcatV2; -import ai.vespa.rankingexpression.importer.operations.Const; -import ai.vespa.rankingexpression.importer.operations.Constant; -import ai.vespa.rankingexpression.importer.operations.ExpandDims; -import ai.vespa.rankingexpression.importer.operations.Identity; -import ai.vespa.rankingexpression.importer.operations.IntermediateOperation; -import ai.vespa.rankingexpression.importer.operations.Join; -import ai.vespa.rankingexpression.importer.operations.Map; -import ai.vespa.rankingexpression.importer.operations.MatMul; -import ai.vespa.rankingexpression.importer.operations.Mean; -import ai.vespa.rankingexpression.importer.operations.Merge; -import ai.vespa.rankingexpression.importer.operations.NoOp; -import ai.vespa.rankingexpression.importer.operations.PlaceholderWithDefault; -import ai.vespa.rankingexpression.importer.operations.Reshape; -import ai.vespa.rankingexpression.importer.operations.Select; -import ai.vespa.rankingexpression.importer.operations.Shape; -import ai.vespa.rankingexpression.importer.operations.Squeeze; -import ai.vespa.rankingexpression.importer.operations.Switch; -import com.yahoo.tensor.functions.ScalarFunctions; -import org.tensorflow.SavedModelBundle; -import org.tensorflow.Session; -import org.tensorflow.framework.GraphDef; -import org.tensorflow.framework.MetaGraphDef; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.SignatureDef; -import org.tensorflow.framework.TensorInfo; - -import java.io.IOException; -import java.util.List; -import java.util.stream.Collectors; - -/** - * Converts a TensorFlow graph to a Vespa IntermediateGraph which is the basis - * for generating Vespa ranking expressions. - * - * @author lesters - */ -class GraphImporter { - - private static IntermediateOperation mapOperation(NodeDef node, - List<IntermediateOperation> inputs, - IntermediateGraph graph) { - String nodeName = node.getName(); - String modelName = graph.name(); - int nodePort = IntermediateOperation.indexPartOf(nodeName); - OrderedTensorType nodeType = TypeConverter.typeFrom(node); - AttributeConverter attributes = AttributeConverter.convert(node); - - switch (node.getOp().toLowerCase()) { - // array ops - case "concatv2": return new ConcatV2(modelName, nodeName, inputs); - case "const": return new Const(modelName, nodeName, inputs, attributes, nodeType); - case "expanddims": return new ExpandDims(modelName, nodeName, inputs); - case "identity": return new Identity(modelName, nodeName, inputs); - case "placeholder": return new Argument(modelName, nodeName, nodeType); - case "placeholderwithdefault": return new PlaceholderWithDefault(modelName, nodeName, inputs); - case "reshape": return new Reshape(modelName, nodeName, inputs, attributes); - case "shape": return new Shape(modelName, nodeName, inputs); - case "squeeze": return new Squeeze(modelName, nodeName, inputs, attributes); - - // control flow - case "merge": return new Merge(modelName, nodeName, inputs); - case "switch": return new Switch(modelName, nodeName, inputs, nodePort); - - // math ops - case "abs": return new Map(modelName, nodeName, inputs, ScalarFunctions.abs()); - case "acos": return new Map(modelName, nodeName, inputs, ScalarFunctions.acos()); - case "add": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "add_n": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "asin": return new Map(modelName, nodeName, inputs, ScalarFunctions.asin()); - case "atan": return new Map(modelName, nodeName, inputs, ScalarFunctions.atan()); - case "ceil": return new Map(modelName, nodeName, inputs, ScalarFunctions.ceil()); - case "cos": return new Map(modelName, nodeName, inputs, ScalarFunctions.cos()); - case "div": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "exp": return new Map(modelName, nodeName, inputs, ScalarFunctions.exp()); - case "realdiv": return new Join(modelName, nodeName, inputs, ScalarFunctions.divide()); - case "floor": return new Map(modelName, nodeName, inputs, ScalarFunctions.floor()); - case "log": return new Map(modelName, nodeName, inputs, ScalarFunctions.log()); - case "matmul": return new MatMul(modelName, nodeName, inputs); - case "maximum": return new Join(modelName, nodeName, inputs, ScalarFunctions.max()); - case "mean": return new Mean(modelName, nodeName, inputs, attributes); - case "reducemean": return new Mean(modelName, nodeName, inputs, attributes); - case "mul": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); - case "multiply": return new Join(modelName, nodeName, inputs, ScalarFunctions.multiply()); - case "negate": return new Map(modelName, nodeName, inputs, ScalarFunctions.neg()); - case "reciprocal": return new Map(modelName, nodeName, inputs, ScalarFunctions.reciprocal()); - case "rsqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.rsqrt()); - case "select": return new Select(modelName, nodeName, inputs); - case "where3": return new Select(modelName, nodeName, inputs); - case "sigmoid": return new Map(modelName, nodeName, inputs, ScalarFunctions.sigmoid()); - case "sin": return new Map(modelName, nodeName, inputs, ScalarFunctions.sin()); - case "squareddifference": return new Join(modelName, nodeName, inputs, ScalarFunctions.squareddifference()); - case "sub": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - case "subtract": return new Join(modelName, nodeName, inputs, ScalarFunctions.subtract()); - case "sum": return new Sum(modelName, nodeName, inputs, attributes); - case "square": return new Map(modelName, nodeName, inputs, ScalarFunctions.square()); - case "sqrt": return new Map(modelName, nodeName, inputs, ScalarFunctions.sqrt()); - case "tan": return new Map(modelName, nodeName, inputs, ScalarFunctions.tan()); - case "tanh": return new Map(modelName, nodeName, inputs, ScalarFunctions.tanh()); - - // nn ops - case "biasadd": return new Join(modelName, nodeName, inputs, ScalarFunctions.add()); - case "elu": return new Map(modelName, nodeName, inputs, ScalarFunctions.elu()); - case "relu": return new Map(modelName, nodeName, inputs, ScalarFunctions.relu()); - case "selu": return new Map(modelName, nodeName, inputs, ScalarFunctions.selu()); - case "softmax": return new Softmax(modelName, nodeName, inputs, attributes); - - // state ops - case "variable": return new Constant(modelName, nodeName, nodeType); - case "variablev2": return new Constant(modelName, nodeName, nodeType); - case "varhandleop": return new Constant(modelName, nodeName, nodeType); - case "readvariableop":return new Identity(modelName, nodeName, inputs); - - // evaluation no-ops - case "stopgradient":return new Identity(modelName, nodeName, inputs); - case "noop": return new NoOp(modelName, nodeName, inputs); - - } - - IntermediateOperation op = new NoOp(modelName, node.getName(), inputs); - op.warning("Operation '" + node.getOp() + "' is currently not implemented"); - return op; - } - - static IntermediateGraph importGraph(String modelName, SavedModelBundle bundle) throws IOException { - MetaGraphDef tfGraph = MetaGraphDef.parseFrom(bundle.metaGraphDef()); - - IntermediateGraph intermediateGraph = new IntermediateGraph(modelName); - importSignatures(tfGraph, intermediateGraph); - importOperations(tfGraph, intermediateGraph, bundle); - verifyOutputTypes(tfGraph, intermediateGraph); - - return intermediateGraph; - } - - private static void importSignatures(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { - for (java.util.Map.Entry<String, SignatureDef> signatureEntry : tfGraph.getSignatureDefMap().entrySet()) { - String signatureName = signatureEntry.getKey(); - java.util.Map<String, TensorInfo> inputInfoMap = signatureEntry.getValue().getInputsMap(); - for (java.util.Map.Entry<String, TensorInfo> input : inputInfoMap.entrySet()) { - String inputName = input.getKey(); - String nodeName = input.getValue().getName(); - intermediateGraph.inputs(signatureName).put(inputName, IntermediateOperation.namePartOf(nodeName)); - } - java.util.Map<String, TensorInfo> outputInfoMap = signatureEntry.getValue().getOutputsMap(); - for (java.util.Map.Entry<String, TensorInfo> output : outputInfoMap.entrySet()) { - String outputName = output.getKey(); - String nodeName = output.getValue().getName(); - intermediateGraph.outputs(signatureName).put(outputName, IntermediateOperation.namePartOf(nodeName)); - } - } - } - - private static void importOperations(MetaGraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - for (String signatureName : intermediateGraph.signatures()) { - for (String outputName : intermediateGraph.outputs(signatureName).values()) { - importOperation(outputName, tfGraph.getGraphDef(), intermediateGraph, bundle); - } - } - } - - private static IntermediateOperation importOperation(String nodeName, - GraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - if (intermediateGraph.alreadyImported(nodeName)) { - return intermediateGraph.get(nodeName); - } - NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(nodeName), tfGraph); - List<IntermediateOperation> inputs = importOperationInputs(node, tfGraph, intermediateGraph, bundle); - IntermediateOperation operation = mapOperation(node, inputs, intermediateGraph); - intermediateGraph.put(nodeName, operation); - - List<IntermediateOperation> controlInputs = importControlInputs(node, tfGraph, intermediateGraph, bundle); - if (controlInputs.size() > 0) { - operation.setControlInputs(controlInputs); - } - - if (operation.isConstant()) { - operation.setConstantValueFunction( - type -> new TensorValue(TensorConverter.toVespaTensor(readVariable(nodeName, bundle), type))); - } - - return operation; - } - - private static List<IntermediateOperation> importOperationInputs(NodeDef node, - GraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - return node.getInputList().stream() - .filter(name -> ! isControlDependency(name)) - .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) - .collect(Collectors.toList()); - } - - private static List<IntermediateOperation> importControlInputs(NodeDef node, - GraphDef tfGraph, - IntermediateGraph intermediateGraph, - SavedModelBundle bundle) { - return node.getInputList().stream() - .filter(nodeName -> isControlDependency(nodeName)) - .map(nodeName -> importOperation(nodeName, tfGraph, intermediateGraph, bundle)) - .collect(Collectors.toList()); - } - - private static boolean isControlDependency(String name) { - return name.startsWith("^"); - } - - private static NodeDef getTensorFlowNodeFromGraph(String name, GraphDef tfGraph) { - for (NodeDef node : tfGraph.getNodeList()) { - if (node.getName().equals(name)) { - return node; - } - } - throw new IllegalArgumentException("Could not find node '" + name + "'"); - } - - static org.tensorflow.Tensor<?> readVariable(String name, SavedModelBundle bundle) { - Session.Runner fetched = bundle.session().runner().fetch(name); - List<org.tensorflow.Tensor<?>> importedTensors = fetched.run(); - if (importedTensors.size() != 1) - throw new IllegalStateException("Expected 1 tensor from fetching " + name + - ", but got " + importedTensors.size()); - return importedTensors.get(0); - } - - private static void verifyOutputTypes(MetaGraphDef tfGraph, IntermediateGraph intermediateGraph) { - for (String signatureName : intermediateGraph.signatures()) { - for (String outputName : intermediateGraph.outputs(signatureName).values()) { - IntermediateOperation operation = intermediateGraph.get(outputName); - NodeDef node = getTensorFlowNodeFromGraph(IntermediateOperation.namePartOf(operation.name()), tfGraph.getGraphDef()); - OrderedTensorType type = operation.type().orElseThrow( - () -> new IllegalArgumentException("Output of '" + outputName + "' has no type.")); - TypeConverter.verifyType(node, type); - } - } - - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java deleted file mode 100644 index 95727acb5b4..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ /dev/null @@ -1,238 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer.tensorflow; - -import ai.vespa.rankingexpression.importer.OrderedTensorType; -import com.yahoo.tensor.IndexedTensor; -import com.yahoo.tensor.Tensor; -import com.yahoo.tensor.TensorType; -import org.tensorflow.framework.DataType; -import org.tensorflow.framework.TensorProto; -import org.tensorflow.framework.TensorShapeProto; - -import java.nio.ByteBuffer; -import java.nio.DoubleBuffer; -import java.nio.FloatBuffer; -import java.nio.IntBuffer; -import java.nio.LongBuffer; -import java.util.List; - -/** - * Converts TensorFlow tensors into Vespa tensors. - * - * @author bratseth - * @author lesters - */ -public class TensorConverter { - - public static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor) { - return toVespaTensor(tfTensor, "d"); - } - - private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = TypeConverter.typeFrom(tfTensor, dimensionPrefix); - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); - for (int i = 0; i < values.size(); i++) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, OrderedTensorType type) { - Values values = readValuesOf(tfTensor); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type.type()); - for (int i = 0; i < values.size(); i++) { - builder.cellByDirectIndex(type.toDirectIndex(i), values.get(i)); - } - return builder.build(); - } - - static Tensor toVespaTensor(TensorProto tensorProto, OrderedTensorType type) { - Values values = readValuesOf(tensorProto); - if (values.size() == 0) { // Might be stored as "tensor_content" instead - return toVespaTensor(readTensorContentOf(tensorProto), type); - } - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type.type()); - for (int i = 0; i < values.size(); ++i) - builder.cellByDirectIndex(i, values.get(i)); - return builder.build(); - } - - public static Long tensorSize(TensorType type) { - Long size = 1L; - for (TensorType.Dimension dimension : type.dimensions()) { - size *= dimensionSize(dimension); - } - return size; - } - - private static Long dimensionSize(TensorType.Dimension dim) { - return dim.size().orElseThrow(() -> new IllegalArgumentException("Dimension has no size")); - } - - private static Values readValuesOf(org.tensorflow.Tensor<?> tfTensor) { - switch (tfTensor.dataType()) { - case DOUBLE: return new DoubleValues(tfTensor); - case FLOAT: return new FloatValues(tfTensor); - case BOOL: return new BoolValues(tfTensor); - case UINT8: return new IntValues(tfTensor); - case INT32: return new IntValues(tfTensor); - case INT64: return new LongValues(tfTensor); - default: throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); - } - } - - private static Values readValuesOf(TensorProto tensorProto) { - switch (tensorProto.getDtype()) { - case DT_BOOL: return new ProtoBoolValues(tensorProto); - case DT_HALF: return new ProtoHalfValues(tensorProto); - case DT_INT16: case DT_INT32: return new ProtoIntValues(tensorProto); - case DT_INT64: return new ProtoInt64Values(tensorProto); - case DT_FLOAT: return new ProtoFloatValues(tensorProto); - case DT_DOUBLE: return new ProtoDoubleValues(tensorProto); - default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); - } - } - - private static Class dataTypeToClass(DataType dataType) { - switch (dataType) { - case DT_BOOL: return Boolean.class; - case DT_INT16: return Short.class; - case DT_INT32: return Integer.class; - case DT_INT64: return Long.class; - case DT_HALF: return Float.class; - case DT_FLOAT: return Float.class; - case DT_DOUBLE: return Double.class; - default: throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); - } - } - - private static org.tensorflow.Tensor readTensorContentOf(TensorProto tensorProto) { - return org.tensorflow.Tensor.create(dataTypeToClass(tensorProto.getDtype()), - asSizeArray(tensorProto.getTensorShape().getDimList()), - tensorProto.getTensorContent().asReadOnlyByteBuffer()); - } - - private static long[] asSizeArray(List<TensorShapeProto.Dim> dimensions) { - long[] sizes = new long[dimensions.size()]; - for (int i = 0; i < dimensions.size(); i++) - sizes[i] = dimensions.get(i).getSize(); - return sizes; - } - - /** Allows reading values from buffers of various numeric types as bytes */ - private static abstract class Values { - abstract double get(int i); - abstract int size(); - } - - private static abstract class TensorFlowValues extends Values { - private final int size; - TensorFlowValues(int size) { - this.size = size; - } - @Override int size() { return this.size; } - } - - private static class DoubleValues extends TensorFlowValues { - private final DoubleBuffer values; - DoubleValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = DoubleBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class FloatValues extends TensorFlowValues { - private final FloatBuffer values; - FloatValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = FloatBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class BoolValues extends TensorFlowValues { - private final ByteBuffer values; - BoolValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = ByteBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class IntValues extends TensorFlowValues { - private final IntBuffer values; - IntValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = IntBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static class LongValues extends TensorFlowValues { - private final LongBuffer values; - LongValues(org.tensorflow.Tensor<?> tfTensor) { - super(tfTensor.numElements()); - values = LongBuffer.allocate(tfTensor.numElements()); - tfTensor.writeTo(values); - } - @Override double get(int i) { - return values.get(i); - } - } - - private static abstract class ProtoValues extends Values { - final TensorProto tensorProto; - ProtoValues(TensorProto tensorProto) { this.tensorProto = tensorProto; } - } - - private static class ProtoBoolValues extends ProtoValues { - ProtoBoolValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getBoolVal(i) ? 1.0 : 0.0; } - @Override int size() { return tensorProto.getBoolValCount(); } - } - - private static class ProtoHalfValues extends ProtoValues { - ProtoHalfValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getHalfVal(i); } - @Override int size() { return tensorProto.getHalfValCount(); } - } - - private static class ProtoIntValues extends ProtoValues { - ProtoIntValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getIntVal(i); } - @Override int size() { return tensorProto.getIntValCount(); } - } - - private static class ProtoInt64Values extends ProtoValues { - ProtoInt64Values(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getInt64Val(i); } - @Override int size() { return tensorProto.getInt64ValCount(); } - } - - private static class ProtoFloatValues extends ProtoValues { - ProtoFloatValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getFloatVal(i); } - @Override int size() { return tensorProto.getFloatValCount(); } - } - - private static class ProtoDoubleValues extends ProtoValues { - ProtoDoubleValues(TensorProto tensorProto) { super(tensorProto); } - @Override double get(int i) { return tensorProto.getDoubleVal(i); } - @Override int size() { return tensorProto.getDoubleValCount(); } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java index 0e307992143..04ddb48e859 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorFlowImporter.java @@ -2,14 +2,12 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.ImportedModel; -import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.ModelImporter; import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlModel; import ai.vespa.rankingexpression.importer.onnx.OnnxImporter; import com.yahoo.collections.Pair; import com.yahoo.io.IOUtils; import com.yahoo.system.ProcessExecuter; -import org.tensorflow.SavedModelBundle; import java.io.File; import java.io.IOException; @@ -27,7 +25,7 @@ public class TensorFlowImporter extends ModelImporter { private static final Logger log = Logger.getLogger(TensorFlowImporter.class.getName()); - private final static int[] onnxOpsetsToTry = {8, 10, 12}; + private final static int[] onnxOpsetsToTry = {12, 10, 8}; private final OnnxImporter onnxImporter = new OnnxImporter(); @@ -56,17 +54,6 @@ public class TensorFlowImporter extends ModelImporter { return convertToOnnxAndImport(modelName, modelDir); } - /** Imports a TensorFlow model - DEPRECATED */ - public ImportedModel importModel(String modelName, String modelDir, SavedModelBundle model) { - try { - IntermediateGraph graph = GraphImporter.importGraph(modelName, model); - return convertIntermediateGraphToModel(graph, modelDir, ImportedMlModel.ModelType.TENSORFLOW); - } - catch (IOException e) { - throw new IllegalArgumentException("Could not import TensorFlow model '" + model + "'", e); - } - } - private ImportedModel convertToOnnxAndImport(String modelName, String modelDir) { Path tempDir = null; try { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java deleted file mode 100644 index 3102d5431d4..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ /dev/null @@ -1,128 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. - -package ai.vespa.rankingexpression.importer.tensorflow; - -import ai.vespa.rankingexpression.importer.OrderedTensorType; -import com.yahoo.tensor.TensorType; -import org.tensorflow.framework.AttrValue; -import org.tensorflow.framework.DataType; -import org.tensorflow.framework.NodeDef; -import org.tensorflow.framework.TensorShapeProto; - -/** - * Converts and verifies TensorFlow tensor types into Vespa tensor types. - * - * @author lesters - */ -class TypeConverter { - - static void verifyType(NodeDef node, OrderedTensorType type) { - TensorShapeProto shape = tensorFlowShape(node); - if (shape != null) { - if (shape.getDimCount() != type.rank()) { - throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); - } - for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { - int vespaIndex = type.dimensionMap(tensorFlowIndex); - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); - TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); - if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { - throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); - } - } - } - } - - static OrderedTensorType typeFrom(NodeDef node) { - String dimensionPrefix = "d"; // standard naming convention: d0, d1, ... - TensorShapeProto shape = tensorFlowShape(node); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node))); - for (int i = 0; i < shape.getDimCount(); ++ i) { - String dimensionName = dimensionPrefix + i; - TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); - if (tensorFlowDimension.getSize() >= 0) { - builder.add(TensorType.Dimension.indexed(dimensionName, tensorFlowDimension.getSize())); - } else { - builder.add(TensorType.Dimension.indexed(dimensionName)); - } - } - return builder.build(); - } - - static TensorType typeFrom(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); - int dimensionIndex = 0; - for (long dimensionSize : tfTensor.shape()) { - if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); - } - return b.build(); - } - - private static TensorShapeProto tensorFlowShape(NodeDef node) { - // Use specific shape if available... - AttrValue attrShape = node.getAttrMap().get("shape"); - if (attrShape != null && attrShape.getValueCase() == AttrValue.ValueCase.SHAPE) { - return attrShape.getShape(); - } - - // ... else use inferred shape - AttrValue attrOutputShapes = node.getAttrMap().get("_output_shapes"); - if (attrOutputShapes == null) - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - if (attrOutputShapes.getValueCase() != AttrValue.ValueCase.LIST) - throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - - return attrOutputShapes.getList().getShape(0); // support multiple outputs? - } - - private static DataType tensorFlowValueType(NodeDef node) { - AttrValue attrValueList = node.getAttrMap().get("dtypes"); - if (attrValueList == null) - return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better? - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) - return DataType.DT_DOUBLE; // default - - return attrValueList.getList().getType(0); // support multiple outputs? - } - - private static TensorType.Value toValueType(DataType dataType) { - switch (dataType) { - case DT_FLOAT: return TensorType.Value.FLOAT; - case DT_DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case DT_BOOL: return TensorType.Value.FLOAT; - case DT_BFLOAT16: return TensorType.Value.FLOAT; - case DT_HALF: return TensorType.Value.FLOAT; - case DT_INT8: return TensorType.Value.FLOAT; - case DT_INT16: return TensorType.Value.DOUBLE; - case DT_INT32: return TensorType.Value.DOUBLE; - case DT_INT64: return TensorType.Value.DOUBLE; - case DT_UINT8: return TensorType.Value.FLOAT; - case DT_UINT16: return TensorType.Value.DOUBLE; - case DT_UINT32: return TensorType.Value.DOUBLE; - case DT_UINT64: return TensorType.Value.DOUBLE; - default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); - } - } - - private static TensorType.Value toValueType(org.tensorflow.DataType dataType) { - switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; - case DOUBLE: return TensorType.Value.DOUBLE; - // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.DOUBLE; - case UINT8: return TensorType.Value.FLOAT; - case INT64: return TensorType.Value.DOUBLE; - default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + - " cannot be converted to a Vespa tensor type"); - } - } - -} diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java deleted file mode 100644 index 85ae5238bae..00000000000 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/VariableConverter.java +++ /dev/null @@ -1,59 +0,0 @@ -// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. -package ai.vespa.rankingexpression.importer.tensorflow; - -import ai.vespa.rankingexpression.importer.OrderedTensorType; -import com.yahoo.tensor.serialization.JsonFormat; -import com.yahoo.yolean.Exceptions; -import org.tensorflow.SavedModelBundle; - -import java.nio.charset.StandardCharsets; - -/** - * Converts TensorFlow Variables to the Vespa document format. - * Intended to be used from the command line to convert trained tensors to document form. - * - * @author bratseth - */ -class VariableConverter { - - /** - * Reads the tensor with the given TensorFlow name at the given model location, - * and encodes it as UTF-8 Vespa document tensor JSON having the given ordered tensor type. - * Note that order of dimensions in the tensor type does matter as the TensorFlow tensor - * tensor dimensions are implicitly ordered. - */ - static byte[] importVariable(String modelDir, String tensorFlowVariableName, String orderedTypeSpec) { - try (SavedModelBundle bundle = SavedModelBundle.load(modelDir, "serve")) { - return JsonFormat.encode(TensorConverter.toVespaTensor(GraphImporter.readVariable(tensorFlowVariableName, - bundle), - OrderedTensorType.fromSpec(orderedTypeSpec))); - } - catch (IllegalArgumentException e) { - throw new IllegalArgumentException("Could not import TensorFlow model from directory '" + modelDir + "'", e); - } - } - - public static void main(String[] args) { - if ( args.length != 3) { - System.out.println("Converts a TensorFlow variable into Vespa tensor document field value JSON:"); - System.out.println("A JSON map containing a 'cells' array, see"); - System.out.println("https://docs.vespa.ai/en/reference/document-json-format.html#tensor"); - System.out.println(""); - System.out.println("Arguments: modelDirectory tensorFlowVariableName orderedTypeSpec"); - System.out.println(" - modelDirectory: The directory of the TensorFlow SavedModel"); - System.out.println(" - tensorFlowVariableName: The name of the TensorFlow variable to convert"); - System.out.println(" - orderedTypeSpec: The tensor type, e.g tensor(b[],a[10]), where dimensions are "); - System.out.println(" ordered as given in the deployment log message starting by "); - System.out.println(" 'Importing TensorFlow variable'"); - return; - } - - try { - System.out.println(new String(importVariable(args[0], args[1], args[2]), StandardCharsets.UTF_8)); - } - catch (Exception e) { - System.err.println("Import failed: " + Exceptions.toMessageString(e)); - } - } - -} |