summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-09-02 09:20:54 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-09-02 09:20:54 +0200
commit96e2cf880899cb204000e0693bb1bc51e2f52520 (patch)
treea75f048347f4c806d9332f81f0ffafdb7c549f30 /model-integration
parentb6fd9b3e3381b733923263b667cd9a7d52ed8715 (diff)
Propagate float value types from Onnx and TF
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java18
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java10
3 files changed, 20 insertions, 20 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
index 8c9fe60e1d4..98ff8ca735f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java
@@ -53,16 +53,16 @@ class TypeConverter {
private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) {
switch (dataType) {
- case FLOAT: return TensorType.Value.DOUBLE;
+ case FLOAT: return TensorType.Value.FLOAT;
case DOUBLE: return TensorType.Value.DOUBLE;
// Imperfect conversion, for now:
- case BOOL: return TensorType.Value.DOUBLE;
- case INT8: return TensorType.Value.DOUBLE;
- case INT16: return TensorType.Value.DOUBLE;
+ case BOOL: return TensorType.Value.FLOAT;
+ case INT8: return TensorType.Value.FLOAT;
+ case INT16: return TensorType.Value.FLOAT;
case INT32: return TensorType.Value.DOUBLE;
case INT64: return TensorType.Value.DOUBLE;
- case UINT8: return TensorType.Value.DOUBLE;
- case UINT16: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
+ case UINT16: return TensorType.Value.FLOAT;
case UINT32: return TensorType.Value.DOUBLE;
case UINT64: return TensorType.Value.DOUBLE;
default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType +
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
index d9bb5c2fe45..3102d5431d4 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java
@@ -92,17 +92,17 @@ class TypeConverter {
private static TensorType.Value toValueType(DataType dataType) {
switch (dataType) {
- case DT_FLOAT: return TensorType.Value.DOUBLE;
+ case DT_FLOAT: return TensorType.Value.FLOAT;
case DT_DOUBLE: return TensorType.Value.DOUBLE;
// Imperfect conversion, for now:
- case DT_BOOL: return TensorType.Value.DOUBLE;
- case DT_BFLOAT16: return TensorType.Value.DOUBLE;
- case DT_HALF: return TensorType.Value.DOUBLE;
- case DT_INT8: return TensorType.Value.DOUBLE;
+ case DT_BOOL: return TensorType.Value.FLOAT;
+ case DT_BFLOAT16: return TensorType.Value.FLOAT;
+ case DT_HALF: return TensorType.Value.FLOAT;
+ case DT_INT8: return TensorType.Value.FLOAT;
case DT_INT16: return TensorType.Value.DOUBLE;
case DT_INT32: return TensorType.Value.DOUBLE;
case DT_INT64: return TensorType.Value.DOUBLE;
- case DT_UINT8: return TensorType.Value.DOUBLE;
+ case DT_UINT8: return TensorType.Value.FLOAT;
case DT_UINT16: return TensorType.Value.DOUBLE;
case DT_UINT32: return TensorType.Value.DOUBLE;
case DT_UINT64: return TensorType.Value.DOUBLE;
@@ -113,12 +113,12 @@ class TypeConverter {
private static TensorType.Value toValueType(org.tensorflow.DataType dataType) {
switch (dataType) {
- case FLOAT: return TensorType.Value.DOUBLE;
+ case FLOAT: return TensorType.Value.FLOAT;
case DOUBLE: return TensorType.Value.DOUBLE;
// Imperfect conversion, for now:
- case BOOL: return TensorType.Value.DOUBLE;
+ case BOOL: return TensorType.Value.FLOAT;
case INT32: return TensorType.Value.DOUBLE;
- case UINT8: return TensorType.Value.DOUBLE;
+ case UINT8: return TensorType.Value.FLOAT;
case INT64: return TensorType.Value.DOUBLE;
default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType +
" cannot be converted to a Vespa tensor type");
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
index 68df59bf93f..35c853bd746 100644
--- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java
@@ -31,28 +31,28 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor constant0 = Tensor.from(model.largeConstants().get("test_Variable"));
assertNotNull(constant0);
- assertEquals(new TensorType.Builder(TensorType.Value.DOUBLE).indexed("d2", 784).indexed("d1", 10).build(),
+ assertEquals(new TensorType.Builder(TensorType.Value.FLOAT).indexed("d2", 784).indexed("d1", 10).build(),
constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = Tensor.from(model.largeConstants().get("test_Variable_1"));
assertNotNull(constant1);
- assertEquals(new TensorType.Builder(TensorType.Value.DOUBLE).indexed("d1", 10).build(), constant1.type());
+ assertEquals(new TensorType.Builder(TensorType.Value.FLOAT).indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
// Check inputs
assertEquals(1, model.inputs().size());
assertTrue(model.inputs().containsKey("Placeholder"));
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get("Placeholder"));
// Check signature
ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add");
assertNotNull(output);
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
output.expression());
- assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
+ assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
- assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ assertEquals("{Placeholder=tensor<float>(d0[],d1[784])}", output.argumentTypes().toString());
}
@Test