diff options
author | Håvard Pettersen <havardpe@oath.com> | 2019-07-09 08:59:07 +0000 |
---|---|---|
committer | Håvard Pettersen <havardpe@oath.com> | 2019-07-09 08:59:07 +0000 |
commit | fde3e35a631fdfdbcf41a2466a65712b9a3f5ee2 (patch) | |
tree | 6861dde6cfc3176b3e0a336f1b4b24b8dda3da86 /model-integration | |
parent | 4e8a65ed3701c814459b5ce58291d9764446d873 (diff) |
propagate float and stop using it too much
Diffstat (limited to 'model-integration')
3 files changed, 24 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java index 29d600fa7c6..8c9fe60e1d4 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/onnx/TypeConverter.java @@ -53,17 +53,17 @@ class TypeConverter { private static TensorType.Value toValueType(Onnx.TensorProto.DataType dataType) { switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; + case FLOAT: return TensorType.Value.DOUBLE; case DOUBLE: return TensorType.Value.DOUBLE; // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT8: return TensorType.Value.FLOAT; - case INT16: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; + case BOOL: return TensorType.Value.DOUBLE; + case INT8: return TensorType.Value.DOUBLE; + case INT16: return TensorType.Value.DOUBLE; + case INT32: return TensorType.Value.DOUBLE; case INT64: return TensorType.Value.DOUBLE; - case UINT8: return TensorType.Value.FLOAT; - case UINT16: return TensorType.Value.FLOAT; - case UINT32: return TensorType.Value.FLOAT; + case UINT8: return TensorType.Value.DOUBLE; + case UINT16: return TensorType.Value.DOUBLE; + case UINT32: return TensorType.Value.DOUBLE; case UINT64: return TensorType.Value.DOUBLE; default: throw new IllegalArgumentException("A ONNX tensor with data type " + dataType + " cannot be converted to a Vespa tensor type"); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java index d8ddb01b650..08c0564ed8a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/tensorflow/TypeConverter.java @@ -85,19 +85,19 @@ class TypeConverter { private static TensorType.Value toValueType(DataType dataType) { switch (dataType) { - case DT_FLOAT: return TensorType.Value.FLOAT; + case DT_FLOAT: return TensorType.Value.DOUBLE; case DT_DOUBLE: return TensorType.Value.DOUBLE; // Imperfect conversion, for now: - case DT_BOOL: return TensorType.Value.FLOAT; - case DT_BFLOAT16: return TensorType.Value.FLOAT; - case DT_HALF: return TensorType.Value.FLOAT; - case DT_INT8: return TensorType.Value.FLOAT; - case DT_INT16: return TensorType.Value.FLOAT; - case DT_INT32: return TensorType.Value.FLOAT; + case DT_BOOL: return TensorType.Value.DOUBLE; + case DT_BFLOAT16: return TensorType.Value.DOUBLE; + case DT_HALF: return TensorType.Value.DOUBLE; + case DT_INT8: return TensorType.Value.DOUBLE; + case DT_INT16: return TensorType.Value.DOUBLE; + case DT_INT32: return TensorType.Value.DOUBLE; case DT_INT64: return TensorType.Value.DOUBLE; - case DT_UINT8: return TensorType.Value.FLOAT; - case DT_UINT16: return TensorType.Value.FLOAT; - case DT_UINT32: return TensorType.Value.FLOAT; + case DT_UINT8: return TensorType.Value.DOUBLE; + case DT_UINT16: return TensorType.Value.DOUBLE; + case DT_UINT32: return TensorType.Value.DOUBLE; case DT_UINT64: return TensorType.Value.DOUBLE; default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + " cannot be converted to a Vespa tensor type"); @@ -106,12 +106,12 @@ class TypeConverter { private static TensorType.Value toValueType(org.tensorflow.DataType dataType) { switch (dataType) { - case FLOAT: return TensorType.Value.FLOAT; + case FLOAT: return TensorType.Value.DOUBLE; case DOUBLE: return TensorType.Value.DOUBLE; // Imperfect conversion, for now: - case BOOL: return TensorType.Value.FLOAT; - case INT32: return TensorType.Value.FLOAT; - case UINT8: return TensorType.Value.FLOAT; + case BOOL: return TensorType.Value.DOUBLE; + case INT32: return TensorType.Value.DOUBLE; + case UINT8: return TensorType.Value.DOUBLE; case INT64: return TensorType.Value.DOUBLE; default: throw new IllegalArgumentException("A TensorFlow tensor with data type " + dataType + " cannot be converted to a Vespa tensor type"); diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java index 07814687dc6..424e4d6c57c 100644 --- a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/onnx/OnnxMnistSoftmaxImportTestCase.java @@ -43,14 +43,14 @@ public class OnnxMnistSoftmaxImportTestCase { // Check inputs assertEquals(1, model.inputs().size()); assertTrue(model.inputs().containsKey("Placeholder")); - assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), model.inputs().get("Placeholder")); + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); // Check signature ImportedMlFunction output = model.defaultSignature().outputFunction("add", "add"); assertNotNull(output); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", output.expression()); - assertEquals(TensorType.fromSpec("tensor<float>(d0[],d1[784])"), + assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); } |