diff options
author | Arne H Juul <arnej27959@users.noreply.github.com> | 2023-05-22 11:42:17 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-05-22 11:42:17 +0200 |
commit | 07d0699378a589bcdffa08bba840b695a185ef99 (patch) | |
tree | ae4a6057d52f77bd2a14fab214774ffa90b86a03 | |
parent | c7a07adf43c13165e49e2aa2ef509ecb2526a48c (diff) | |
parent | 95d2e5194acd87facd594201dd1db254a41b1f73 (diff) |
Merge pull request #27137 from vespa-engine/arnej/allow-some-short-forms-for-constant-tensors
allow short-form JSON for 1-d constants
9 files changed, 199 insertions, 32 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java index 66da43856b1..eccb6910866 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java @@ -28,11 +28,14 @@ public class ConstantTensorJsonValidator { private static final String FIELD_CELLS = "cells"; private static final String FIELD_ADDRESS = "address"; private static final String FIELD_VALUE = "value"; + private static final String FIELD_VALUES = "values"; private static final JsonFactory jsonFactory = new JsonFactory(); private JsonParser parser; private Map<String, TensorType.Dimension> tensorDimensions; + private boolean isSingleDenseType = false; + private boolean isSingleMappedType = false; public void validate(String fileName, TensorType type, Reader tensorData) { if (fileName.endsWith(".json")) { @@ -57,19 +60,69 @@ public class ConstantTensorJsonValidator { .dimensions() .stream() .collect(Collectors.toMap(TensorType.Dimension::name, Function.identity())); + if (type.dimensions().size() == 1) { + this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty); + this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty); + } + var top = parser.nextToken(); + if (top == JsonToken.START_ARRAY) { + consumeValuesArray(); + } else if (top == JsonToken.START_OBJECT) { + consumeTopObject(); + } + }); + } - assertNextTokenIs(JsonToken.START_OBJECT); - assertNextTokenIs(JsonToken.FIELD_NAME); - assertFieldNameIs(FIELD_CELLS); + private void consumeValuesArray() throws IOException { + if (! isSingleDenseType) { + throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors")); + } + assertCurrentTokenIs(JsonToken.START_ARRAY); + while (parser.nextToken() != JsonToken.END_ARRAY) { + validateNumeric(parser.getCurrentToken()); + } + } + private void consumeTopObject() throws IOException { + assertCurrentTokenIs(JsonToken.START_OBJECT); + assertNextTokenIs(JsonToken.FIELD_NAME); + String fieldName = parser.getCurrentName(); + if (fieldName.equals(FIELD_VALUES)) { assertNextTokenIs(JsonToken.START_ARRAY); + consumeValuesArray(); + } else if (fieldName.equals(FIELD_CELLS)) { + consumeCellsField(); + } else { + throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName)); + } + assertNextTokenIs(JsonToken.END_OBJECT); + } - while (parser.nextToken() != JsonToken.END_ARRAY) { - validateTensorCell(); - } + private void consumeCellsField() throws IOException { + var token = parser.nextToken(); + if (token == JsonToken.START_ARRAY) { + consumeLiteralFormArray(); + } else if (token == JsonToken.START_OBJECT) { + consumeSimpleMappedObject(); + } else { + throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString())); + } + } - assertNextTokenIs(JsonToken.END_OBJECT); - }); + private void consumeLiteralFormArray() throws IOException { + while (parser.nextToken() != JsonToken.END_ARRAY) { + validateTensorCell(); + } + } + + private void consumeSimpleMappedObject() throws IOException { + if (! isSingleMappedType) { + throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects")); + } + while (parser.nextToken() != JsonToken.END_OBJECT) { + assertCurrentTokenIs(JsonToken.FIELD_NAME); + validateTensorCellValue(); + } } private void validateTensorCell() { @@ -87,7 +140,7 @@ public class ConstantTensorJsonValidator { if (fieldName.equals(FIELD_ADDRESS)) { validateTensorAddress(); } else if (fieldName.equals(FIELD_VALUE)) { - validateTensorValue(); + validateTensorCellValue(); } } else { throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object"); @@ -169,9 +222,12 @@ public class ConstantTensorJsonValidator { throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName)); } - private void validateTensorValue() throws IOException { + private void validateTensorCellValue() throws IOException { JsonToken token = parser.nextToken(); + validateNumeric(token); + } + private void validateNumeric(JsonToken token) throws IOException { if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) { throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString())); } diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java index 80643917a58..42be1592eca 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java @@ -281,7 +281,25 @@ public class ConstantTensorJsonValidatorTest { " }", "}")); }); - assertTrue(exception.getMessage().contains("Expected field name 'cells', got 'stats'")); + System.err.println("msg: " + exception.getMessage()); + assertTrue(exception.getMessage().contains("Expected 'cells' or 'values', got 'stats'")); + } + + @Test + void ensure_that_values_array_for_vector_works() { + validateTensorJson( + TensorType.fromSpec("tensor(x[5])"), + inputJsonToReader("[5,4.0,3.1,-2,-1.0]")); + validateTensorJson( + TensorType.fromSpec("tensor(x[5])"), + inputJsonToReader("{'values':[5,4.0,3.1,-2,-1.0]}")); + } + + @Test + void ensure_that_simple_object_for_map_works() { + validateTensorJson( + TensorType.fromSpec("tensor(x{})"), + inputJsonToReader("{'cells':{'a':5,'b':4.0,'c':3.1,'d':-2,'e':-1.0}}")); } } diff --git a/eval/src/tests/eval/value_cache/dense-short1.json b/eval/src/tests/eval/value_cache/dense-short1.json new file mode 100644 index 00000000000..4e170001c96 --- /dev/null +++ b/eval/src/tests/eval/value_cache/dense-short1.json @@ -0,0 +1 @@ +[ 1, 2.0, 3.5 ] diff --git a/eval/src/tests/eval/value_cache/dense-short2.json b/eval/src/tests/eval/value_cache/dense-short2.json new file mode 100644 index 00000000000..40121135544 --- /dev/null +++ b/eval/src/tests/eval/value_cache/dense-short2.json @@ -0,0 +1,3 @@ +{ + "values": [ 1, 2.0, 3.5 ] +} diff --git a/eval/src/tests/eval/value_cache/sparse-short1.json b/eval/src/tests/eval/value_cache/sparse-short1.json new file mode 100644 index 00000000000..949b7b2b8bd --- /dev/null +++ b/eval/src/tests/eval/value_cache/sparse-short1.json @@ -0,0 +1,5 @@ +{ + "foo": 1.0, + "bar": 2.0, + "three": 3.0 +} diff --git a/eval/src/tests/eval/value_cache/sparse-short2.json b/eval/src/tests/eval/value_cache/sparse-short2.json new file mode 100644 index 00000000000..f10b1b6f9fb --- /dev/null +++ b/eval/src/tests/eval/value_cache/sparse-short2.json @@ -0,0 +1,7 @@ +{ + "cells": { + "foo": 1.0, + "bar": 2.0, + "three": 3.0 + } +} diff --git a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp index 1a77cfe847b..4b4ba3fc0d3 100644 --- a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp +++ b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp @@ -19,12 +19,26 @@ TensorSpec make_dense_tensor() { .add({{"x", 1}, {"y", 1}}, 4.0); } +TensorSpec make_simple_dense_tensor() { + return TensorSpec("tensor(z[3])") + .add({{"z", 0}}, 1.0) + .add({{"z", 1}}, 2.0) + .add({{"z", 2}}, 3.5); +} + TensorSpec make_sparse_tensor() { return TensorSpec("tensor(x{},y{})") .add({{"x", "foo"}, {"y", "bar"}}, 1.0) .add({{"x", "bar"}, {"y", "foo"}}, 2.0); } +TensorSpec make_simple_sparse_tensor() { + return TensorSpec("tensor(mydim{})") + .add({{"mydim", "foo"}}, 1.0) + .add({{"mydim", "three"}}, 3.0) + .add({{"mydim", "bar"}}, 2.0); +} + TensorSpec make_mixed_tensor() { return TensorSpec("tensor(x{},y[2])") .add({{"x", "foo"}, {"y", 0}}, 1.0) @@ -75,6 +89,16 @@ TEST_F("require that lz4 compressed sparse tensor can be loaded", ConstantTensor TEST_DO(verify_tensor(make_sparse_tensor(), f1.create(TEST_PATH("sparse.json.lz4"), "tensor(x{},y{})"))); } +TEST_F("require that sparse tensor short form can be loaded", ConstantTensorLoader(factory)) { + TEST_DO(verify_tensor(make_simple_sparse_tensor(), f1.create(TEST_PATH("sparse-short1.json"), "tensor(mydim{})"))); + TEST_DO(verify_tensor(make_simple_sparse_tensor(), f1.create(TEST_PATH("sparse-short2.json"), "tensor(mydim{})"))); +} + +TEST_F("require that dense tensor short form can be loaded", ConstantTensorLoader(factory)) { + TEST_DO(verify_tensor(make_simple_dense_tensor(), f1.create(TEST_PATH("dense-short1.json"), "tensor(z[3])"))); + TEST_DO(verify_tensor(make_simple_dense_tensor(), f1.create(TEST_PATH("dense-short2.json"), "tensor(z[3])"))); +} + TEST_F("require that bad lz4 file fails to load creating empty result", ConstantTensorLoader(factory)) { TEST_DO(verify_tensor(sparse_tensor_nocells(), f1.create(TEST_PATH("bad_lz4.json.lz4"), "tensor(x{},y{})"))); } diff --git a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp index 9af473f1f94..5654a3abcbe 100644 --- a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp +++ b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp @@ -41,6 +41,52 @@ struct AddressExtractor : ObjectTraverser { } }; +struct SingleMappedExtractor : ObjectTraverser { + const vespalib::string &dimension; + TensorSpec &spec; + SingleMappedExtractor(const vespalib::string &dimension_in, TensorSpec &spec_in) + : dimension(dimension_in), + spec(spec_in) + {} + void field(const Memory &symbol, const Inspector &inspector) override { + vespalib::string label = symbol.make_string(); + double value = inspector.asDouble(); + TensorSpec::Address address; + address.emplace(dimension, label); + spec.add(address, value); + } +}; + + +void decodeSingleMappedForm(const Inspector &root, const ValueType &value_type, TensorSpec &spec) { + auto extractor = SingleMappedExtractor(value_type.dimensions()[0].name, spec); + root.traverse(extractor); +} + +void decodeSingleDenseForm(const Inspector &values, const ValueType &value_type, TensorSpec &spec) { + const auto &dimension = value_type.dimensions()[0].name; + for (size_t i = 0; i < values.entries(); ++i) { + TensorSpec::Address address; + address.emplace(dimension, TensorSpec::Label(i)); + spec.add(address, values[i].asDouble()); + } +} + +void decodeLiteralForm(const Inspector &cells, const ValueType &value_type, TensorSpec &spec) { + std::set<vespalib::string> indexed; + for (const auto &dimension: value_type.dimensions()) { + if (dimension.is_indexed()) { + indexed.insert(dimension.name); + } + } + for (size_t i = 0; i < cells.entries(); ++i) { + TensorSpec::Address address; + AddressExtractor extractor(indexed, address); + cells[i]["address"].traverse(extractor); + spec.add(address, cells[i]["value"].asDouble()); + } +} + void decode_json(const vespalib::string &path, Input &input, Slime &slime) { if (slime::JsonFormat::decode(input, slime) == 0) { LOG(warning, "file contains invalid json: %s", path.c_str()); @@ -90,19 +136,26 @@ ConstantTensorLoader::create(const vespalib::string &path, const vespalib::strin } Slime slime; decode_json(path, slime); - std::set<vespalib::string> indexed; - for (const auto &dimension: value_type.dimensions()) { - if (dimension.is_indexed()) { - indexed.insert(dimension.name); - } - } TensorSpec spec(type); - const Inspector &cells = slime.get()["cells"]; - for (size_t i = 0; i < cells.entries(); ++i) { - TensorSpec::Address address; - AddressExtractor extractor(indexed, address); - cells[i]["address"].traverse(extractor); - spec.add(address, cells[i]["value"].asDouble()); + bool isSingleDenseType = value_type.is_dense() && (value_type.count_indexed_dimensions() == 1); + bool isSingleMappedType = value_type.is_sparse() && (value_type.count_mapped_dimensions() == 1); + const Inspector &root = slime.get(); + const Inspector &cells = root["cells"]; + const Inspector &values = root["values"]; + if (cells.type().getId() == vespalib::slime::ARRAY::ID) { + decodeLiteralForm(cells, value_type, spec); + } + else if (cells.type().getId() == vespalib::slime::OBJECT::ID && isSingleMappedType) { + decodeSingleMappedForm(cells, value_type, spec); + } + else if (values.type().getId() == vespalib::slime::ARRAY::ID && isSingleDenseType) { + decodeSingleDenseForm(values, value_type, spec); + } + else if (root.type().getId() == vespalib::slime::OBJECT::ID && isSingleMappedType) { + decodeSingleMappedForm(root, value_type, spec); + } + else if (root.type().getId() == vespalib::slime::ARRAY::ID && isSingleDenseType) { + decodeSingleDenseForm(root, value_type, spec); } try { return std::make_unique<SimpleConstantValue>(value_from_spec(spec, _factory)); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 45e581d73e8..9c34875dfd7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -213,7 +213,7 @@ public class JsonFormat { if (root.field("cells").valid() && ! primitiveContent(root.field("cells"))) decodeCells(root.field("cells"), builder); else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed())) - decodeValues(root.field("values"), builder); + decodeValuesAtTop(root.field("values"), builder); else if (root.field("blocks").valid()) decodeBlocks(root.field("blocks"), builder); else @@ -252,11 +252,11 @@ public class JsonFormat { builder.cell(asAddress(key, builder.type()), decodeNumeric(value)); } - private static void decodeValues(Inspector values, Tensor.Builder builder) { - decodeValues(values, builder, new MutableInteger(0)); + private static void decodeValuesAtTop(Inspector values, Tensor.Builder builder) { + decodeNestedValues(values, builder, new MutableInteger(0)); } - private static void decodeValues(Inspector values, Tensor.Builder builder, MutableInteger index) { + private static void decodeNestedValues(Inspector values, Tensor.Builder builder, MutableInteger index) { if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder)) throw new IllegalArgumentException("An array of values can only be used with a dense tensor. Use a map instead"); if (values.type() == Type.STRING) { @@ -275,7 +275,7 @@ public class JsonFormat { values.traverse((ArrayTraverser) (__, value) -> { if (value.type() == Type.ARRAY) - decodeValues(value, builder, index); + decodeNestedValues(value, builder, index); else if (value.type() == Type.LONG || value.type() == Type.DOUBLE) indexedBuilder.cellByDirectIndex(index.next(), value.asDouble()); else @@ -300,7 +300,7 @@ public class JsonFormat { if (block.type() != Type.OBJECT) throw new IllegalArgumentException("Expected an item in a blocks array to be an object, not " + block.type()); mixedBuilder.block(decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()), - decodeValues(block.field("values"), mixedBuilder)); + decodeValuesInBlock(block.field("values"), mixedBuilder)); } /** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */ @@ -311,7 +311,7 @@ public class JsonFormat { if (isArrayOfObjects(root)) decodeCells(root, builder); else if ( ! hasMapped) - decodeValues(root, builder); + decodeValuesAtTop(root, builder); else if (hasMapped && hasIndexed) decodeBlocks(root, builder); else @@ -330,7 +330,7 @@ public class JsonFormat { if (value.type() != Type.ARRAY) throw new IllegalArgumentException("Expected an item in a blocks array to be an array, not " + value.type()); mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()), - decodeValues(value, mixedBuilder)); + decodeValuesInBlock(value, mixedBuilder)); } private static byte decodeHex(String input, int index) { @@ -408,7 +408,7 @@ public class JsonFormat { }; } - private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { + private static double[] decodeValuesInBlock(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) { double[] values = new double[(int)mixedBuilder.denseSubspaceSize()]; if (valuesField.type() == Type.ARRAY) { if (valuesField.entries() == 0) { |