summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2019-04-01 15:59:42 +0200
committerHenning Baldersheim <balder@yahoo-inc.com>2019-04-01 15:59:42 +0200
commit06b999904e735420ad5d1a74ae551f88573d2657 (patch)
treeb8d8f134aab5a4adeb166ba56cedb64281d231ae /vespajlib
parente8440bd1dbafac4ce09797bb3395b0cac54c6d82 (diff)
Add support for different values during decoding too.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java44
3 files changed, 62 insertions, 11 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 766b0daedcc..ccfe99119ee 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1162,8 +1162,11 @@
],
"methods": [
"public void <init>()",
+ "public void <init>(com.yahoo.tensor.TensorType$ValueType)",
"public varargs void <init>(com.yahoo.tensor.TensorType[])",
+ "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])",
"public void <init>(java.lang.Iterable)",
+ "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)",
"public int rank()",
"public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)",
"public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bb7a976af9b..d548b13601f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -27,16 +27,17 @@ public class TensorType {
public enum ValueType { DOUBLE, FLOAT};
/** The empty tensor type - which is the same as a double */
- public static final TensorType empty = new TensorType(Collections.emptyList());
+ public static final TensorType empty = new TensorType(ValueType.DOUBLE, Collections.emptyList());
- private final ValueType valueType = ValueType.DOUBLE;
+ private final ValueType valueType;
public final ValueType valueType() { return valueType; }
/** Sorted list of the dimensions of this */
private final ImmutableList<Dimension> dimensions;
- private TensorType(Collection<Dimension> dimensions) {
+ private TensorType(ValueType valueType, Collection<Dimension> dimensions) {
+ this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
this.dimensions = ImmutableList.copyOf(dimensionList);
@@ -379,8 +380,15 @@ public class TensorType {
private final Map<String, Dimension> dimensions = new LinkedHashMap<>();
- /** Creates an empty builder */
+ private final ValueType valueType;
+
+ /** Creates an empty builder with cells of type double*/
public Builder() {
+ this(ValueType.DOUBLE);
+ }
+
+ public Builder(ValueType valueType) {
+ this.valueType = valueType;
}
/**
@@ -391,6 +399,10 @@ public class TensorType {
* If it is indexed in one and mapped in the other it will become mapped.
*/
public Builder(TensorType ... types) {
+ this(ValueType.DOUBLE, types);
+ }
+ public Builder(ValueType valueType, TensorType ... types) {
+ this.valueType = valueType;
for (TensorType type : types)
addDimensionsOf(type);
}
@@ -399,6 +411,10 @@ public class TensorType {
* Creates a builder from the given dimensions.
*/
public Builder(Iterable<Dimension> dimensions) {
+ this(ValueType.DOUBLE, dimensions);
+ }
+ public Builder(ValueType valueType, Iterable<Dimension> dimensions) {
+ this.valueType = valueType;
for (TensorType.Dimension dimension : dimensions) {
dimension(dimension);
}
@@ -497,7 +513,7 @@ public class TensorType {
}
public TensorType build() {
- return new TensorType(dimensions.values());
+ return new TensorType(valueType, dimensions.values());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 6dcc870ef94..2537e7d8669 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -93,24 +93,41 @@ public class DenseBinaryFormat implements BinaryFormat {
DimensionSizes sizes;
if (optionalType.isPresent()) {
type = optionalType.get();
- TensorType serializedType = decodeType(buffer);
+ TensorType serializedType = decodeType(buffer, type.valueType());
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
" cannot be assigned to type " + type);
sizes = sizesFromType(serializedType);
}
else {
- type = decodeType(buffer);
+ type = decodeType(buffer, TensorType.ValueType.DOUBLE);
sizes = sizesFromType(type);
}
Tensor.Builder builder = Tensor.Builder.of(type, sizes);
- decodeCells(sizes, buffer, (IndexedTensor.BoundBuilder)builder);
+ decodeCells(type.valueType(), sizes, buffer, (IndexedTensor.BoundBuilder)builder);
return builder.build();
}
- private TensorType decodeType(GrowableByteBuffer buffer) {
+ private TensorType decodeType(GrowableByteBuffer buffer, TensorType.ValueType valueType) {
+ TensorType.ValueType serializedValueType = TensorType.ValueType.DOUBLE;
+ if ((valueType != TensorType.ValueType.DOUBLE) || (encodeType != EncodeType.DOUBLE_IS_DEFAULT)) {
+ int type = buffer.getInt1_4Bytes();
+ switch (type) {
+ case DOUBLE_VALUE_TYPE:
+ serializedValueType = TensorType.ValueType.DOUBLE;
+ break;
+ case FLOAT_VALUE_TYPE:
+ serializedValueType = TensorType.ValueType.DOUBLE;
+ break;
+ default:
+ throw new IllegalArgumentException("Received tensor value type '" + serializedValueType + "'. Only 0(double), or 1(float) are legal.");
+ }
+ }
+ if (valueType != serializedValueType) {
+ throw new IllegalArgumentException("Expected " + valueType + ", got " + serializedValueType);
+ }
+ TensorType.Builder builder = new TensorType.Builder(serializedValueType);
int dimensionCount = buffer.getInt1_4Bytes();
- TensorType.Builder builder = new TensorType.Builder();
for (int i = 0; i < dimensionCount; i++)
builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation
return builder.build();
@@ -124,9 +141,24 @@ public class DenseBinaryFormat implements BinaryFormat {
return builder.build();
}
- private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ private void decodeCells(TensorType.ValueType valueType, DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ switch (valueType) {
+ case DOUBLE:
+ decodeCellsAsDouble(sizes, buffer, builder);
+ break;
+ case FLOAT:
+ decodeCellsAsFloat(sizes, buffer, builder);
+ break;
+ }
+ }
+
+ private void decodeCellsAsDouble(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
for (long i = 0; i < sizes.totalSize(); i++)
builder.cellByDirectIndex(i, buffer.getDouble());
}
+ private void decodeCellsAsFloat(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) {
+ for (long i = 0; i < sizes.totalSize(); i++)
+ builder.cellByDirectIndex(i, buffer.getFloat());
+ }
}