aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-11-30 13:48:01 +0000
committerArne Juul <arnej@yahooinc.com>2023-12-11 08:47:15 +0000
commit055b84652f6a0c9b517c76588c145d92216f6e02 (patch)
tree635c1763de83261409293d6ae9edb8fc03e9a51d
parent18e3fb5c91e9e40d46fccc1b8988c445f27ec19e (diff)
add parsing of special strings for inf/nan cell values
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java16
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java9
-rw-r--r--document/src/main/java/com/yahoo/document/json/readers/TensorReader.java4
-rw-r--r--document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java4
-rw-r--r--eval/src/tests/eval/value_cache/dense-special.json7
-rw-r--r--eval/src/tests/eval/value_cache/tensor_loader_test.cpp23
-rw-r--r--eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java32
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java18
9 files changed, 135 insertions, 16 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index fcb99215565..c41c6e59a4f 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -7,6 +7,8 @@ import com.fasterxml.jackson.core.JsonToken;
import com.google.common.base.Joiner;
import com.yahoo.tensor.TensorType;
+import static com.yahoo.tensor.serialization.JsonFormat.decodeNumberString;
+
import java.io.IOException;
import java.io.Reader;
import java.util.ArrayList;
@@ -282,9 +284,19 @@ public class ConstantTensorJsonValidator {
}
private void validateNumeric(String where, JsonToken token) throws IOException {
- if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) {
- throw new InvalidConstantTensorException(parser, String.format("Inside '%s': cell value is not a number (%s)", where, token.toString()));
+ if (token == JsonToken.VALUE_NUMBER_FLOAT || token == JsonToken.VALUE_NUMBER_INT) {
+ return; // ok
+ }
+ if (token == JsonToken.VALUE_STRING) {
+ String input = parser.getValueAsString();
+ try {
+ double d = decodeNumberString(input);
+ return;
+ } catch (NumberFormatException e) {
+ throw new InvalidConstantTensorException(parser, String.format("Inside '%s': %s", where, e.getMessage()));
+ }
}
+ throw new InvalidConstantTensorException(parser, String.format("Inside '%s': cell value is not a number (%s)", where, token.toString()));
}
private void assertCurrentTokenIs(JsonToken wantedToken) {
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
index 4892c9acefa..9171aae170c 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
@@ -208,7 +208,7 @@ public class ConstantTensorJsonValidatorTest {
" ]",
"}"));
});
- assertTrue(exception.getMessage().contains("Inside 'value': cell value is not a number (VALUE_STRING)"));
+ assertTrue(exception.getMessage().contains("Inside 'value': Excepted a number, got string 'fruit'"));
}
@Test
@@ -295,6 +295,13 @@ public class ConstantTensorJsonValidatorTest {
}
@Test
+ void ensure_that_values_can_contain_special_values() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[5])"),
+ inputJsonToReader("['Infinity','+inf','NaN','-infinity','-nan']"));
+ }
+
+ @Test
void ensure_that_simple_object_for_map_works() {
validateTensorJson(
TensorType.fromSpec("tensor(x{})"),
diff --git a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
index e487fd2ec57..0b7b1ae9996 100644
--- a/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
+++ b/document/src/main/java/com/yahoo/document/json/readers/TensorReader.java
@@ -14,6 +14,7 @@ import com.yahoo.tensor.TensorType;
import static com.yahoo.document.json.readers.JsonParserHelpers.*;
import static com.yahoo.tensor.serialization.JsonFormat.decodeHexString;
+import static com.yahoo.tensor.serialization.JsonFormat.decodeNumberString;
/**
* Reads the tensor format defined at
@@ -243,6 +244,9 @@ public class TensorReader {
private static double readDouble(TokenBuffer buffer) {
try {
+ if (buffer.current() == JsonToken.VALUE_STRING) {
+ return decodeNumberString(buffer.currentText());
+ }
return Double.parseDouble(buffer.currentText());
}
catch (NumberFormatException e) {
diff --git a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
index 28e5293e96b..8a45fe95fa2 100644
--- a/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
+++ b/document/src/test/java/com/yahoo/document/json/JsonReaderTestCase.java
@@ -1462,14 +1462,14 @@ public class JsonReaderTestCase {
builder.cell().label("x", 0).label("y", 0).value(2.0);
builder.cell().label("x", 0).label("y", 1).value(3.0);
builder.cell().label("x", 0).label("y", 2).value(4.0);
- builder.cell().label("x", 1).label("y", 0).value(5.0);
+ builder.cell().label("x", 1).label("y", 0).value(Double.POSITIVE_INFINITY);
builder.cell().label("x", 1).label("y", 1).value(6.0);
builder.cell().label("x", 1).label("y", 2).value(7.0);
Tensor expected = builder.build();
Tensor tensor = assertTensorField(expected,
createPutWithTensor(inputJson("{",
- " 'values': [2.0, 3.0, 4.0, 5.0, 6.0, 7.0]",
+ " 'values': [2.0, 3.0, 4.0, 'inf', 6.0, 7.0]",
"}"), "dense_tensor"), "dense_tensor");
assertTrue(tensor instanceof IndexedTensor); // this matters for performance
}
diff --git a/eval/src/tests/eval/value_cache/dense-special.json b/eval/src/tests/eval/value_cache/dense-special.json
new file mode 100644
index 00000000000..644f0f436e7
--- /dev/null
+++ b/eval/src/tests/eval/value_cache/dense-special.json
@@ -0,0 +1,7 @@
+[
+ "Infinity", "+Infinity",
+ "INF", "+INF",
+ "-Infinity", "-INF",
+ "NAN", "+NAN",
+ "-nan", "-NAN"
+]
diff --git a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
index 22847a1d08e..24d6aaab007 100644
--- a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
+++ b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
@@ -130,4 +130,27 @@ TEST_F("require that bad lz4 file fails to load creating empty result", Constant
TEST_DO(verify_tensor(sparse_tensor_nocells(), f1.create(TEST_PATH("bad_lz4.json.lz4"), "tensor(x{},y{})")));
}
+void checkBitEq(double a, double b) {
+ size_t aa, bb;
+ memcpy(&aa, &a, sizeof(aa));
+ memcpy(&bb, &b, sizeof(bb));
+ EXPECT_EQUAL(aa, bb);
+}
+
+TEST_F("require that special string-encoded values work", ConstantTensorLoader(factory)) {
+ auto c = f1.create(TEST_PATH("dense-special.json"), "tensor<float>(z[10])");
+ const auto &v = c->value();
+ auto cells = v.cells().template typify<float>();
+ EXPECT_EQUAL(std::numeric_limits<float>::infinity(), cells[0]);
+ EXPECT_EQUAL(std::numeric_limits<float>::infinity(), cells[1]);
+ EXPECT_EQUAL(std::numeric_limits<float>::infinity(), cells[2]);
+ EXPECT_EQUAL(std::numeric_limits<float>::infinity(), cells[3]);
+ EXPECT_EQUAL(-std::numeric_limits<float>::infinity(), cells[4]);
+ EXPECT_EQUAL(-std::numeric_limits<float>::infinity(), cells[5]);
+ TEST_DO(checkBitEq(std::numeric_limits<float>::quiet_NaN(), cells[6]));
+ TEST_DO(checkBitEq(std::numeric_limits<float>::quiet_NaN(), cells[7]));
+ TEST_DO(checkBitEq(-std::numeric_limits<float>::quiet_NaN(), cells[8]));
+ TEST_DO(checkBitEq(-std::numeric_limits<float>::quiet_NaN(), cells[9]));
+}
+
TEST_MAIN() { TEST_RUN_ALL(); }
diff --git a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp
index 189f7aa14ce..aea3e8a76fa 100644
--- a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp
+++ b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp
@@ -3,10 +3,11 @@
#include "constant_tensor_loader.h"
#include <vespa/eval/eval/tensor_spec.h>
#include <vespa/eval/eval/value_codec.h>
-#include <vespa/vespalib/objects/nbostream.h>
-#include <vespa/vespalib/io/mapped_file_input.h>
#include <vespa/vespalib/data/lz4_input_decoder.h>
#include <vespa/vespalib/data/slime/slime.h>
+#include <vespa/vespalib/io/mapped_file_input.h>
+#include <vespa/vespalib/objects/nbostream.h>
+#include <vespa/vespalib/text/lowercase.h>
#include <vespa/vespalib/util/size_literals.h>
#include <set>
@@ -20,6 +21,31 @@ using ObjectTraverser = slime::ObjectTraverser;
namespace {
+double decodeDouble(const Inspector &inspector) {
+ if (inspector.type().getId() == vespalib::slime::STRING::ID) {
+ auto orig = inspector.asString().make_stringref();
+ auto lower = vespalib::LowerCase::convert(orig);
+ if (lower == "infinity" || lower == "+infinity" || lower == "inf" || lower == "+inf") {
+ double d = std::numeric_limits<double>::infinity();
+ return d;
+ }
+ if (lower == "-infinity" || lower == "-inf") {
+ double d = -std::numeric_limits<double>::infinity();
+ return d;
+ }
+ if (lower == "nan" || lower == "+nan") {
+ double d = std::numeric_limits<double>::quiet_NaN();
+ return d;
+ }
+ if (lower == "-nan") {
+ double d = -std::numeric_limits<double>::quiet_NaN();
+ return d;
+ }
+ LOG(warning, "bad string-encoded numeric value '%.*s'", (int)orig.size(), orig.data());
+ }
+ return inspector.asDouble();
+}
+
struct Target {
const ValueType tensor_type;
TensorSpec spec;
@@ -110,7 +136,7 @@ struct SingleMappedExtractor : ObjectTraverser {
{}
void field(const Memory &symbol, const Inspector &inspector) override {
vespalib::string label = symbol.make_string();
- double value = inspector.asDouble();
+ double value = decodeDouble(inspector);
TensorSpec::Address address;
address.emplace(dimension, label);
target.check_add(address, value);
@@ -128,7 +154,7 @@ void decodeSingleDenseForm(const Inspector &values, const ValueType &value_type,
for (size_t i = 0; i < values.entries(); ++i) {
TensorSpec::Address address;
address.emplace(dimension, TensorSpec::Label(i));
- target.check_add(address, values[i].asDouble());
+ target.check_add(address, decodeDouble(values[i]));
}
}
@@ -137,7 +163,7 @@ struct DenseValuesDecoder {
Target &_target;
void decode(const Inspector &input, const TensorSpec::Address &address, size_t dim_idx) {
if (dim_idx == _idims.size()) {
- _target.check_add(address, input.asDouble());
+ _target.check_add(address, decodeDouble(input));
} else {
const auto &dimension = _idims[dim_idx];
if (input.entries() != dimension.size) {
@@ -209,7 +235,7 @@ void decodeLiteralForm(const Inspector &cells, const ValueType &value_type, Targ
TensorSpec::Address address;
AddressExtractor extractor(indexed, address);
cells[i]["address"].traverse(extractor);
- target.check_add(address, cells[i]["value"].asDouble());
+ target.check_add(address, decodeDouble(cells[i]["value"]));
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 28f14c8d7ca..204c0331e3a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -234,10 +234,11 @@ public class JsonFormat {
TensorAddress address = decodeAddress(cell.field("address"), builder.type());
Inspector value = cell.field("value");
- if (value.type() != Type.LONG && value.type() != Type.DOUBLE)
+ if (value.type() == Type.STRING || value.type() == Type.LONG || value.type() == Type.DOUBLE) {
+ builder.cell(address, decodeNumeric(value));
+ } else {
throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'");
-
- builder.cell(address, value.asDouble());
+ }
}
private static void decodeSingleDimensionCell(String key, Inspector value, Tensor.Builder builder) {
@@ -268,8 +269,8 @@ public class JsonFormat {
values.traverse((ArrayTraverser) (__, value) -> {
if (value.type() == Type.ARRAY)
decodeNestedValues(value, builder, index);
- else if (value.type() == Type.LONG || value.type() == Type.DOUBLE)
- indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
+ else if (value.type() == Type.LONG || value.type() == Type.DOUBLE || value.type() == Type.STRING)
+ indexedBuilder.cellByDirectIndex(index.next(), decodeNumeric(value));
else
throw new IllegalArgumentException("Excepted the values array to contain numbers or nested arrays, not " + value.type());
});
@@ -445,10 +446,31 @@ public class JsonFormat {
return new TensorAddress.Builder(type).add(type.dimensions().get(0).name(), label).build();
}
+
private static double decodeNumeric(Inspector numericField) {
+ if (numericField.type() == Type.STRING) {
+ return decodeNumberString(numericField.asString());
+ }
if (numericField.type() != Type.LONG && numericField.type() != Type.DOUBLE)
throw new IllegalArgumentException("Excepted a number, not " + numericField.type());
return numericField.asDouble();
}
+ public static double decodeNumberString(String input) {
+ String s = input.toLowerCase();
+ if (s.equals("infinity") || s.equals("+infinity") || s.equals("inf") || s.equals("+inf")) {
+ return Double.POSITIVE_INFINITY;
+ }
+ if (s.equals("-infinity") || s.equals("-inf")) {
+ return Double.NEGATIVE_INFINITY;
+ }
+ if (s.equals("nan") || s.equals("+nan")) {
+ return Double.NaN;
+ }
+ if (s.equals("-nan")) {
+ return Math.copySign(Double.NaN, -1.0); // or Double.longBitsToDouble(0xfff8000000000000L);
+ }
+ throw new NumberFormatException("Excepted a number, got string '" + input + "'");
+ }
+
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index d95396aca50..66d3a0e824d 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -669,6 +669,24 @@ public class JsonFormatTestCase {
"{\"type\":\"tensor<float>(x[1])\",\"values\":[0.3333333432674408]}");
}
+ @Test
+ public void testSpecialNumberStrings() {
+ assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("Infinity"), 0.0);
+ assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("+Infinity"), 0.0);
+ assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("Inf"), 0.0);
+ assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("+Inf"), 0.0);
+ assertEquals(Double.POSITIVE_INFINITY, JsonFormat.decodeNumberString("infinity"), 0.0);
+ assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-Infinity"), 0.0);
+ assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-Inf"), 0.0);
+ assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-infinity"), 0.0);
+ assertEquals(Double.NEGATIVE_INFINITY, JsonFormat.decodeNumberString("-inf"), 0.0);
+ assertEquals(0x7FF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("nan")));
+ assertEquals(0x7FF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("NaN")));
+ assertEquals(0x7FF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("+NaN")));
+ assertEquals(0xFFF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("-nan")));
+ assertEquals(0xFFF8000000000000L, Double.doubleToRawLongBits(JsonFormat.decodeNumberString("-NaN")));
+ }
+
private void assertEncodeShortForm(String tensor, String expected) {
assertEncodeShortForm(Tensor.from(tensor), expected);
}