aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
commit5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch)
tree2b65d4f48b92bf7ec846b3efd5d5259244bc234a /model-integration
parent6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff)
Add tensor value type
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java40
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java24
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java21
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java15
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java20
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java18
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java22
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java41
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java7
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java26
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java53
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java5
15 files changed, 198 insertions, 134 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
index c4acfeb3235..9c8f6238731 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/OrderedTensorType.java
@@ -29,9 +29,17 @@ public class OrderedTensorType {
private final long[] innerSizesVespa;
private final int[] dimensionMap;
- private OrderedTensorType(List<TensorType.Dimension> dimensions) {
+ private OrderedTensorType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
this.dimensions = Collections.unmodifiableList(dimensions);
- this.type = new TensorType.Builder(dimensions).build();
+ this.type = new TensorType.Builder(valueType, dimensions).build();
+ this.innerSizesOriginal = new long[dimensions.size()];
+ this.innerSizesVespa = new long[dimensions.size()];
+ this.dimensionMap = createDimensionMap();
+ }
+
+ private OrderedTensorType(TensorType type) {
+ this.dimensions = type.dimensions();
+ this.type = type;
this.innerSizesOriginal = new long[dimensions.size()];
this.innerSizesVespa = new long[dimensions.size()];
this.dimensionMap = createDimensionMap();
@@ -136,11 +144,11 @@ public class OrderedTensorType {
renamedDimensions.add(TensorType.Dimension.mapped(newName.get()));
}
}
- return new OrderedTensorType(renamedDimensions);
+ return new OrderedTensorType(type.valueType(), renamedDimensions);
}
public OrderedTensorType rename(String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.valueType());
for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
Optional<Long> dimSize = dimensions.get(i).size();
@@ -154,7 +162,7 @@ public class OrderedTensorType {
}
public static OrderedTensorType standardType(OrderedTensorType type) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(type.type().valueType());
for (int i = 0; i < type.dimensions().size(); ++ i) {
TensorType.Dimension dim = type.dimensions().get(i);
String dimensionName = "d" + i;
@@ -193,18 +201,18 @@ public class OrderedTensorType {
* where dimensions are listed in the order of this rather than the natural order of their names.
*/
public static OrderedTensorType fromSpec(String typeSpec) {
- return new OrderedTensorType(TensorTypeParser.dimensionsFromSpec(typeSpec));
+ return new OrderedTensorType(TensorType.fromSpec(typeSpec));
}
- public static OrderedTensorType fromDimensionList(List<Long> dims) {
- return fromDimensionList(dims, "d"); // standard naming convention: d0, d1, ...
+ public static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions) {
+ return fromDimensionList(valueType, dimensions, "d"); // standard naming convention: d0, d1, ...
}
- private static OrderedTensorType fromDimensionList(List<Long> dims, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
- for (int i = 0; i < dims.size(); ++ i) {
+ private static OrderedTensorType fromDimensionList(TensorType.Value valueType, List<Long> dimensions, String dimensionPrefix) {
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(valueType);
+ for (int i = 0; i < dimensions.size(); ++ i) {
String dimensionName = dimensionPrefix + i;
- Long dimSize = dims.get(i);
+ Long dimSize = dimensions.get(i);
if (dimSize >= 0) {
builder.add(TensorType.Dimension.indexed(dimensionName, dimSize));
} else {
@@ -216,9 +224,15 @@ public class OrderedTensorType {
public static class Builder {
+ private final TensorType.Value valueType;
private final List<TensorType.Dimension> dimensions;
public Builder() {
+ this(TensorType.Value.DOUBLE);
+ }
+
+ public Builder(TensorType.Value valueType) {
+ this.valueType = valueType;
this.dimensions = new ArrayList<>();
}
@@ -228,7 +242,7 @@ public class OrderedTensorType {
}
public OrderedTensorType build() {
- return new OrderedTensorType(dimensions);
+ return new OrderedTensorType(valueType, dimensions);
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
index dd2add973e4..5cc1defc010 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/GraphImporter.java
@@ -16,8 +16,10 @@ import ai.vespa.rankingexpression.importer.operations.MatMul;
import ai.vespa.rankingexpression.importer.operations.NoOp;
import ai.vespa.rankingexpression.importer.operations.Reshape;
import ai.vespa.rankingexpression.importer.operations.Shape;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.ScalarFunctions;
import onnx.Onnx;
+import onnx.Onnx.TensorProto.DataType;
import java.util.List;
import java.util.stream.Collectors;
@@ -114,7 +116,8 @@ class GraphImporter {
} else if (isConstantTensor(name, onnxGraph)) {
Onnx.TensorProto tensorProto = getConstantTensor(name, onnxGraph);
- OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(tensorProto.getDimsList());
+ OrderedTensorType defaultType = OrderedTensorType.fromDimensionList(toValueType(tensorProto.getDataType()),
+ tensorProto.getDimsList());
operation = new Constant(intermediateGraph.name(), name, defaultType);
operation.setConstantValueFunction(type -> new TensorValue(TensorConverter.toVespaTensor(tensorProto, type)));
@@ -133,6 +136,25 @@ class GraphImporter {
return operation;
}
+ private static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
+ case UINT32: return TensorType.Value.FLOAT;
+ case UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
private static boolean isArgumentTensor(String name, Onnx.GraphProto graph) {
Onnx.ValueInfoProto value = getArgumentTensor(name, graph);
Onnx.TensorProto tensor = getConstantTensor(name, graph);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index f251a14213b..79b399f2c6f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -36,7 +36,7 @@ class TypeConverter {
private static OrderedTensorType fromOnnxType(Onnx.TypeProto type, String dimensionPrefix) {
Onnx.TensorShapeProto shape = type.getTensorType().getShape();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(TensorType.Value.DOUBLE);
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
Onnx.TensorShapeProto.Dimension onnxDimension = shape.getDim(i);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 1a564661ccb..7ae50a0549d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
- return null;
- }
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
- if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a constant.");
- }
+ if ( ! concatDimOp.getConstantValue().isPresent())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant.");
+
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
- if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a scalar.");
- }
+ if (concatDimTensor.type().rank() != 0)
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar.");
OrderedTensorType aType = inputs.get(0).type().get();
concatDimensionIndex = (int)concatDimTensor.asDouble();
@@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
- if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "inputs must have save rank.");
- }
+ if (bType.rank() != aType.rank())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank.");
+
for (int j = 0; j < aType.rank(); ++j) {
long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
@@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation {
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDimensionIndex) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 8ae6d81b8d4..c64b9ded601 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis must be a constant.");
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis argument must be a scalar.");
- }
+ if (axis.type().rank() != 0)
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar.");
OrderedTensorType inputType = inputs.get(0).type().get();
int dimensionToInsert = (int)axis.asDouble();
@@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
@@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(2)) return null;
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : expandDimensions) {
typeBuilder.indexed(name, 1);
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 3b77f9527ca..0ee54f839bc 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -9,6 +9,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
+import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.TensorFunction;
@@ -17,6 +18,7 @@ import java.util.Collections;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
+import java.util.stream.Collectors;
/**
* Wraps an imported operation node and produces the respective Vespa tensor
@@ -161,6 +163,19 @@ public abstract class IntermediateOperation {
}
/**
+ * Returns the largest value type among the input value types.
+ * This should only be called after it has been verified that input types are available.
+ *
+ * @throws IllegalArgumentException if a type cannot be uniquely determined
+ * @throws RuntimeException if called when input types are not available
+ */
+ TensorType.Value resultValueType() {
+ return TensorType.Value.largestOf(inputs.stream()
+ .map(input -> input.type().get().type().valueType())
+ .collect(Collectors.toList()));
+ }
+
+ /**
* A method signature input and output has the form name:index.
* This returns the name part without the index.
*/
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index fed95e13bb7..c2d75153586 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -22,13 +22,12 @@ public class Join extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
int sizeDifference = a.rank() - b.rank();
for (int i = 0; i < a.rank(); ++i) {
TensorType.Dimension aDim = a.dimensions().get(i);
@@ -52,12 +51,8 @@ public class Join extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
IntermediateOperation a = largestInput();
IntermediateOperation b = smallestInput();
@@ -92,9 +87,8 @@ public class Join extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
OrderedTensorType a = largestInput().type().get();
OrderedTensorType b = smallestInput().type().get();
int sizeDifference = a.rank() - b.rank();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 1dbfd6e40dc..9a76662529d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -17,10 +17,9 @@ public class MatMul extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ if ( ! allInputTypesPresent(2)) return null;
+
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
typeBuilder.add(inputs.get(0).type().get().dimensions().get(0));
typeBuilder.add(inputs.get(1).type().get().dimensions().get(1));
return typeBuilder.build();
@@ -28,9 +27,8 @@ public class MatMul extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
OrderedTensorType aType = inputs.get(0).type().get();
OrderedTensorType bType = inputs.get(1).type().get();
if (aType.type().rank() < 2 || bType.type().rank() < 2)
@@ -48,9 +46,8 @@ public class MatMul extends IntermediateOperation {
@Override
public void addDimensionNameConstraints(DimensionRenamer renamer) {
- if (!allInputTypesPresent(2)) {
- return;
- }
+ if ( ! allInputTypesPresent(2)) return;
+
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
@@ -69,4 +66,5 @@ public class MatMul extends IntermediateOperation {
renamer.addConstraint(aDim0, aDim1, DimensionRenamer::lesserThan, this);
renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
}
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
index 4be220db9d5..d8e9950c61f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Mean.java
@@ -32,13 +32,11 @@ public class Mean extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
IntermediateOperation reductionIndices = inputs.get(1);
- if (!reductionIndices.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Mean in " + name + ": " +
- "reduction indices must be a constant.");
+ if ( ! reductionIndices.getConstantValue().isPresent()) {
+ throw new IllegalArgumentException("Mean in " + name + ": Reduction indices must be a constant.");
}
Tensor indices = reductionIndices.getConstantValue().get().asTensor();
reduceDimensions = new ArrayList<>();
@@ -59,14 +57,14 @@ public class Mean extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
+
TensorFunction inputFunction = inputs.get(0).function().get();
TensorFunction output = new Reduce(inputFunction, Reduce.Aggregator.avg, reduceDimensions);
if (shouldKeepDimensions()) {
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : reduceDimensions) {
typeBuilder.indexed(name, 1);
}
@@ -99,9 +97,9 @@ public class Mean extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType, boolean keepDimensions) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
- if (!reduceDimensions.contains(dimension.name())) {
+ if ( ! reduceDimensions.contains(dimension.name())) {
builder.add(dimension);
} else if (keepDimensions) {
builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
index 18f3cc1cc39..4a0fe236c9f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Reshape.java
@@ -32,18 +32,16 @@ public class Reshape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+
IntermediateOperation newShape = inputs.get(1);
- if (!newShape.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("Reshape in " + name + ": " +
- "shape input must be a constant.");
- }
+ if ( ! newShape.getConstantValue().isPresent())
+ throw new IllegalArgumentException("Reshape in " + name + ": Shape input must be a constant.");
+
Tensor shape = newShape.getConstantValue().get().asTensor();
OrderedTensorType inputType = inputs.get(0).type().get();
- OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder outputTypeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (Iterator<Tensor.Cell> cellIterator = shape.cellIterator(); cellIterator.hasNext();) {
Tensor.Cell cell = cellIterator.next();
@@ -61,12 +59,9 @@ public class Reshape extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
+ if ( ! allInputFunctionsPresent(2)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
TensorFunction inputFunction = inputs.get(0).function().get();
return reshape(inputFunction, inputType.type(), type.type());
@@ -80,9 +75,8 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
+ if ( ! OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType)))
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
- }
// Conceptually, reshaping consists on unrolling a tensor to an array using the dimension order,
// then use the dimension order of the new shape to roll back into a tensor.
@@ -96,20 +90,17 @@ public class Reshape extends IntermediateOperation {
TensorType transformationType = new TensorType.Builder(inputType, outputType).build();
Generate transformTensor = new Generate(transformationType,
- new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
-
- TensorFunction outputFunction = new Reduce(
- new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
+ new GeneratorLambdaFunctionNode(transformationType, transformExpression).asLongListToDoubleOperator());
- return outputFunction;
+ return new Reduce(new com.yahoo.tensor.functions.Join(inputFunction, transformTensor, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ inputType.dimensions().stream().map(TensorType.Dimension::name).collect(Collectors.toList()));
}
private static ExpressionNode unrollTensorExpression(TensorType type) {
- if (type.rank() == 0) {
+ if (type.rank() == 0)
return new ConstantNode(DoubleValue.zero);
- }
+
List<ExpressionNode> children = new ArrayList<>();
List<ArithmeticOperator> operators = new ArrayList<>();
int size = 1;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
index 361729a8c14..79f3012c327 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Shape.java
@@ -19,11 +19,10 @@ public class Shape extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
+ if ( ! allInputTypesPresent(1)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
- return new OrderedTensorType.Builder()
+ return new OrderedTensorType.Builder(resultValueType())
.add(TensorType.Dimension.indexed(vespaName(), inputType.dimensions().size()))
.build();
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
index 2eeefcbe8a2..52d40144f61 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Squeeze.java
@@ -25,9 +25,8 @@ public class Squeeze extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(1)) {
- return null;
- }
+ if ( ! allInputTypesPresent(1)) return null;
+
OrderedTensorType inputType = inputs.get(0).type().get();
squeezeDimensions = new ArrayList<>();
@@ -51,9 +50,8 @@ public class Squeeze extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(1)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(1)) return null;
+
TensorFunction inputFunction = inputs.get(0).function().get();
return new Reduce(inputFunction, Reduce.Aggregator.sum, squeezeDimensions);
}
@@ -73,7 +71,7 @@ public class Squeeze extends IntermediateOperation {
}
private OrderedTensorType reducedType(OrderedTensorType inputType) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(resultValueType());
for (TensorType.Dimension dimension: inputType.type().dimensions()) {
if ( ! squeezeDimensions.contains(dimension.name())) {
builder.add(dimension);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
index 6c92ffa6055..a4fe38cce95 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TensorConverter.java
@@ -5,6 +5,7 @@ import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import org.tensorflow.DataType;
import org.tensorflow.framework.TensorProto;
import java.nio.ByteBuffer;
@@ -27,7 +28,7 @@ public class TensorConverter {
}
private static Tensor toVespaTensor(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
- TensorType type = toVespaTensorType(tfTensor.shape(), dimensionPrefix);
+ TensorType type = toVespaTensorType(tfTensor, dimensionPrefix);
Values values = readValuesOf(tfTensor);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type);
for (int i = 0; i < values.size(); i++)
@@ -53,10 +54,10 @@ public class TensorConverter {
return builder.build();
}
- private static TensorType toVespaTensorType(long[] shape, String dimensionPrefix) {
- TensorType.Builder b = new TensorType.Builder();
+ private static TensorType toVespaTensorType(org.tensorflow.Tensor<?> tfTensor, String dimensionPrefix) {
+ TensorType.Builder b = new TensorType.Builder(toValueType(tfTensor.dataType()));
int dimensionIndex = 0;
- for (long dimensionSize : shape) {
+ for (long dimensionSize : tfTensor.shape()) {
if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ...
b.indexed(dimensionPrefix + (dimensionIndex++), dimensionSize);
}
@@ -85,7 +86,7 @@ public class TensorConverter {
case INT64: return new LongValues(tfTensor);
}
throw new IllegalArgumentException("Cannot convert a tensor with elements of type " +
- tfTensor.dataType() + " to a Vespa tensor");
+ tfTensor.dataType() + " to a Vespa tensor");
}
private static Values readValuesOf(TensorProto tensorProto) {
@@ -107,6 +108,21 @@ public class TensorConverter {
throw new IllegalArgumentException("Unsupported data type in attribute tensor import");
}
+ /** TensorFlow has two different DataType classes. This must be kept in sync with TypeConverter.toValueType */
+ static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case FLOAT: return TensorType.Value.FLOAT;
+ case DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT32: return TensorType.Value.FLOAT;
+ case UINT8: return TensorType.Value.FLOAT;
+ case INT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
/** Allows reading values from buffers of various numeric types as bytes */
private static abstract class Values {
abstract double get(int i);
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index 63a605ce97a..3e825026b0e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -5,6 +5,7 @@ package ai.vespa.rankingexpression.importer.tensorflow;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import org.tensorflow.framework.AttrValue;
+import org.tensorflow.framework.DataType;
import org.tensorflow.framework.NodeDef;
import org.tensorflow.framework.TensorShapeProto;
@@ -22,7 +23,7 @@ class TypeConverter {
if (shape != null) {
if (shape.getDimCount() != type.rank()) {
throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
+ "does not match Vespa shape");
}
for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions().size(); ++tensorFlowIndex) {
int vespaIndex = type.dimensionMap(tensorFlowIndex);
@@ -30,7 +31,7 @@ class TypeConverter {
TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
+ "does not match Vespa dimensions");
}
}
}
@@ -38,16 +39,24 @@ class TypeConverter {
private static TensorShapeProto tensorFlowShape(NodeDef node) {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
- if (attrValueList == null) {
+ if (attrValueList == null)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
- }
- if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
+ "does not exist");
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
- }
- List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
- return shapeList.get(0); // support multiple outputs?
+ "is not of expected type");
+
+ return attrValueList.getList().getShape(0); // support multiple outputs?
+ }
+
+ private static DataType tensorFlowValueType(NodeDef node) {
+ AttrValue attrValueList = node.getAttrMap().get("dtypes");
+ if (attrValueList == null)
+ return DataType.DT_DOUBLE; // default. This will usually (always?) be used. TODO: How can we do better?
+ if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST)
+ return DataType.DT_DOUBLE; // default
+
+ return attrValueList.getList().getType(0); // support multiple outputs?
}
static OrderedTensorType fromTensorFlowType(NodeDef node) {
@@ -55,8 +64,8 @@ class TypeConverter {
}
private static OrderedTensorType fromTensorFlowType(NodeDef node, String dimensionPrefix) {
- OrderedTensorType.Builder builder = new OrderedTensorType.Builder();
TensorShapeProto shape = tensorFlowShape(node);
+ OrderedTensorType.Builder builder = new OrderedTensorType.Builder(toValueType(tensorFlowValueType(node)));
for (int i = 0; i < shape.getDimCount(); ++ i) {
String dimensionName = dimensionPrefix + i;
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(i);
@@ -69,4 +78,26 @@ class TypeConverter {
return builder.build();
}
+ /** TensorFlow has two different DataType classes. This must be kept in sync with TensorConverter.toValueType */
+ static TensorType.Value toValueType(DataType dataType) {
+ switch (dataType) {
+ case DT_FLOAT: return TensorType.Value.FLOAT;
+ case DT_DOUBLE: return TensorType.Value.DOUBLE;
+ // Imperfect conversion, for now:
+ case DT_BOOL: return TensorType.Value.FLOAT;
+ case DT_BFLOAT16: return TensorType.Value.FLOAT;
+ case DT_HALF: return TensorType.Value.FLOAT;
+ case DT_INT8: return TensorType.Value.FLOAT;
+ case DT_INT16: return TensorType.Value.FLOAT;
+ case DT_INT32: return TensorType.Value.FLOAT;
+ case DT_INT64: return TensorType.Value.DOUBLE;
+ case DT_UINT8: return TensorType.Value.FLOAT;
+ case DT_UINT16: return TensorType.Value.FLOAT;
+ case DT_UINT32: return TensorType.Value.FLOAT;
+ case DT_UINT64: return TensorType.Value.DOUBLE;
+ default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
+ " cannot be converted to a Vespa tensor type");
+ }
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
index afe699d6e05..61f332327be 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/OrderedTensorTypeTestCase.java
@@ -13,9 +13,10 @@ public class OrderedTensorTypeTestCase {
@Test
public void testToFromSpec() {
String spec = "tensor(b[],c{},a[3])";
+ String orderedSpec = "tensor(a[3],b[],c{})";
OrderedTensorType type = OrderedTensorType.fromSpec(spec);
- assertEquals(spec, type.toString());
- assertEquals("tensor(a[3],b[],c{})", type.type().toString());
+ assertEquals(orderedSpec, type.toString());
+ assertEquals(orderedSpec, type.type().toString());
}
}