diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a /model-integration | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
Diffstat (limited to 'model-integration')
15 files changed, 198 insertions, 134 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java index c4acfeb3235..9c8f6238731 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java @@ -29,9 +29,17 @@ public class OrderedTensorType { private final long[] innerSizesVespa; private final int[] dimensionMap; - private OrderedTensorType(List<TensorType.Dimension> dimensions) { + private OrderedTensorType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { this.dimensions = Collections.unmodifiableList(dimensions); - this.type = new TensorType.Builder(dimensions).build(); + this.type = new TensorType.Builder(valueType, dimensions).build(); + this.innerSizesOriginal = new long[dimensions.size()]; + this.innerSizesVespa = new long[dimensions.size()]; + this.dimensionMap = createDimensionMap(); + } + + private OrderedTensorType(TensorType type) { + this.dimensions = type.dimensions(); + this.type = type; this.innerSizesOriginal = new long[dimensions.size()]; this.innerSizesVespa = new long[dimensions.size()]; this.dimensionMap = createDimensionMap(); @@ -136,11 +144,11 @@ public class OrderedTensorType { renamedDimensions.add(TensorType.Dimension.mapped(newName.get())); } } - return new OrderedTensorType(renamedDimensions); + return new OrderedTensorType(type.valueType(), renamedDimensions); } public OrderedTensorType rename(String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.valueType()); for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; Optional<Long> dimSize = dimensions.get(i).size(); @@ -154,7 +162,7 @@ public class OrderedTensorType { } public static OrderedTensorType standardType(OrderedTensorType type) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.type().valueType()); for (int i = 0; i < type.dimensions().size(); ++ i) { TensorType.Dimension dim = type.dimensions().get(i); String dimensionName = "d" + i; @@ -193,18 +201,18 @@ public class OrderedTensorType { * where dimensions are listed in the order of this rather than the natural order of their names. */ public static OrderedTensorType fromSpec(String typeSpec) { - return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec)); + return new OrderedTensorType(TensorType.fromSpec(typeSpec)); } - public static OrderedTensorType fromDimensionList(List<Long> dims) { - return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ... + public static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions) { + return fromDimensionList(valueType, dimensions, "d"); // standard naming convention: d0, d1, ... } - private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); - for (int i = 0; i < dims.size(); ++ i) { + private static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions, String dimensionPrefix) { + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueType); + for (int i = 0; i < dimensions.size(); ++ i) { String dimensionName = dimensionPrefix + i; - Long dimSize = dims.get(i); + Long dimSize = dimensions.get(i); if (dimSize >= 0) { builder.add(TensorType.Dimension.indexed(dimensionName, dimSize)); } else { @@ -216,9 +224,15 @@ public class OrderedTensorType { public static class Builder { + private final TensorType.Value valueType; private final List<TensorType.Dimension> dimensions; public Builder() { + this(TensorType.Value.DOUBLE); + } + + public Builder(TensorType.Value valueType) { + this.valueType = valueType; this.dimensions = new ArrayList<>(); } @@ -228,7 +242,7 @@ public class OrderedTensorType { } public OrderedTensorType build() { - return new OrderedTensorType(dimensions); + return new OrderedTensorType(valueType, dimensions); } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java index dd2add973e4..5cc1defc010 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java @@ -16,8 +16,10 @@ import ai.vespa.rankingexpression.importer.operations.MatMul; import ai.vespa.rankingexpression.importer.operations.NoOp; import ai.vespa.rankingexpression.importer.operations.Reshape; import ai.vespa.rankingexpression.importer.operations.Shape; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.ScalarFunctions; import onnx.Onnx; +import onnx.Onnx.TensorProto.DataType; import java.util.List; import java.util.stream.Collectors; @@ -114,7 +116,8 @@ class GraphImporter { } else if (isConstantTensor(name, onnxGraph)) { Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph); - OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList()); + OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()), + tensorProto.getDimsList()); operation = new Constant(intermediateGraph.name(), name, defaultType); operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type))); @@ -133,6 +136,25 @@ class GraphImporter { return operation; } + private static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT8: return TensorType.Value.FLOAT; + case INT16: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + case UINT8: return TensorType.Value.FLOAT; + case UINT16: return TensorType.Value.FLOAT; + case UINT32: return TensorType.Value.FLOAT; + case UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) { Onnx.ValueInfoProto value = getArgumentTensor(name, graph); Onnx.TensorProto tensor = getConstantTensor(name, graph); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index f251a14213b..79b399f2c6f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -36,7 +36,7 @@ class TypeConverter { private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) { Onnx.TensorShapeProto shape = type.getTensorType().getShape(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 1a564661ccb..7ae50a0549d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { - return null; - } + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input - if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a constant."); - } + if ( ! concatDimOp.getConstantValue().isPresent()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant."); + Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); - if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a scalar."); - } + if (concatDimTensor.type().rank() != 0) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar."); OrderedTensorType aType = inputs.get(0).type().get(); concatDimensionIndex = (int)concatDimTensor.asDouble(); @@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); - if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "inputs must have save rank."); - } + if (bType.rank() != aType.rank()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank."); + for (int j = 0; j < aType.rank(); ++j) { long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); @@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation { } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDimensionIndex) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 8ae6d81b8d4..c64b9ded601 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis must be a constant."); + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis argument must be a scalar."); - } + if (axis.type().rank() != 0) + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar."); OrderedTensorType inputType = inputs.get(0).type().get(); int dimensionToInsert = (int)axis.asDouble(); @@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { @@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputFunctionsPresent(2)) return null; // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : expandDimensions) { typeBuilder.indexed(name, 1); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 3b77f9527ca..0ee54f839bc 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -9,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.TensorFunction; @@ -17,6 +18,7 @@ import java.util.Collections; import java.util.List; import java.util.Optional; import java.util.function.Function; +import java.util.stream.Collectors; /** * Wraps an imported operation node and produces the respective Vespa tensor @@ -161,6 +163,19 @@ public abstract class IntermediateOperation { } /** + * Returns the largest value type among the input value types. + * This should only be called after it has been verified that input types are available. + * + * @throws IllegalArgumentException if a type cannot be uniquely determined + * @throws RuntimeException if called when input types are not available + */ + TensorType.Value resultValueType() { + return TensorType.Value.largestOf(inputs.stream() + .map(input -> input.type().get().type().valueType()) + .collect(Collectors.toList())); + } + + /** * A method signature input and output has the form name:index. * This returns the name part without the index. */ diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index fed95e13bb7..c2d75153586 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -22,13 +22,12 @@ public class Join extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType a = largestInput().type().get(); OrderedTensorType b = smallestInput().type().get(); - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); int sizeDifference = a.rank() - b.rank(); for (int i = 0; i < a.rank(); ++i) { TensorType.Dimension aDim = a.dimensions().get(i); @@ -52,12 +51,8 @@ public class Join extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; IntermediateOperation a = largestInput(); IntermediateOperation b = smallestInput(); @@ -92,9 +87,8 @@ public class Join extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + OrderedTensorType a = largestInput().type().get(); OrderedTensorType b = smallestInput().type().get(); int sizeDifference = a.rank() - b.rank(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 1dbfd6e40dc..9a76662529d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + if ( ! allInputTypesPresent(2)) return null; + + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); typeBuilder.add(inputs.get(0).type().get().dimensions().get(0)); typeBuilder.add(inputs.get(1).type().get().dimensions().get(1)); return typeBuilder.build(); @@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + OrderedTensorType aType = inputs.get(0).type().get(); OrderedTensorType bType = inputs.get(1).type().get(); if (aType.type().rank() < 2 || bType.type().rank() < 2) @@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation { @Override public void addDimensionNameConstraints(DimensionRenamer renamer) { - if (!allInputTypesPresent(2)) { - return; - } + if ( ! allInputTypesPresent(2)) return; + List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); @@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this); renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java index 4be220db9d5..d8e9950c61f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java @@ -32,13 +32,11 @@ public class Mean extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + IntermediateOperation reductionIndices = inputs.get(1); - if (!reductionIndices.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Mean in " + name + ": " + - "reduction indices must be a constant."); + if ( ! reductionIndices.getConstantValue().isPresent()) { + throw new IllegalArgumentException("Mean in " + name + ": Reduction indices must be a constant."); } Tensor indices = reductionIndices.getConstantValue().get().asTensor(); reduceDimensions = new ArrayList<>(); @@ -59,14 +57,14 @@ public class Mean extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + + TensorFunction inputFunction = inputs.get(0).function().get(); TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions); if (shouldKeepDimensions()) { // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : reduceDimensions) { typeBuilder.indexed(name, 1); } @@ -99,9 +97,9 @@ public class Mean extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { - if (!reduceDimensions.contains(dimension.name())) { + if ( ! reduceDimensions.contains(dimension.name())) { builder.add(dimension); } else if (keepDimensions) { builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java index 18f3cc1cc39..4a0fe236c9f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java @@ -32,18 +32,16 @@ public class Reshape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + IntermediateOperation newShape = inputs.get(1); - if (!newShape.getConstantValue().isPresent()) { - throw new IllegalArgumentException("Reshape in " + name + ": " + - "shape input must be a constant."); - } + if ( ! newShape.getConstantValue().isPresent()) + throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant."); + Tensor shape = newShape.getConstantValue().get().asTensor(); OrderedTensorType inputType = inputs.get(0).type().get(); - OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) { Tensor.Cell cell = cellIterator.next(); @@ -61,12 +59,9 @@ public class Reshape extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputTypesPresent(2)) { - return null; - } - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; + if ( ! allInputFunctionsPresent(2)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); TensorFunction inputFunction = inputs.get(0).function().get(); return reshape(inputFunction, inputType.type(), type.type()); @@ -80,9 +75,8 @@ public class Reshape extends IntermediateOperation { } public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { + if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); - } // Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order, // then use the dimension order of the new shape to roll back into a tensor. @@ -96,20 +90,17 @@ public class Reshape extends IntermediateOperation { TensorType transformationType = new TensorType.Builder(inputType, outputType).build(); Generate transformTensor = new Generate(transformationType, - new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - - TensorFunction outputFunction = new Reduce( - new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); + new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator()); - return outputFunction; + return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList())); } private static ExpressionNode unrollTensorExpression(TensorType type) { - if (type.rank() == 0) { + if (type.rank() == 0) return new ConstantNode(DoubleValue.zero); - } + List<ExpressionNode> children = new ArrayList<>(); List<ArithmeticOperator> operators = new ArrayList<>(); int size = 1; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java index 361729a8c14..79f3012c327 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java @@ -19,11 +19,10 @@ public class Shape extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } + if ( ! allInputTypesPresent(1)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); - return new OrderedTensorType.Builder() + return new OrderedTensorType.Builder(resultValueType()) .add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size())) .build(); } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java index 2eeefcbe8a2..52d40144f61 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java @@ -25,9 +25,8 @@ public class Squeeze extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(1)) { - return null; - } + if ( ! allInputTypesPresent(1)) return null; + OrderedTensorType inputType = inputs.get(0).type().get(); squeezeDimensions = new ArrayList<>(); @@ -51,9 +50,8 @@ public class Squeeze extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(1)) { - return null; - } + if ( ! allInputFunctionsPresent(1)) return null; + TensorFunction inputFunction = inputs.get(0).function().get(); return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions); } @@ -73,7 +71,7 @@ public class Squeeze extends IntermediateOperation { } private OrderedTensorType reducedType(OrderedTensorType inputType) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType()); for (TensorType.Dimension dimension: inputType.type().dimensions()) { if ( ! squeezeDimensions.contains(dimension.name())) { builder.add(dimension); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java index 6c92ffa6055..a4fe38cce95 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java @@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import org.tensorflow.DataType; import org.tensorflow.framework.TensorProto; import java.nio.ByteBuffer; @@ -27,7 +28,7 @@ public class TensorConverter { } private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { - TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix); + TensorType type = toVespaTensorType(tfTensor, dimensionPrefix); Values values = readValuesOf(tfTensor); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type); for (int i = 0; i < values.size(); i++) @@ -53,10 +54,10 @@ public class TensorConverter { return builder.build(); } - private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) { - TensorType.Builder b = new TensorType.Builder(); + private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) { + TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType())); int dimensionIndex = 0; - for (long dimensionSize : shape) { + for (long dimensionSize : tfTensor.shape()) { if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize); } @@ -85,7 +86,7 @@ public class TensorConverter { case INT64: return new LongValues(tfTensor); } throw new IllegalArgumentException("Cannot convert a tensor with elements of type " + - tfTensor.dataType() + " to a Vespa tensor"); + tfTensor.dataType() + " to a Vespa tensor"); } private static Values readValuesOf(TensorProto tensorProto) { @@ -107,6 +108,21 @@ public class TensorConverter { throw new IllegalArgumentException("Unsupported data type in attribute tensor import"); } + /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */ + static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case FLOAT: return TensorType.Value.FLOAT; + case DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case BOOL: return TensorType.Value.FLOAT; + case INT32: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.FLOAT; + case INT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + /** Allows reading values from buffers of various numeric types as bytes */ private static abstract class Values { abstract double get(int i); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index 63a605ce97a..3e825026b0e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.tensorflow; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.DataType; import org.tensorflow.framework.NodeDef; import org.tensorflow.framework.TensorShapeProto; @@ -22,7 +23,7 @@ class TypeConverter { if (shape != null) { if (shape.getDimCount() != type.rank()) { throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); + "does not match Vespa shape"); } for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) { int vespaIndex = type.dimensionMap(tensorFlowIndex); @@ -30,7 +31,7 @@ class TypeConverter { TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); + "does not match Vespa dimensions"); } } } @@ -38,16 +39,24 @@ class TypeConverter { private static TensorShapeProto tensorFlowShape(NodeDef node) { AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); - if (attrValueList == null) { + if (attrValueList == null) throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); - } - if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { + "does not exist"); + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); - } - List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); - return shapeList.get(0); // support multiple outputs? + "is not of expected type"); + + return attrValueList.getList().getShape(0); // support multiple outputs? + } + + private static DataType tensorFlowValueType(NodeDef node) { + AttrValue attrValueList = node.getAttrMap().get("dtypes"); + if (attrValueList == null) + return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better? + if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) + return DataType.DT_DOUBLE; // default + + return attrValueList.getList().getType(0); // support multiple outputs? } static OrderedTensorType fromTensorFlowType(NodeDef node) { @@ -55,8 +64,8 @@ class TypeConverter { } private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) { - OrderedTensorType.Builder builder = new OrderedTensorType.Builder(); TensorShapeProto shape = tensorFlowShape(node); + OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node))); for (int i = 0; i < shape.getDimCount(); ++ i) { String dimensionName = dimensionPrefix + i; TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i); @@ -69,4 +78,26 @@ class TypeConverter { return builder.build(); } + /** TensorFlow has two different DataType classes. This must be kept in sync with TensorConverter.toValueType */ + static TensorType.Value toValueType(DataType dataType) { + switch (dataType) { + case DT_FLOAT: return TensorType.Value.FLOAT; + case DT_DOUBLE: return TensorType.Value.DOUBLE; + // Imperfect conversion, for now: + case DT_BOOL: return TensorType.Value.FLOAT; + case DT_BFLOAT16: return TensorType.Value.FLOAT; + case DT_HALF: return TensorType.Value.FLOAT; + case DT_INT8: return TensorType.Value.FLOAT; + case DT_INT16: return TensorType.Value.FLOAT; + case DT_INT32: return TensorType.Value.FLOAT; + case DT_INT64: return TensorType.Value.DOUBLE; + case DT_UINT8: return TensorType.Value.FLOAT; + case DT_UINT16: return TensorType.Value.FLOAT; + case DT_UINT32: return TensorType.Value.FLOAT; + case DT_UINT64: return TensorType.Value.DOUBLE; + default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + + " cannot be converted to a Vespa tensor type"); + } + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java index afe699d6e05..61f332327be 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java @@ -13,9 +13,10 @@ public class OrderedTensorTypeTestCase { @Test public void testToFromSpec() { String spec = "tensor(b[],c{},a[3])"; + String orderedSpec = "tensor(a[3],b[],c{})"; OrderedTensorType type = OrderedTensorType.fromSpec(spec); - assertEquals(spec, type.toString()); - assertEquals("tensor(a[3],b[],c{})", type.type().toString()); + assertEquals(orderedSpec, type.toString()); + assertEquals(orderedSpec, type.type().toString()); } } |