summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHÃ¥vard Pettersen <3535158+havardpe@users.noreply.github.com>2021-04-08 15:41:24 +0200
committerGitHub <noreply@github.com>2021-04-08 15:41:24 +0200
commit449f608d2770977732fd3de4a24d52354a84c747 (patch)
tree856ef66e6416745dd6d633b72c9a6d407fc314fb
parent79f00a5e8536e7c2956daccedd1c4be463eb933e (diff)
parent4a33700665782a9ac22522dc5a8f8138f07b5b73 (diff)
Merge pull request #17307 from vespa-engine/lesters/new-tensor-cell-types-java
Add bfloat16 and int8 tensor cell types in Java
-rw-r--r--eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp31
-rw-r--r--eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json22
-rw-r--r--vespajlib/abi-spec.json6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java58
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java14
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java50
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java29
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java17
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java23
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java39
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java15
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java62
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java42
18 files changed, 362 insertions, 92 deletions
diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp
index 6e882fc3d9d..974f95a2add 100644
--- a/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp
+++ b/eval/src/apps/make_tensor_binary_format_test_spec/make_tensor_binary_format_test_spec.cpp
@@ -3,6 +3,8 @@
#include <vespa/vespalib/data/slime/slime.h>
#include <vespa/vespalib/objects/nbostream.h>
#include <vespa/vespalib/util/stringfmt.h>
+#include <vespa/vespalib/util/bfloat16.h>
+#include <vespa/eval/eval/int8float.h>
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/value_type.h>
#include <vespa/eval/eval/test/test_io.h>
@@ -20,14 +22,20 @@ using Dict = std::vector<vespalib::string>;
template <typename T> std::vector<bool> with_cell_type_opts();
template <> std::vector<bool> with_cell_type_opts<double>() { return {false, true}; }
template <> std::vector<bool> with_cell_type_opts<float>() { return {true}; }
+template <> std::vector<bool> with_cell_type_opts<BFloat16>() { return {true}; }
+template <> std::vector<bool> with_cell_type_opts<Int8Float>() { return {true}; }
template <typename T> uint8_t cell_type_id();
template <> uint8_t cell_type_id<double>() { return 0; }
template <> uint8_t cell_type_id<float>() { return 1; }
+template <> uint8_t cell_type_id<BFloat16>() { return 2; }
+template <> uint8_t cell_type_id<Int8Float>() { return 3; }
template <typename T> const char *cell_type_str();
template <> const char *cell_type_str<double>() { return ""; }
template <> const char *cell_type_str<float>() { return "<float>"; }
+template <> const char *cell_type_str<BFloat16>() { return "<bfloat16>"; }
+template <> const char *cell_type_str<Int8Float>() { return "<int8>"; }
template <typename T> nbostream make_sparse(bool with_cell_type) {
nbostream data;
@@ -62,7 +70,8 @@ template <typename T> nbostream make_mixed(bool with_cell_type) {
return data;
}
-void set_tensor(Cursor &test, const TensorSpec &spec) {
+void set_tensor(Cursor &test, const TensorSpec &spec_in) {
+ auto spec = spec_in.normalize();
const Inspector &old_tensor = test["tensor"];
if (old_tensor.valid()) {
TensorSpec old_spec = TensorSpec::from_slime(old_tensor);
@@ -183,8 +192,8 @@ void make_vector_test(Cursor &test, size_t x_size) {
for (size_t x = 0; x < x_size; ++x) {
double value = val(x);
spec.add({{"x", x}}, value);
- dense << static_cast<T>(value);
- mixed << static_cast<T>(value);
+ dense << T(value);
+ mixed << T(value);
}
set_tensor(test, spec);
add_binary(test, {dense, mixed});
@@ -212,8 +221,8 @@ void make_matrix_test(Cursor &test, size_t x_size, size_t y_size) {
for (size_t y = 0; y < y_size; ++y) {
double value = mix({val(x), val(y)});
spec.add({{"x", x}, {"y", y}}, value);
- dense << static_cast<T>(value);
- mixed << static_cast<T>(value);
+ dense << T(value);
+ mixed << T(value);
}
}
set_tensor(test, spec);
@@ -245,8 +254,8 @@ void make_map_test(Cursor &test, const Dict &x_dict_in) {
spec.add({{"x", x}}, value);
sparse.writeSmallString(x);
mixed.writeSmallString(x);
- sparse << static_cast<T>(value);
- mixed << static_cast<T>(value);
+ sparse << T(value);
+ mixed << T(value);
}
set_tensor(test, spec);
add_binary(test, {sparse, mixed});
@@ -285,8 +294,8 @@ void make_mesh_test(Cursor &test, const Dict &x_dict_in, const vespalib::string
sparse.writeSmallString(y);
mixed.writeSmallString(x);
mixed.writeSmallString(y);
- sparse << static_cast<T>(value);
- mixed << static_cast<T>(value);
+ sparse << T(value);
+ mixed << T(value);
}
set_tensor(test, spec);
add_binary(test, {sparse, mixed});
@@ -326,7 +335,7 @@ void make_vector_map_test(Cursor &test,
for (size_t idx = 0; idx < indexed_size; ++idx) {
double value = mix({val(label), val(idx)});
spec.add({{mapped_name, label}, {indexed_name, idx}}, value);
- mixed << static_cast<T>(value);
+ mixed << T(value);
}
}
set_tensor(test, spec);
@@ -360,6 +369,8 @@ void make_tests(test::TestWriter &writer) {
make_number_test(writer.create(), 42.0);
make_typed_tests<double>(writer);
make_typed_tests<float>(writer);
+ make_typed_tests<BFloat16>(writer);
+ make_typed_tests<Int8Float>(writer);
}
int main(int, char **) {
diff --git a/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json b/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json
index f6b535e071a..b7710eadf5d 100644
--- a/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json
+++ b/eval/src/apps/make_tensor_binary_format_test_spec/test_spec.json
@@ -20,4 +20,24 @@
{"tensor":{"type":"tensor<float>(x[10],y{})","cells":[]},"binary":["0x07010101790101780A00"]}
{"tensor":{"type":"tensor<float>(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x070101017801017903020161413000004140000041500000016241A8000041B0000041B80000","0x07010101780101790302016241A8000041B0000041B800000161413000004140000041500000"]}
{"tensor":{"type":"tensor<float>(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x070101017901017803020161413000004140000041500000016241A8000041B0000041B80000","0x07010101790101780302016241A8000041B0000041B800000161413000004140000041500000"]}
-{"num_tests":22}
+{"tensor":{"type":"tensor<bfloat16>(x[3])","cells":[{"address":{"x":0},"value":1},{"address":{"x":1},"value":2},{"address":{"x":2},"value":3}]},"binary":["0x0602010178033F8040004040","0x070200010178033F8040004040"]}
+{"tensor":{"type":"tensor<bfloat16>(x[2],y[3])","cells":[{"address":{"x":0,"y":0},"value":11},{"address":{"x":0,"y":1},"value":12},{"address":{"x":0,"y":2},"value":13},{"address":{"x":1,"y":0},"value":21},{"address":{"x":1,"y":1},"value":22},{"address":{"x":1,"y":2},"value":23}]},"binary":["0x06020201780201790341304140415041A841B041B8","0x0702000201780201790341304140415041A841B041B8"]}
+{"tensor":{"type":"tensor<bfloat16>(x{})","cells":[]},"binary":["0x050201017800","0x07020101780000"]}
+{"tensor":{"type":"tensor<bfloat16>(x{})","cells":[{"address":{"x":"a"},"value":1},{"address":{"x":"b"},"value":2},{"address":{"x":"c"},"value":3}]},"binary":["0x05020101780301613F800162400001634040","0x0702010178000301613F800162400001634040","0x05020101780301613F800163404001624000","0x0702010178000301613F800163404001624000","0x0502010178030162400001613F8001634040","0x070201017800030162400001613F8001634040","0x050201017803016240000163404001613F80","0x07020101780003016240000163404001613F80","0x0502010178030163404001613F8001624000","0x070201017800030163404001613F8001624000","0x050201017803016340400162400001613F80","0x07020101780003016340400162400001613F80"]}
+{"tensor":{"type":"tensor<bfloat16>(x{},y{})","cells":[]},"binary":["0x0502020178017900","0x070202017801790000"]}
+{"tensor":{"type":"tensor<bfloat16>(x{},y{})","cells":[{"address":{"x":"bar","y":"a"},"value":21},{"address":{"x":"foo","y":"a"},"value":11}]},"binary":["0x050202017801790203666F6F0161413003626172016141A8","0x07020201780179000203666F6F0161413003626172016141A8","0x050202017801790203626172016141A803666F6F01614130","0x07020201780179000203626172016141A803666F6F01614130"]}
+{"tensor":{"type":"tensor<bfloat16>(x{},y[10])","cells":[]},"binary":["0x07020101780101790A00"]}
+{"tensor":{"type":"tensor<bfloat16>(x[10],y{})","cells":[]},"binary":["0x07020101790101780A00"]}
+{"tensor":{"type":"tensor<bfloat16>(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x070201017801017903020161413041404150016241A841B041B8","0x07020101780101790302016241A841B041B80161413041404150"]}
+{"tensor":{"type":"tensor<bfloat16>(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x070201017901017803020161413041404150016241A841B041B8","0x07020101790101780302016241A841B041B80161413041404150"]}
+{"tensor":{"type":"tensor<int8>(x[3])","cells":[{"address":{"x":0},"value":1},{"address":{"x":1},"value":2},{"address":{"x":2},"value":3}]},"binary":["0x060301017803010203","0x07030001017803010203"]}
+{"tensor":{"type":"tensor<int8>(x[2],y[3])","cells":[{"address":{"x":0,"y":0},"value":11},{"address":{"x":0,"y":1},"value":12},{"address":{"x":0,"y":2},"value":13},{"address":{"x":1,"y":0},"value":21},{"address":{"x":1,"y":1},"value":22},{"address":{"x":1,"y":2},"value":23}]},"binary":["0x0603020178020179030B0C0D151617","0x070300020178020179030B0C0D151617"]}
+{"tensor":{"type":"tensor<int8>(x{})","cells":[]},"binary":["0x050301017800","0x07030101780000"]}
+{"tensor":{"type":"tensor<int8>(x{})","cells":[{"address":{"x":"a"},"value":1},{"address":{"x":"b"},"value":2},{"address":{"x":"c"},"value":3}]},"binary":["0x050301017803016101016202016303","0x07030101780003016101016202016303","0x050301017803016101016303016202","0x07030101780003016101016303016202","0x050301017803016202016101016303","0x07030101780003016202016101016303","0x050301017803016202016303016101","0x07030101780003016202016303016101","0x050301017803016303016101016202","0x07030101780003016303016101016202","0x050301017803016303016202016101","0x07030101780003016303016202016101"]}
+{"tensor":{"type":"tensor<int8>(x{},y{})","cells":[]},"binary":["0x0503020178017900","0x070302017801790000"]}
+{"tensor":{"type":"tensor<int8>(x{},y{})","cells":[{"address":{"x":"bar","y":"a"},"value":21},{"address":{"x":"foo","y":"a"},"value":11}]},"binary":["0x050302017801790203666F6F01610B03626172016115","0x07030201780179000203666F6F01610B03626172016115","0x05030201780179020362617201611503666F6F01610B","0x0703020178017900020362617201611503666F6F01610B"]}
+{"tensor":{"type":"tensor<int8>(x{},y[10])","cells":[]},"binary":["0x07030101780101790A00"]}
+{"tensor":{"type":"tensor<int8>(x[10],y{})","cells":[]},"binary":["0x07030101790101780A00"]}
+{"tensor":{"type":"tensor<int8>(x{},y[3])","cells":[{"address":{"x":"a","y":0},"value":11},{"address":{"x":"a","y":1},"value":12},{"address":{"x":"a","y":2},"value":13},{"address":{"x":"b","y":0},"value":21},{"address":{"x":"b","y":1},"value":22},{"address":{"x":"b","y":2},"value":23}]},"binary":["0x0703010178010179030201610B0C0D0162151617","0x07030101780101790302016215161701610B0C0D"]}
+{"tensor":{"type":"tensor<int8>(x[3],y{})","cells":[{"address":{"x":0,"y":"a"},"value":11},{"address":{"x":0,"y":"b"},"value":21},{"address":{"x":1,"y":"a"},"value":12},{"address":{"x":1,"y":"b"},"value":22},{"address":{"x":2,"y":"a"},"value":13},{"address":{"x":2,"y":"b"},"value":23}]},"binary":["0x0703010179010178030201610B0C0D0162151617","0x07030101790101780302016215161701610B0C0D"]}
+{"num_tests":42}
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e51569da988..9ad2c55f7e3 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1411,7 +1411,9 @@
],
"fields": [
"public static final enum com.yahoo.tensor.TensorType$Value DOUBLE",
- "public static final enum com.yahoo.tensor.TensorType$Value FLOAT"
+ "public static final enum com.yahoo.tensor.TensorType$Value FLOAT",
+ "public static final enum com.yahoo.tensor.TensorType$Value INT8",
+ "public static final enum com.yahoo.tensor.TensorType$Value BFLOAT16"
]
},
"com.yahoo.tensor.TensorType": {
@@ -3463,4 +3465,4 @@
],
"fields": []
}
-}
+} \ No newline at end of file
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index dc17c657db9..c369fe96562 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -221,12 +221,14 @@ public abstract class IndexedTensor implements Tensor {
b.append("[");
// value
- if (tensor.type().valueType() == TensorType.Value.DOUBLE)
- b.append(tensor.get(index));
- else if (tensor.type().valueType() == TensorType.Value.FLOAT)
- b.append(tensor.getFloat(index));
- else
- throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ switch (tensor.type().valueType()) {
+ case DOUBLE: b.append(tensor.get(index)); break;
+ case FLOAT: b.append(tensor.getFloat(index)); break;
+ case BFLOAT16: b.append(tensor.getFloat(index)); break;
+ case INT8: b.append(tensor.getFloat(index)); break;
+ default:
+ throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ }
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
@@ -292,13 +294,14 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ }
}
/**
@@ -312,13 +315,14 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ }
}
/**
@@ -332,13 +336,15 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
+ }
}
private static void validateSizes(DimensionSizes sizes, int length) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index f608aead347..606509bbfd8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -529,12 +529,14 @@ public class MixedTensor implements Tensor {
b.append("[");
// value
- if (type.valueType() == TensorType.Value.DOUBLE)
- b.append(getDouble(subspaceIndex, index, tensor));
- else if (tensor.type().valueType() == TensorType.Value.FLOAT)
- b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats
- else
- throw new IllegalStateException("Unexpected value type " + type.valueType());
+ switch (type.valueType()) {
+ case DOUBLE: b.append(getDouble(subspaceIndex, index, tensor)); break;
+ case FLOAT: b.append(getDouble(subspaceIndex, index, tensor)); break; // TODO: Really use floats
+ case BFLOAT16: b.append(getDouble(subspaceIndex, index, tensor)); break;
+ case INT8: b.append(getDouble(subspaceIndex, index, tensor)); break;
+ default:
+ throw new IllegalStateException("Unexpected value type " + type.valueType());
+ }
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index becec1a4493..0a1d9b6cf6e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -197,12 +197,14 @@ class TensorParser {
try {
String cellValueString = string.substring(position, nextNumberEnd);
try {
- if (cellValueType == TensorType.Value.DOUBLE)
- return Double.parseDouble(cellValueString);
- else if (cellValueType == TensorType.Value.FLOAT)
- return Float.parseFloat(cellValueString);
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
+ switch (cellValueType) {
+ case DOUBLE: return Double.parseDouble(cellValueString);
+ case FLOAT: return Float.parseFloat(cellValueString);
+ case BFLOAT16: return Float.parseFloat(cellValueString);
+ case INT8: return Float.parseFloat(cellValueString);
+ default:
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
} catch (NumberFormatException e) {
throw new IllegalArgumentException("At value position " + position + ": '" +
cellValueString + "' is not a valid " + cellValueType);
@@ -287,12 +289,13 @@ class TensorParser {
protected void consumeNumber() {
Number number = consumeNumber(builder.type().valueType());
- if (builder.type().valueType() == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number);
- else if (builder.type().valueType() == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
+ switch (builder.type().valueType()) {
+ case DOUBLE: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); break;
+ case FLOAT: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ case BFLOAT16: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ case INT8: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ }
}
-
}
/**
@@ -351,12 +354,13 @@ class TensorParser {
private void consumeNumber(TensorAddress address) {
Number number = consumeNumber(builder.type().valueType());
- if (builder.type().valueType() == TensorType.Value.DOUBLE)
- builder.cell(address, (Double)number);
- else if (builder.type().valueType() == TensorType.Value.FLOAT)
- builder.cell(address, (Float)number);
+ switch (builder.type().valueType()) {
+ case DOUBLE: builder.cell(address, (Double)number); break;
+ case FLOAT: builder.cell(address, (Float)number); break;
+ case BFLOAT16: builder.cell(address, (Float)number); break;
+ case INT8: builder.cell(address, (Float)number); break;
+ }
}
-
}
private static class MappedValueParser extends ValueParser {
@@ -388,12 +392,14 @@ class TensorParser {
TensorType.Value cellValueType = builder.type().valueType();
String cellValueString = string.substring(position, valueEnd).trim();
try {
- if (cellValueType == TensorType.Value.DOUBLE)
- builder.cell(address, Double.parseDouble(cellValueString));
- else if (cellValueType == TensorType.Value.FLOAT)
- builder.cell(address, Float.parseFloat(cellValueString));
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
+ switch (cellValueType) {
+ case DOUBLE: builder.cell(address, Double.parseDouble(cellValueString)); break;
+ case FLOAT: builder.cell(address, Float.parseFloat(cellValueString)); break;
+ case BFLOAT16: builder.cell(address, Float.parseFloat(cellValueString)); break;
+ case INT8: builder.cell(address, Float.parseFloat(cellValueString)); break;
+ default:
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 236e9d31c39..0f67c25337b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -33,7 +33,7 @@ public class TensorType {
public enum Value {
// Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
- DOUBLE("double"), FLOAT("float");
+ DOUBLE("double"), FLOAT("float"), INT8("int8"), BFLOAT16("bfloat16");
private final String id;
@@ -59,19 +59,22 @@ public class TensorType {
public static Value largestOf(Value value1, Value value2) {
if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
- return FLOAT;
+ if (value1 == FLOAT || value2 == FLOAT) return FLOAT;
+ if (value1 == BFLOAT16 || value2 == BFLOAT16) return BFLOAT16;
+ return INT8;
}
@Override
public String toString() { return name().toLowerCase(); }
public static Value fromId(String valueTypeString) {
- switch (valueTypeString) {
- case "double" : return Value.DOUBLE;
- case "float" : return Value.FLOAT;
- default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
- " but was '" + valueTypeString + "'");
+ for(Value value : Value.values()) {
+ if (value.id.equals(valueTypeString)) {
+ return value;
+ }
}
+ throw new IllegalArgumentException("Value type must be either 'double', 'float', " +
+ "'bfloat16', or 'int8' but was '" + valueTypeString + "'");
}
};
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index d052e383c85..b6ea0d04a50 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -64,12 +64,13 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
TensorType.Value fromValueType = tensor.type().valueType();
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Tensor.Cell cell = i.next();
- if (fromValueType == TensorType.Value.FLOAT) {
- builder.cell(cell.getKey(), cell.getFloatValue());
- } else if (fromValueType == TensorType.Value.DOUBLE) {
- builder.cell(cell.getKey(), cell.getDoubleValue());
- } else {
- builder.cell(cell.getKey(), cell.getValue());
+ switch (fromValueType) {
+ case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break;
+ case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break;
+ case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break;
+ case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break;
+ default:
+ builder.cell(cell.getKey(), cell.getValue());
}
}
return builder.build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index 0cec09157fb..1567c95c9fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -7,10 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.Iterator;
import java.util.Optional;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
/**
* Implementation of a dense binary format for a tensor on the form:
@@ -53,6 +50,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeDoubleCells(tensor, buffer); break;
case FLOAT: encodeFloatCells(tensor, buffer); break;
+ case BFLOAT16: encodeBFloat16Cells(tensor, buffer); break;
+ case INT8: encodeInt8Cells(tensor, buffer); break;
}
}
@@ -66,6 +65,16 @@ public class DenseBinaryFormat implements BinaryFormat {
buffer.putFloat(tensor.getFloat(i));
}
+ private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i)));
+ }
+
+ private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
+ for (int i = 0; i < tensor.size(); i++)
+ buffer.put((byte) tensor.getFloat(i));
+ }
+
@Override
public Tensor decode(Optional<TensorType> optionalType, GrowableByteBuffer buffer) {
TensorType type;
@@ -111,6 +120,8 @@ public class DenseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeDoubleCells(sizes, builder, buffer); break;
case FLOAT: decodeFloatCells(sizes, builder, buffer); break;
+ case BFLOAT16: decodeBFloat16Cells(sizes, builder, buffer); break;
+ case INT8: decodeInt8Cells(sizes, builder, buffer); break;
}
}
@@ -124,4 +135,16 @@ public class DenseBinaryFormat implements BinaryFormat {
builder.cellByDirectIndex(i, buffer.getFloat());
}
+ private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort()));
+ }
+ }
+
+ private void decodeInt8Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
+ for (long i = 0; i < sizes.totalSize(); i++) {
+ builder.cellByDirectIndex(i, (float) buffer.get());
+ }
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index bc247e5561f..6cb9a63fe68 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -64,6 +64,9 @@ class MixedBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ case BFLOAT16: encodeCells(buffer, tensor, (val) ->
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break;
+ case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break;
}
}
@@ -127,6 +130,9 @@ class MixedBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ case BFLOAT16: decodeCells(buffer, builder, type, () ->
+ (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break;
+ case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index cd671f824fa..763b722a90c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -52,6 +52,9 @@ class SparseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
+ case BFLOAT16: encodeCells(buffer, tensor, (val) ->
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break;
+ case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break;
}
}
@@ -102,6 +105,9 @@ class SparseBinaryFormat implements BinaryFormat {
switch (serializationValueType) {
case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
+ case BFLOAT16: decodeCells(buffer, builder, type, () ->
+ (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break;
+ case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index 5c47572c779..be04be80ed9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -29,6 +29,8 @@ public class TypedBinaryFormat {
private static final int DOUBLE_VALUE_TYPE = 0; // Not encoded as it is default, and you know the type when deserializing
private static final int FLOAT_VALUE_TYPE = 1;
+ private static final int BFLOAT16_VALUE_TYPE = 2;
+ private static final int INT8_VALUE_TYPE = 3;
public static byte[] encode(Tensor tensor) {
GrowableByteBuffer buffer = new GrowableByteBuffer();
@@ -113,6 +115,8 @@ public class TypedBinaryFormat {
switch (valueType) {
case DOUBLE: buffer.putInt1_4Bytes(DOUBLE_VALUE_TYPE); break;
case FLOAT: buffer.putInt1_4Bytes(FLOAT_VALUE_TYPE); break;
+ case BFLOAT16: buffer.putInt1_4Bytes(BFLOAT16_VALUE_TYPE); break;
+ case INT8: buffer.putInt1_4Bytes(INT8_VALUE_TYPE); break;
default:
throw new IllegalArgumentException("Attempt to encode unknown tensor value type: " + valueType);
}
@@ -123,8 +127,11 @@ public class TypedBinaryFormat {
switch (valueType) {
case DOUBLE_VALUE_TYPE: return TensorType.Value.DOUBLE;
case FLOAT_VALUE_TYPE: return TensorType.Value.FLOAT;
+ case BFLOAT16_VALUE_TYPE: return TensorType.Value.BFLOAT16;
+ case INT8_VALUE_TYPE: return TensorType.Value.INT8;
}
- throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. Only 0(double), or 1(float) are legal.");
+ throw new IllegalArgumentException("Received tensor value type '" + valueType + "'. " +
+ "Only 0(double), 1(float), 2(bfloat16), or 3(int8) is legal.");
}
private static byte[] asByteArray(GrowableByteBuffer buffer) {
@@ -134,4 +141,12 @@ public class TypedBinaryFormat {
return result;
}
+ static short bFloat16BitsFromFloat(float val) {
+ return (short) (Float.floatToRawIntBits(val) >>> 16);
+ }
+
+ static float floatFromBFloat16Bits(short bits) {
+ return Float.intBitsToFloat(bits << 16);
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 5bd1bbdba37..572dc433d71 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -50,6 +50,29 @@ public class TensorTestCase {
assertEquals(Tensor.from("tensor<float>(x[1]):{{x:0}:5}").getClass(), IndexedFloatTensor.class);
assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[1])")).cell(5.0, 0).build().getClass(),
IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<bfloat16>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+
+ assertEquals(Tensor.from("tensor<int8>(x[1]):[5]").getClass(), IndexedFloatTensor.class);
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[1])")).cell(5.0, 0).build().getClass(),
+ IndexedFloatTensor.class);
+ }
+
+ private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) {
+ Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }");
+ Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }");
+ assertEquals(valueType, t1.multiply(t2).type().valueType());
+ assertEquals(valueType, t2.multiply(t1).type().valueType());
+ }
+
+ @Test
+ public void testValueTypeResolving() {
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double");
+ assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float");
+ assertCellTypeResult(TensorType.Value.FLOAT, "float", "float");
+ // Test bfloat16 and int8 when we have proper cell type resolving in place.
}
@Test
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
index a547f941d8e..caa125dfef7 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTypeTestCase.java
@@ -96,8 +96,12 @@ public class TensorTypeTestCase {
assertValueType(TensorType.Value.DOUBLE, "tensor(x[])");
assertValueType(TensorType.Value.DOUBLE, "tensor<double>(x[])");
assertValueType(TensorType.Value.FLOAT, "tensor<float>(x[])");
+ assertValueType(TensorType.Value.BFLOAT16, "tensor<bfloat16>(x[])");
+ assertValueType(TensorType.Value.INT8, "tensor<int8>(x[])");
assertEquals("tensor(x[])", TensorType.fromSpec("tensor<double>(x[])").toString());
assertEquals("tensor<float>(x[])", TensorType.fromSpec("tensor<float>(x[])").toString());
+ assertEquals("tensor<bfloat16>(x[])", TensorType.fromSpec("tensor<bfloat16>(x[])").toString());
+ assertEquals("tensor<int8>(x[])", TensorType.fromSpec("tensor<int8>(x[])").toString());
}
private static void assertTensorType(String typeSpec) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
index 5d1bc7b0c3f..3c79b0c769c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/DenseBinaryFormatTestCase.java
@@ -41,7 +41,7 @@ public class DenseBinaryFormatTestCase {
}
@Test
- public void requireThatDefaultSerializationFormatDoNotChange() {
+ public void requireThatDefaultSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[]{2, // binary format type
2, // dimension count
2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
@@ -54,7 +54,7 @@ public class DenseBinaryFormatTestCase {
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[]{6, // binary format type
1, // float type
2, // dimension count
@@ -68,9 +68,44 @@ public class DenseBinaryFormatTestCase {
}
@Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[]{6, // binary format type
+ 2, // bfloat16 type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 64, 0, // value 1
+ 64, 64, // value 2
+ };
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[]{6, // binary format type
+ 3, // int8 type
+ 2, // dimension count
+ 2, (byte) 'x', (byte) 'y', 2, // dimension xy with size
+ 1, (byte) 'z', 1, // dimension z with size
+ 2, // value 1
+ 3, // value 2
+ };
+ Tensor tensor = Tensor.from("tensor<int8>(xy[],z[]):{{xy:0,z:0}:2.0,{xy:1,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
public void testSerializationOfDifferentValueTypes() {
+ assertSerialization("tensor(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<double>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x[],y[]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x[],y[]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
+ assertSerialization("tensor<double>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<float>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 4.0, 5.0]");
+ assertSerialization("tensor<int8>(x[2],y[2]):[2, 3, 4, 5]");
}
private void assertSerialization(String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 81de8a9db4c..3ca20661587 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -134,4 +134,19 @@ public class JsonFormatTestCase {
}
}
+ private void assertEncodeDecode(Tensor tensor) {
+ Tensor decoded = JsonFormat.decode(tensor.type(), JsonFormat.encodeWithType(tensor));
+ assertEquals(tensor, decoded);
+ assertEquals(tensor.type(), decoded.type());
+ }
+
+ @Test
+ public void testTensorCellTypes() {
+ assertEncodeDecode(Tensor.from("tensor(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<double>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<float>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<bfloat16>(x[2],y[2]):[2.0, 3.0, 5.0 ,8.0]"));
+ assertEncodeDecode(Tensor.from("tensor<int8>(x[2],y[2]):[2,3,5,8]"));
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
index 69ef4922d8d..e9f8c81f21b 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/MixedBinaryFormatTestCase.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
+import java.util.Arrays;
import java.util.Optional;
import static org.junit.Assert.assertEquals;
@@ -78,9 +79,70 @@ public class MixedBinaryFormatTestCase {
}
@Test
+ public void requireThatDefaultSerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {3, // binary format type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
+ Tensor tensor = Tensor.from("tensor(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatFloatSerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 1, // float type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, 0, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 64, 0, 0}; // cell 1
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 2, // bfloat16 type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {7, // binary format type
+ 3, // int8 type
+ 1, // number of sparse dimensions
+ 2, (byte)'x', (byte)'y', // name of sparse dimension
+ 1, // number of dense dimensions
+ 1, (byte)'z', 1, // name and size of dense dimension
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z[1]):{{xy:ab,z:0}:2.0,{xy:cd,z:0}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y[2]):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y[2]):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
index 50b71024ddf..2a622b73513 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/SparseBinaryFormatTestCase.java
@@ -55,19 +55,19 @@ public class SparseBinaryFormatTestCase {
}
@Test
- public void requireThatSerializationFormatDoNotChange() {
+ public void requireThatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {1, // binary format type
2, // num dimensions
2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, 0, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 8, 0, 0, 0, 0, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
- public void requireThatFloatSerializationFormatDoNotChange() {
+ public void requireThatFloatSerializationFormatDoesNotChange() {
byte[] encodedTensor = new byte[] {
5, // binary format type
1, // float type
@@ -76,14 +76,44 @@ public class SparseBinaryFormatTestCase {
2, // num cells,
2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, 0, 0, // cell 0
2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64, 0, 0}; // cell 1
- assertEquals(Arrays.toString(encodedTensor),
- Arrays.toString(TypedBinaryFormat.encode(Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}"))));
+ Tensor tensor = Tensor.from("tensor<float>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatBFloat16SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 2, // bfloat16 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 64, 0, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 64, 64}; // cell 1
+ Tensor tensor = Tensor.from("tensor<bfloat16>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
+ }
+
+ @Test
+ public void requireThatInt8SerializationFormatDoesNotChange() {
+ byte[] encodedTensor = new byte[] {
+ 5, // binary format type
+ 3, // int8 type
+ 2, // num dimensions
+ 2, (byte)'x', (byte)'y', 1, (byte)'z', // dimensions
+ 2, // num cells,
+ 2, (byte)'a', (byte)'b', 1, (byte)'e', 2, // cell 0
+ 2, (byte)'c', (byte)'d', 1, (byte)'e', 3}; // cell 1
+ Tensor tensor = Tensor.from("tensor<int8>(xy{},z{}):{{xy:ab,z:e}:2.0,{xy:cd,z:e}:3.0}");
+ assertEquals(Arrays.toString(encodedTensor), Arrays.toString(TypedBinaryFormat.encode(tensor)));
}
@Test
public void testSerializationOfDifferentValueTypes() {
assertSerialization("tensor<double>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
assertSerialization("tensor<float>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<bfloat16>(x{},y{}):{{x:0,y:0}:2.0, {x:0,y:1}:3.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0}");
+ assertSerialization("tensor<int8>(x{},y{}):{{x:0,y:0}:2, {x:0,y:1}:3, {x:1,y:0}:4, {x:1,y:1}:5}");
}
private void assertSerialization(String tensorString) {