aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-08 15:05:01 +0200
committerLester Solbakken <lesters@oath.com>2021-04-08 15:05:01 +0200
commitcb1ea8c336adb05c90200468b25fe4ab89ee803c (patch)
tree1d0b5da3bf31ec8be03c38818ad9a08d592de120
parent8ae580ad9d7a2b7bd2420c5f543a21ac0d7a7c9a (diff)
Resolve feedback from PR review
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java74
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java66
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java17
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java8
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java8
10 files changed, 98 insertions, 126 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 9f3d7c01c6b..c369fe96562 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -221,16 +221,14 @@ public abstract class IndexedTensor implements Tensor {
b.append("[");
// value
- if (tensor.type().valueType() == TensorType.Value.DOUBLE)
- b.append(tensor.get(index));
- else if (tensor.type().valueType() == TensorType.Value.FLOAT)
- b.append(tensor.getFloat(index));
- else if (tensor.type().valueType() == TensorType.Value.BFLOAT16)
- b.append(tensor.getFloat(index));
- else if (tensor.type().valueType() == TensorType.Value.INT8)
- b.append(tensor.getFloat(index));
- else
- throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ switch (tensor.type().valueType()) {
+ case DOUBLE: b.append(tensor.get(index)); break;
+ case FLOAT: b.append(tensor.getFloat(index)); break;
+ case BFLOAT16: b.append(tensor.getFloat(index)); break;
+ case INT8: b.append(tensor.getFloat(index)); break;
+ default:
+ throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
+ }
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
@@ -296,17 +294,14 @@ public abstract class IndexedTensor implements Tensor {
*/
public static Builder of(TensorType type, DimensionSizes sizes) {
validate(type, sizes);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.BFLOAT16)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.INT8)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes);
+ }
}
/**
@@ -320,17 +315,14 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, float[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
-
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
- else if (type.valueType() == TensorType.Value.BFLOAT16)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.INT8)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes, values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes).fill(values); // Default
+ }
}
/**
@@ -344,17 +336,15 @@ public abstract class IndexedTensor implements Tensor {
public static Builder of(TensorType type, DimensionSizes sizes, double[] values) {
validate(type, sizes);
validateSizes(sizes, values.length);
+ switch (type.valueType()) {
+ case DOUBLE: return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
+ case FLOAT: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case BFLOAT16: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ case INT8: return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
+ default:
+ return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
- if (type.valueType() == TensorType.Value.FLOAT)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.BFLOAT16)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.INT8)
- return new IndexedFloatTensor.BoundFloatBuilder(type, sizes).fill(values);
- else if (type.valueType() == TensorType.Value.DOUBLE)
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values);
- else
- return new IndexedDoubleTensor.BoundDoubleBuilder(type, sizes, values); // Default
+ }
}
private static void validateSizes(DimensionSizes sizes, int length) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index be5f4143f54..606509bbfd8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -529,16 +529,14 @@ public class MixedTensor implements Tensor {
b.append("[");
// value
- if (type.valueType() == TensorType.Value.DOUBLE)
- b.append(getDouble(subspaceIndex, index, tensor));
- else if (tensor.type().valueType() == TensorType.Value.FLOAT)
- b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats
- else if (tensor.type().valueType() == TensorType.Value.BFLOAT16)
- b.append(getDouble(subspaceIndex, index, tensor));
- else if (tensor.type().valueType() == TensorType.Value.INT8)
- b.append(getDouble(subspaceIndex, index, tensor));
- else
- throw new IllegalStateException("Unexpected value type " + type.valueType());
+ switch (type.valueType()) {
+ case DOUBLE: b.append(getDouble(subspaceIndex, index, tensor)); break;
+ case FLOAT: b.append(getDouble(subspaceIndex, index, tensor)); break; // TODO: Really use floats
+ case BFLOAT16: b.append(getDouble(subspaceIndex, index, tensor)); break;
+ case INT8: b.append(getDouble(subspaceIndex, index, tensor)); break;
+ default:
+ throw new IllegalStateException("Unexpected value type " + type.valueType());
+ }
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 126fee878ab..0a1d9b6cf6e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -197,16 +197,14 @@ class TensorParser {
try {
String cellValueString = string.substring(position, nextNumberEnd);
try {
- if (cellValueType == TensorType.Value.DOUBLE)
- return Double.parseDouble(cellValueString);
- else if (cellValueType == TensorType.Value.FLOAT)
- return Float.parseFloat(cellValueString);
- else if (cellValueType == TensorType.Value.BFLOAT16)
- return Float.parseFloat(cellValueString);
- else if (cellValueType == TensorType.Value.INT8)
- return Float.parseFloat(cellValueString);
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
+ switch (cellValueType) {
+ case DOUBLE: return Double.parseDouble(cellValueString);
+ case FLOAT: return Float.parseFloat(cellValueString);
+ case BFLOAT16: return Float.parseFloat(cellValueString);
+ case INT8: return Float.parseFloat(cellValueString);
+ default:
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
} catch (NumberFormatException e) {
throw new IllegalArgumentException("At value position " + position + ": '" +
cellValueString + "' is not a valid " + cellValueType);
@@ -291,16 +289,13 @@ class TensorParser {
protected void consumeNumber() {
Number number = consumeNumber(builder.type().valueType());
- if (builder.type().valueType() == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number);
- else if (builder.type().valueType() == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
- else if (builder.type().valueType() == TensorType.Value.BFLOAT16)
- builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
- else if (builder.type().valueType() == TensorType.Value.INT8)
- builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
+ switch (builder.type().valueType()) {
+ case DOUBLE: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); break;
+ case FLOAT: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ case BFLOAT16: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ case INT8: builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); break;
+ }
}
-
}
/**
@@ -359,16 +354,13 @@ class TensorParser {
private void consumeNumber(TensorAddress address) {
Number number = consumeNumber(builder.type().valueType());
- if (builder.type().valueType() == TensorType.Value.DOUBLE)
- builder.cell(address, (Double)number);
- else if (builder.type().valueType() == TensorType.Value.FLOAT)
- builder.cell(address, (Float)number);
- else if (builder.type().valueType() == TensorType.Value.BFLOAT16)
- builder.cell(address, (Float)number);
- else if (builder.type().valueType() == TensorType.Value.INT8)
- builder.cell(address, (Float)number);
+ switch (builder.type().valueType()) {
+ case DOUBLE: builder.cell(address, (Double)number); break;
+ case FLOAT: builder.cell(address, (Float)number); break;
+ case BFLOAT16: builder.cell(address, (Float)number); break;
+ case INT8: builder.cell(address, (Float)number); break;
+ }
}
-
}
private static class MappedValueParser extends ValueParser {
@@ -400,16 +392,14 @@ class TensorParser {
TensorType.Value cellValueType = builder.type().valueType();
String cellValueString = string.substring(position, valueEnd).trim();
try {
- if (cellValueType == TensorType.Value.DOUBLE)
- builder.cell(address, Double.parseDouble(cellValueString));
- else if (cellValueType == TensorType.Value.FLOAT)
- builder.cell(address, Float.parseFloat(cellValueString));
- else if (cellValueType == TensorType.Value.BFLOAT16)
- builder.cell(address, Float.parseFloat(cellValueString));
- else if (cellValueType == TensorType.Value.INT8)
- builder.cell(address, Float.parseFloat(cellValueString));
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
+ switch (cellValueType) {
+ case DOUBLE: builder.cell(address, Double.parseDouble(cellValueString)); break;
+ case FLOAT: builder.cell(address, Float.parseFloat(cellValueString)); break;
+ case BFLOAT16: builder.cell(address, Float.parseFloat(cellValueString)); break;
+ case INT8: builder.cell(address, Float.parseFloat(cellValueString)); break;
+ default:
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
}
catch (NumberFormatException e) {
throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" +
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index d7cf5bffcfe..0f67c25337b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -60,23 +60,21 @@ public class TensorType {
public static Value largestOf(Value value1, Value value2) {
if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
if (value1 == FLOAT || value2 == FLOAT) return FLOAT;
- if (value1 == BFLOAT16 || value2 == BFLOAT16) return FLOAT;
- if (value1 == INT8 || value2 == INT8) return FLOAT;
- return FLOAT;
+ if (value1 == BFLOAT16 || value2 == BFLOAT16) return BFLOAT16;
+ return INT8;
}
@Override
public String toString() { return name().toLowerCase(); }
public static Value fromId(String valueTypeString) {
- switch (valueTypeString) {
- case "double" : return Value.DOUBLE;
- case "float" : return Value.FLOAT;
- case "bfloat16" : return Value.BFLOAT16;
- case "int8" : return Value.INT8;
- default : throw new IllegalArgumentException("Value type must be either 'double', 'float', " +
- "'bfloat16', or 'int8' but was '" + valueTypeString + "'");
+ for(Value value : Value.values()) {
+ if (value.id.equals(valueTypeString)) {
+ return value;
+ }
}
+ throw new IllegalArgumentException("Value type must be either 'double', 'float', " +
+ "'bfloat16', or 'int8' but was '" + valueTypeString + "'");
}
};
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index d853c4a9069..b6ea0d04a50 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -64,16 +64,13 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
TensorType.Value fromValueType = tensor.type().valueType();
for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
Tensor.Cell cell = i.next();
- if (fromValueType == TensorType.Value.DOUBLE) {
- builder.cell(cell.getKey(), cell.getDoubleValue());
- } else if (fromValueType == TensorType.Value.FLOAT) {
- builder.cell(cell.getKey(), cell.getFloatValue());
- } else if (fromValueType == TensorType.Value.BFLOAT16) {
- builder.cell(cell.getKey(), cell.getFloatValue());
- } else if (fromValueType == TensorType.Value.INT8) {
- builder.cell(cell.getKey(), cell.getFloatValue());
- } else {
- builder.cell(cell.getKey(), cell.getValue());
+ switch (fromValueType) {
+ case DOUBLE: builder.cell(cell.getKey(), cell.getDoubleValue()); break;
+ case FLOAT: builder.cell(cell.getKey(), cell.getFloatValue()); break;
+ case BFLOAT16: builder.cell(cell.getKey(), cell.getFloatValue()); break;
+ case INT8: builder.cell(cell.getKey(), cell.getFloatValue()); break;
+ default:
+ builder.cell(cell.getKey(), cell.getValue());
}
}
return builder.build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
index edb68025d45..1567c95c9fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java
@@ -7,10 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
-import java.util.Iterator;
import java.util.Optional;
-import java.util.function.Consumer;
-import java.util.function.Supplier;
/**
* Implementation of a dense binary format for a tensor on the form:
@@ -70,7 +67,7 @@ public class DenseBinaryFormat implements BinaryFormat {
private void encodeBFloat16Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
for (int i = 0; i < tensor.size(); i++)
- buffer.putShort((short)(Float.floatToRawIntBits(tensor.getFloat(i)) >>> 16));
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(tensor.getFloat(i)));
}
private void encodeInt8Cells(IndexedTensor tensor, GrowableByteBuffer buffer) {
@@ -140,7 +137,7 @@ public class DenseBinaryFormat implements BinaryFormat {
private void decodeBFloat16Cells(DimensionSizes sizes, IndexedTensor.BoundBuilder builder, GrowableByteBuffer buffer) {
for (long i = 0; i < sizes.totalSize(); i++) {
- builder.cellByDirectIndex(i, Float.intBitsToFloat(buffer.getShort() << 16));
+ builder.cellByDirectIndex(i, TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort()));
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index 7d500614caa..6cb9a63fe68 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -65,7 +65,7 @@ class MixedBinaryFormat implements BinaryFormat {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
case BFLOAT16: encodeCells(buffer, tensor, (val) ->
- buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break;
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break;
case INT8: encodeCells(buffer, tensor, (val) -> buffer.put(((byte)val.floatValue()))); break;
}
}
@@ -131,7 +131,7 @@ class MixedBinaryFormat implements BinaryFormat {
case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
case BFLOAT16: decodeCells(buffer, builder, type, () ->
- (double)Float.intBitsToFloat(buffer.getShort() << 16)); break;
+ (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break;
case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
index 160cf660e1b..763b722a90c 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java
@@ -53,7 +53,7 @@ class SparseBinaryFormat implements BinaryFormat {
case DOUBLE: encodeCells(buffer, tensor, buffer::putDouble); break;
case FLOAT: encodeCells(buffer, tensor, (val) -> buffer.putFloat(val.floatValue())); break;
case BFLOAT16: encodeCells(buffer, tensor, (val) ->
- buffer.putShort((short) (Float.floatToRawIntBits(val.floatValue()) >>> 16))); break;
+ buffer.putShort(TypedBinaryFormat.bFloat16BitsFromFloat(val.floatValue()))); break;
case INT8: encodeCells(buffer, tensor, (val) -> buffer.put((byte)(val.floatValue()))); break;
}
}
@@ -106,7 +106,7 @@ class SparseBinaryFormat implements BinaryFormat {
case DOUBLE: decodeCells(buffer, builder, type, buffer::getDouble); break;
case FLOAT: decodeCells(buffer, builder, type, () -> (double)buffer.getFloat()); break;
case BFLOAT16: decodeCells(buffer, builder, type, () ->
- (double)Float.intBitsToFloat(buffer.getShort() << 16)); break;
+ (double)TypedBinaryFormat.floatFromBFloat16Bits(buffer.getShort())); break;
case INT8: decodeCells(buffer, builder, type, () -> (double)buffer.get()); break;
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
index cddb283489c..be04be80ed9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/TypedBinaryFormat.java
@@ -141,4 +141,12 @@ public class TypedBinaryFormat {
return result;
}
+ static short bFloat16BitsFromFloat(float val) {
+ return (short) (Float.floatToRawIntBits(val) >>> 16);
+ }
+
+ static float floatFromBFloat16Bits(short bits) {
+ return Float.intBitsToFloat(bits << 16);
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index b47c0873535..572dc433d71 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -71,14 +71,8 @@ public class TensorTestCase {
public void testValueTypeResolving() {
assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double");
assertCellTypeResult(TensorType.Value.DOUBLE, "double", "float");
- assertCellTypeResult(TensorType.Value.DOUBLE, "double", "bfloat16");
- assertCellTypeResult(TensorType.Value.DOUBLE, "double", "int8");
assertCellTypeResult(TensorType.Value.FLOAT, "float", "float");
- assertCellTypeResult(TensorType.Value.FLOAT, "float", "bfloat16");
- assertCellTypeResult(TensorType.Value.FLOAT, "float", "int8");
- assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "bfloat16");
- assertCellTypeResult(TensorType.Value.FLOAT, "bfloat16", "int8");
- assertCellTypeResult(TensorType.Value.FLOAT, "int8", "int8");
+ // Test bfloat16 and int8 when we have proper cell type resolving in place.
}
@Test