diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-11-30 13:48:01 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-12-11 08:47:15 +0000 |
commit | 055b84652f6a0c9b517c76588c145d92216f6e02 (patch) | |
tree | 635c1763de83261409293d6ae9edb8fc03e9a51d /vespajlib | |
parent | 18e3fb5c91e9e40d46fccc1b8988c445f27ec19e (diff) |
add parsing of special strings for inf/nan cell values
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java | 32 | ||||
-rw-r--r-- | vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java | 18 |
2 files changed, 45 insertions, 5 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 28f14c8d7ca..204c0331e3a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -234,10 +234,11 @@ public class JsonFormat { TensorAddress address = decodeAddress(cell.field("address"), builder.type()); Inspector value = cell.field("value"); - if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + if (value.type() == Type.STRING || value.type() == Type.LONG || value.type() == Type.DOUBLE) { + builder.cell(address, decodeNumeric(value)); + } else { throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'"); - - builder.cell(address, value.asDouble()); + } } private static void decodeSingleDimensionCell(String key, Inspector value, Tensor.Builder builder) { @@ -268,8 +269,8 @@ public class JsonFormat { values.traverse((ArrayTraverser) (__, value) -> { if (value.type() == Type.ARRAY) decodeNestedValues(value, builder, index); - else if (value.type() == Type.LONG || value.type() == Type.DOUBLE) - indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); + else if (value.type() == Type.LONG || value.type() == Type.DOUBLE || value.type() == Type.STRING) + indexedBuilder.cellByDirectIndex(index.next(), decodeNumeric(value)); else throw new IllegalArgumentException("Excepted the values array to contain numbers or nested arrays, not " + value.type()); }); @@ -445,10 +446,31 @@ public class JsonFormat { return new TensorAddress.Builder(type).add(type.dimensions().get(0).name(), label).build(); } + private static double decodeNumeric(Inspector numericField) { + if (numericField.type() == Type.STRING) { + return decodeNumberString(numericField.asString()); + } if (numericField.type() != Type.LONG && numericField.type() != Type.DOUBLE) throw new IllegalArgumentException("Excepted a number, not " + numericField.type()); return numericField.asDouble(); } + public static double decodeNumberString(String input) { + String s = input.toLowerCase(); + if (s.equals("infinity") || s.equals("+infinity") || s.equals("inf") || s.equals("+inf")) { + return Double.POSITIVE_INFINITY; + } + if (s.equals("-infinity") || s.equals("-inf")) { + return Double.NEGATIVE_INFINITY; + } + if (s.equals("nan") || s.equals("+nan")) { + return Double.NaN; + } + if (s.equals("-nan")) { + return Math.copySign(Double.NaN, -1.0); // or Double.longBitsToDouble(0xfff8000000000000L); + } + throw new NumberFormatException("Excepted a number, got string '" + input + "'"); + } + } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index d95396aca50..66d3a0e824d 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -669,6 +669,24 @@ public class JsonFormatTestCase { "{\"type\":\"tensor<float>(x[1])\",\"values\":[0.3333333432674408]}"); } + @Test + public void testSpecialNumberStrings() { + assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("Infinity"), 0.0); + assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("+Infinity"), 0.0); + assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("Inf"), 0.0); + assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("+Inf"), 0.0); + assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("infinity"), 0.0); + assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-Infinity"), 0.0); + assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-Inf"), 0.0); + assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-infinity"), 0.0); + assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-inf"), 0.0); + assertEquals(0x7FF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("nan"))); + assertEquals(0x7FF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("NaN"))); + assertEquals(0x7FF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("+NaN"))); + assertEquals(0xFFF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("-nan"))); + assertEquals(0xFFF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("-NaN"))); + } + private void assertEncodeShortForm(String tensor, String expected) { assertEncodeShortForm(Tensor.from(tensor), expected); } |