aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 11:24:52 +0200
commit049e9a325c8142958909d0464da12a56e5a8f638 (patch)
tree31d857ec4a5ad3415464e480ae473c39224623b2 /vespajlib
parentbccd68f8f9a7eb0830d136f8b034ae4f40cc819c (diff)
Add bfloat16 and int8 tensor cell types in Java
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java29
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java39
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java62
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java42
15 files changed, 277 insertions, 14 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e51569da988..9ad2c55f7e3 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1411,7 +1411,9 @@
],
"fields": [
"public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
- "public static final enum com.yahoo.tensor.TensorType$Value FLOAT"
+ "public static final enum com.yahoo.tensor.TensorType$Value FLOAT",
+ "public static final enum com.yahoo.tensor.TensorType$Value INT8",
+ "public static final enum com.yahoo.tensor.TensorType$Value BFLOAT16"
]
},
"com.yahoo.tensor.TensorType": {
@@ -3463,4 +3465,4 @@
],
"fields": []
}
-}
+} \ No newline at end of file
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index dc17c657db9..9f3d7c01c6b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -225,6 +225,10 @@ public abstract class IndexedTensor implements Tensor {
b.append(tensor.get(index));
else if (tensor.type().valueType() == TensorType.Value.FLOAT)
b.append(tensor.getFloat(index));
+ else if (tensor.type().valueType() == TensorType.Value.BFLOAT16)
+ b.append(tensor.getFloat(index));
+ else if (tensor.type().valueType() == TensorType.Value.INT8)
+ b.append(tensor.getFloat(index));
else
throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
@@ -295,6 +299,10 @@ public abstract class IndexedTensor implements Tensor {
if (type.valueType() == TensorType.Value.FLOAT)
return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.BFLOAT16)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ else if (type.valueType() == TensorType.Value.INT8)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
else if (type.valueType() == TensorType.Value.DOUBLE)
return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
else
@@ -315,6 +323,10 @@ public abstract class IndexedTensor implements Tensor {
if (type.valueType() == TensorType.Value.FLOAT)
return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ else if (type.valueType() == TensorType.Value.BFLOAT16)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.INT8)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
else if (type.valueType() == TensorType.Value.DOUBLE)
return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
else
@@ -335,6 +347,10 @@ public abstract class IndexedTensor implements Tensor {
if (type.valueType() == TensorType.Value.FLOAT)
return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.BFLOAT16)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ else if (type.valueType() == TensorType.Value.INT8)
+ return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
else if (type.valueType() == TensorType.Value.DOUBLE)
return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
else
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index f608aead347..be5f4143f54 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -533,6 +533,10 @@ public class MixedTensor implements Tensor {
b.append(getDouble(subspaceIndex, index, tensor));
else if (tensor.type().valueType() == TensorType.Value.FLOAT)
b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats
+ else if (tensor.type().valueType() == TensorType.Value.BFLOAT16)
+ b.append(getDouble(subspaceIndex, index, tensor));
+ else if (tensor.type().valueType() == TensorType.Value.INT8)
+ b.append(getDouble(subspaceIndex, index, tensor));
else
throw new IllegalStateException("Unexpected value type " + type.valueType());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index becec1a4493..126fee878ab 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -201,6 +201,10 @@ class TensorParser {
return Double.parseDouble(cellValueString);
else if (cellValueType == TensorType.Value.FLOAT)
return Float.parseFloat(cellValueString);
+ else if (cellValueType == TensorType.Value.BFLOAT16)
+ return Float.parseFloat(cellValueString);
+ else if (cellValueType == TensorType.Value.INT8)
+ return Float.parseFloat(cellValueString);
else
throw new IllegalArgumentException(cellValueType + " is not supported");
} catch (NumberFormatException e) {
@@ -291,6 +295,10 @@ class TensorParser {
builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number);
else if (builder.type().valueType() == TensorType.Value.FLOAT)
builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
+ else if (builder.type().valueType() == TensorType.Value.BFLOAT16)
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
+ else if (builder.type().valueType() == TensorType.Value.INT8)
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
}
}
@@ -355,6 +363,10 @@ class TensorParser {
builder.cell(address, (Double)number);
else if (builder.type().valueType() == TensorType.Value.FLOAT)
builder.cell(address, (Float)number);
+ else if (builder.type().valueType() == TensorType.Value.BFLOAT16)
+ builder.cell(address, (Float)number);
+ else if (builder.type().valueType() == TensorType.Value.INT8)
+ builder.cell(address, (Float)number);
}
}
@@ -392,6 +404,10 @@ class TensorParser {
builder.cell(address, Double.parseDouble(cellValueString));
else if (cellValueType == TensorType.Value.FLOAT)
builder.cell(address, Float.parseFloat(cellValueString));
+ else if (cellValueType == TensorType.Value.BFLOAT16)
+ builder.cell(address, Float.parseFloat(cellValueString));
+ else if (cellValueType == TensorType.Value.INT8)
+ builder.cell(address, Float.parseFloat(cellValueString));
else
throw new IllegalArgumentException(cellValueType + " is not supported");
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 236e9d31c39..d7cf5bffcfe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -33,7 +33,7 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE("double"), FLOAT("float");
+ DOUBLE("double"), FLOAT("float"), INT8("int8"), BFLOAT16("bfloat16");
private final String id;
@@ -59,6 +59,9 @@ public class TensorType {
public static Value largestOf(Value value1, Value value2) {
if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
+ if (value1 == FLOAT || value2 == FLOAT) return FLOAT;
+ if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT;
+ if (value1 == INT8 || value2 == INT8) return FLOAT;
return FLOAT;
}
@@ -69,8 +72,10 @@ public class TensorType {
switch (valueTypeString) {
case "double" : return Value.DOUBLE;
case "float" : return Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
+ case "bfloat16" : return Value.BFLOAT16;
+ case "int8" : return Value.INT8;
+ default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " +
+ "'bfloat16', or 'int8' but was '" + valueTypeString + "'");
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 0cec09157fb..edb68025d45 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -53,6 +53,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeDoubleCells(tensor, buffer); break;
case FLOAT: encodeFloatCells(tensor, buffer); break;
+ case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break;
+ case INT8: encodeInt8Cells(tensor, buffer); break;
}
}
@@ -66,6 +68,16 @@ public class DenseBinaryFormat implements BinaryFormat {
buffer.putFloat(tensor.getFloat(i));
}
+ private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16));
+ }
+
+ private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.put((byte) tensor.getFloat(i));
+ }
+
@Override
public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
TensorType type;
@@ -111,6 +123,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break;
case FLOAT: decodeFloatCells(sizes, builder, buffer); break;
+ case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break;
+ case INT8: decodeInt8Cells(sizes, builder, buffer); break;
}
}
@@ -124,4 +138,16 @@ public class DenseBinaryFormat implements BinaryFormat {
builder.cellByDirectIndex(i, buffer.getFloat());
}
+ private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16));
+ }
+ }
+
+ private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, (float) buffer.get());
+ }
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index bc247e5561f..7d500614caa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -64,6 +64,9 @@ class MixedBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ case BFLOAT16: encodeCells(buffer, tensor, (val) ->
+ buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break;
+ case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break;
}
}
@@ -127,6 +130,9 @@ class MixedBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ case BFLOAT16: decodeCells(buffer, builder, type, () ->
+ (double)Float.intBitsToFloat(buffer.getShort() << 16)); break;
+ case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index cd671f824fa..160cf660e1b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -52,6 +52,9 @@ class SparseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ case BFLOAT16: encodeCells(buffer, tensor, (val) ->
+ buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break;
+ case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break;
}
}
@@ -102,6 +105,9 @@ class SparseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ case BFLOAT16: decodeCells(buffer, builder, type, () ->
+ (double)Float.intBitsToFloat(buffer.getShort() << 16)); break;
+ case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 5c47572c779..cddb283489c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -29,6 +29,8 @@ public class TypedBinaryFormat {
private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
private static final int FLOAT_VALUE_TYPE = 1;
+ private static final int BFLOAT16_VALUE_TYPE = 2;
+ private static final int INT8_VALUE_TYPE = 3;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
@@ -113,6 +115,8 @@ public class TypedBinaryFormat {
switch (valueType) {
case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break;
case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break;
+ case BFLOAT16: buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE); break;
+ case INT8: buffer.putInt1_4Bytes(INT8_VALUE_TYPE); break;
default:
throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
}
@@ -123,8 +127,11 @@ public class TypedBinaryFormat {
switch (valueType) {
case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE;
case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT;
+ case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16;
+ case INT8_VALUE_TYPE: return TensorType.Value.INT8;
}
- throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal.");
+ throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " +
+ "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal.");
}
private static byte[] asByteArray(GrowableByteBuffer buffer) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 5bd1bbdba37..b47c0873535 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -50,6 +50,35 @@ public class TensorTestCase {
assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class);
assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(),
IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<bfloat16>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<int8>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+ }
+
+ private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) {
+ Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }");
+ Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }");
+ assertEquals(valueType, t1.multiply(t2).type().valueType());
+ assertEquals(valueType, t2.multiply(t1).type().valueType());
+ }
+
+ @Test
+ public void testValueTypeResolving() {
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "float");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16");
+ assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8");
+ assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8");
}
@Test
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index a547f941d8e..caa125dfef7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,8 +96,12 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ assertValueType(TensorType.Value.BFLOAT16, "tensor<bfloat16>(x[])");
+ assertValueType(TensorType.Value.INT8, "tensor<int8>(x[])");
assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
+ assertEquals("tensor<bfloat16>(x[])", TensorType.fromSpec("tensor<bfloat16>(x[])").toString());
+ assertEquals("tensor<int8>(x[])", TensorType.fromSpec("tensor<int8>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 5d1bc7b0c3f..3c79b0c769c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -41,7 +41,7 @@ public class DenseBinaryFormatTestCase {
}
@Test
- public void requireThatDefaultSerializationFormatDoNotChange() {
+ public void requireThatDefaultSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[]{2, // binary format type
2, // dimension count
2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
@@ -54,7 +54,7 @@ public class DenseBinaryFormatTestCase {
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[]{6, // binary format type
1, // float type
2, // dimension count
@@ -68,9 +68,44 @@ public class DenseBinaryFormatTestCase {
}
@Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[]{6, // binary format type
+ 2, // bfloat16 type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 64, 0, // value 1
+ 64, 64, // value 2
+ };
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[]{6, // binary format type
+ 3, // int8 type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 2, // value 1
+ 3, // value 2
+ };
+ Tensor tensor = Tensor.from("tensor<int8>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
public void testSerializationOfDifferentValueTypes() {
+ assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x[],y[]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
+ assertSerialization("tensor<double>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<float>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<int8>(x[2],y[2]):[2, 3, 4, 5]");
}
private void assertSerialization(String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 81de8a9db4c..3ca20661587 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -134,4 +134,19 @@ public class JsonFormatTestCase {
}
}
+ private void assertEncodeDecode(Tensor tensor) {
+ Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encodeWithType(tensor));
+ assertEquals(tensor, decoded);
+ assertEquals(tensor.type(), decoded.type());
+ }
+
+ @Test
+ public void testTensorCellTypes() {
+ assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<double>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<float>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]"));
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index 69ef4922d8d..e9f8c81f21b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.util.Arrays;
import java.util.Optional;
import static org.junit.Assert.assertEquals;
@@ -78,9 +79,70 @@ public class MixedBinaryFormatTestCase {
}
@Test
+ public void requireThatDefaultSerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {3, // binary format type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
+ Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatFloatSerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 1, // float type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 64, 0, 0}; // cell 1
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 2, // bfloat16 type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 3, // int8 type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y[2]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 50b71024ddf..2a622b73513 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase {
}
@Test
- public void requireThatSerializationFormatDoNotChange() {
+ public void requireThatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {
5, // binary format type
1, // float type
@@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase {
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 2, // bfloat16 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 3, // int8 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {