aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-05-22 11:42:17 +0200
committerGitHub <noreply@github.com>2023-05-22 11:42:17 +0200
commit07d0699378a589bcdffa08bba840b695a185ef99 (patch)
treeae4a6057d52f77bd2a14fab214774ffa90b86a03
parentc7a07adf43c13165e49e2aa2ef509ecb2526a48c (diff)
parent95d2e5194acd87facd594201dd1db254a41b1f73 (diff)
Merge pull request #27137 from vespa-engine/arnej/allow-some-short-forms-for-constant-tensors
allow short-form JSON for 1-d constants
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java76
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java20
-rw-r--r--eval/src/tests/eval/value_cache/dense-short1.json1
-rw-r--r--eval/src/tests/eval/value_cache/dense-short2.json3
-rw-r--r--eval/src/tests/eval/value_cache/sparse-short1.json5
-rw-r--r--eval/src/tests/eval/value_cache/sparse-short2.json7
-rw-r--r--eval/src/tests/eval/value_cache/tensor_loader_test.cpp24
-rw-r--r--eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp77
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java18
9 files changed, 199 insertions, 32 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
index 66da43856b1..eccb6910866 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidator.java
@@ -28,11 +28,14 @@ public class ConstantTensorJsonValidator {
private static final String FIELD_CELLS = "cells";
private static final String FIELD_ADDRESS = "address";
private static final String FIELD_VALUE = "value";
+ private static final String FIELD_VALUES = "values";
private static final JsonFactory jsonFactory = new JsonFactory();
private JsonParser parser;
private Map<String, TensorType.Dimension> tensorDimensions;
+ private boolean isSingleDenseType = false;
+ private boolean isSingleMappedType = false;
public void validate(String fileName, TensorType type, Reader tensorData) {
if (fileName.endsWith(".json")) {
@@ -57,19 +60,69 @@ public class ConstantTensorJsonValidator {
.dimensions()
.stream()
.collect(Collectors.toMap(TensorType.Dimension::name, Function.identity()));
+ if (type.dimensions().size() == 1) {
+ this.isSingleMappedType = (type.indexedSubtype() == TensorType.empty);
+ this.isSingleDenseType = (type.mappedSubtype() == TensorType.empty);
+ }
+ var top = parser.nextToken();
+ if (top == JsonToken.START_ARRAY) {
+ consumeValuesArray();
+ } else if (top == JsonToken.START_OBJECT) {
+ consumeTopObject();
+ }
+ });
+ }
- assertNextTokenIs(JsonToken.START_OBJECT);
- assertNextTokenIs(JsonToken.FIELD_NAME);
- assertFieldNameIs(FIELD_CELLS);
+ private void consumeValuesArray() throws IOException {
+ if (! isSingleDenseType) {
+ throw new InvalidConstantTensorException(parser, String.format("Field 'values' is only valid for simple vectors (1-d dense tensors"));
+ }
+ assertCurrentTokenIs(JsonToken.START_ARRAY);
+ while (parser.nextToken() != JsonToken.END_ARRAY) {
+ validateNumeric(parser.getCurrentToken());
+ }
+ }
+ private void consumeTopObject() throws IOException {
+ assertCurrentTokenIs(JsonToken.START_OBJECT);
+ assertNextTokenIs(JsonToken.FIELD_NAME);
+ String fieldName = parser.getCurrentName();
+ if (fieldName.equals(FIELD_VALUES)) {
assertNextTokenIs(JsonToken.START_ARRAY);
+ consumeValuesArray();
+ } else if (fieldName.equals(FIELD_CELLS)) {
+ consumeCellsField();
+ } else {
+ throw new InvalidConstantTensorException(parser, String.format("Expected 'cells' or 'values', got '%s'", fieldName));
+ }
+ assertNextTokenIs(JsonToken.END_OBJECT);
+ }
- while (parser.nextToken() != JsonToken.END_ARRAY) {
- validateTensorCell();
- }
+ private void consumeCellsField() throws IOException {
+ var token = parser.nextToken();
+ if (token == JsonToken.START_ARRAY) {
+ consumeLiteralFormArray();
+ } else if (token == JsonToken.START_OBJECT) {
+ consumeSimpleMappedObject();
+ } else {
+ throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be object or array, but got %s", token.toString()));
+ }
+ }
- assertNextTokenIs(JsonToken.END_OBJECT);
- });
+ private void consumeLiteralFormArray() throws IOException {
+ while (parser.nextToken() != JsonToken.END_ARRAY) {
+ validateTensorCell();
+ }
+ }
+
+ private void consumeSimpleMappedObject() throws IOException {
+ if (! isSingleMappedType) {
+ throw new InvalidConstantTensorException(parser, String.format("Field 'cells' must be an array of address/value objects"));
+ }
+ while (parser.nextToken() != JsonToken.END_OBJECT) {
+ assertCurrentTokenIs(JsonToken.FIELD_NAME);
+ validateTensorCellValue();
+ }
}
private void validateTensorCell() {
@@ -87,7 +140,7 @@ public class ConstantTensorJsonValidator {
if (fieldName.equals(FIELD_ADDRESS)) {
validateTensorAddress();
} else if (fieldName.equals(FIELD_VALUE)) {
- validateTensorValue();
+ validateTensorCellValue();
}
} else {
throw new InvalidConstantTensorException(parser, "Only 'address' or 'value' fields are permitted within a cell object");
@@ -169,9 +222,12 @@ public class ConstantTensorJsonValidator {
throw new InvalidConstantTensorException(parser, String.format("Index '%s' for dimension '%s' is not an integer", value, dimensionName));
}
- private void validateTensorValue() throws IOException {
+ private void validateTensorCellValue() throws IOException {
JsonToken token = parser.nextToken();
+ validateNumeric(token);
+ }
+ private void validateNumeric(JsonToken token) throws IOException {
if (token != JsonToken.VALUE_NUMBER_FLOAT && token != JsonToken.VALUE_NUMBER_INT) {
throw new InvalidConstantTensorException(parser, String.format("Tensor value is not a number (%s)", token.toString()));
}
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
index 80643917a58..42be1592eca 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/application/validation/ConstantTensorJsonValidatorTest.java
@@ -281,7 +281,25 @@ public class ConstantTensorJsonValidatorTest {
" }",
"}"));
});
- assertTrue(exception.getMessage().contains("Expected field name 'cells', got 'stats'"));
+ System.err.println("msg: " + exception.getMessage());
+ assertTrue(exception.getMessage().contains("Expected 'cells' or 'values', got 'stats'"));
+ }
+
+ @Test
+ void ensure_that_values_array_for_vector_works() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[5])"),
+ inputJsonToReader("[5,4.0,3.1,-2,-1.0]"));
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x[5])"),
+ inputJsonToReader("{'values':[5,4.0,3.1,-2,-1.0]}"));
+ }
+
+ @Test
+ void ensure_that_simple_object_for_map_works() {
+ validateTensorJson(
+ TensorType.fromSpec("tensor(x{})"),
+ inputJsonToReader("{'cells':{'a':5,'b':4.0,'c':3.1,'d':-2,'e':-1.0}}"));
}
}
diff --git a/eval/src/tests/eval/value_cache/dense-short1.json b/eval/src/tests/eval/value_cache/dense-short1.json
new file mode 100644
index 00000000000..4e170001c96
--- /dev/null
+++ b/eval/src/tests/eval/value_cache/dense-short1.json
@@ -0,0 +1 @@
+[ 1, 2.0, 3.5 ]
diff --git a/eval/src/tests/eval/value_cache/dense-short2.json b/eval/src/tests/eval/value_cache/dense-short2.json
new file mode 100644
index 00000000000..40121135544
--- /dev/null
+++ b/eval/src/tests/eval/value_cache/dense-short2.json
@@ -0,0 +1,3 @@
+{
+ "values": [ 1, 2.0, 3.5 ]
+}
diff --git a/eval/src/tests/eval/value_cache/sparse-short1.json b/eval/src/tests/eval/value_cache/sparse-short1.json
new file mode 100644
index 00000000000..949b7b2b8bd
--- /dev/null
+++ b/eval/src/tests/eval/value_cache/sparse-short1.json
@@ -0,0 +1,5 @@
+{
+ "foo": 1.0,
+ "bar": 2.0,
+ "three": 3.0
+}
diff --git a/eval/src/tests/eval/value_cache/sparse-short2.json b/eval/src/tests/eval/value_cache/sparse-short2.json
new file mode 100644
index 00000000000..f10b1b6f9fb
--- /dev/null
+++ b/eval/src/tests/eval/value_cache/sparse-short2.json
@@ -0,0 +1,7 @@
+{
+ "cells": {
+ "foo": 1.0,
+ "bar": 2.0,
+ "three": 3.0
+ }
+}
diff --git a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
index 1a77cfe847b..4b4ba3fc0d3 100644
--- a/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
+++ b/eval/src/tests/eval/value_cache/tensor_loader_test.cpp
@@ -19,12 +19,26 @@ TensorSpec make_dense_tensor() {
.add({{"x", 1}, {"y", 1}}, 4.0);
}
+TensorSpec make_simple_dense_tensor() {
+ return TensorSpec("tensor(z[3])")
+ .add({{"z", 0}}, 1.0)
+ .add({{"z", 1}}, 2.0)
+ .add({{"z", 2}}, 3.5);
+}
+
TensorSpec make_sparse_tensor() {
return TensorSpec("tensor(x{},y{})")
.add({{"x", "foo"}, {"y", "bar"}}, 1.0)
.add({{"x", "bar"}, {"y", "foo"}}, 2.0);
}
+TensorSpec make_simple_sparse_tensor() {
+ return TensorSpec("tensor(mydim{})")
+ .add({{"mydim", "foo"}}, 1.0)
+ .add({{"mydim", "three"}}, 3.0)
+ .add({{"mydim", "bar"}}, 2.0);
+}
+
TensorSpec make_mixed_tensor() {
return TensorSpec("tensor(x{},y[2])")
.add({{"x", "foo"}, {"y", 0}}, 1.0)
@@ -75,6 +89,16 @@ TEST_F("require that lz4 compressed sparse tensor can be loaded", ConstantTensor
TEST_DO(verify_tensor(make_sparse_tensor(), f1.create(TEST_PATH("sparse.json.lz4"), "tensor(x{},y{})")));
}
+TEST_F("require that sparse tensor short form can be loaded", ConstantTensorLoader(factory)) {
+ TEST_DO(verify_tensor(make_simple_sparse_tensor(), f1.create(TEST_PATH("sparse-short1.json"), "tensor(mydim{})")));
+ TEST_DO(verify_tensor(make_simple_sparse_tensor(), f1.create(TEST_PATH("sparse-short2.json"), "tensor(mydim{})")));
+}
+
+TEST_F("require that dense tensor short form can be loaded", ConstantTensorLoader(factory)) {
+ TEST_DO(verify_tensor(make_simple_dense_tensor(), f1.create(TEST_PATH("dense-short1.json"), "tensor(z[3])")));
+ TEST_DO(verify_tensor(make_simple_dense_tensor(), f1.create(TEST_PATH("dense-short2.json"), "tensor(z[3])")));
+}
+
TEST_F("require that bad lz4 file fails to load creating empty result", ConstantTensorLoader(factory)) {
TEST_DO(verify_tensor(sparse_tensor_nocells(), f1.create(TEST_PATH("bad_lz4.json.lz4"), "tensor(x{},y{})")));
}
diff --git a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp
index 9af473f1f94..5654a3abcbe 100644
--- a/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp
+++ b/eval/src/vespa/eval/eval/value_cache/constant_tensor_loader.cpp
@@ -41,6 +41,52 @@ struct AddressExtractor : ObjectTraverser {
}
};
+struct SingleMappedExtractor : ObjectTraverser {
+ const vespalib::string &dimension;
+ TensorSpec &spec;
+ SingleMappedExtractor(const vespalib::string &dimension_in, TensorSpec &spec_in)
+ : dimension(dimension_in),
+ spec(spec_in)
+ {}
+ void field(const Memory &symbol, const Inspector &inspector) override {
+ vespalib::string label = symbol.make_string();
+ double value = inspector.asDouble();
+ TensorSpec::Address address;
+ address.emplace(dimension, label);
+ spec.add(address, value);
+ }
+};
+
+
+void decodeSingleMappedForm(const Inspector &root, const ValueType &value_type, TensorSpec &spec) {
+ auto extractor = SingleMappedExtractor(value_type.dimensions()[0].name, spec);
+ root.traverse(extractor);
+}
+
+void decodeSingleDenseForm(const Inspector &values, const ValueType &value_type, TensorSpec &spec) {
+ const auto &dimension = value_type.dimensions()[0].name;
+ for (size_t i = 0; i < values.entries(); ++i) {
+ TensorSpec::Address address;
+ address.emplace(dimension, TensorSpec::Label(i));
+ spec.add(address, values[i].asDouble());
+ }
+}
+
+void decodeLiteralForm(const Inspector &cells, const ValueType &value_type, TensorSpec &spec) {
+ std::set<vespalib::string> indexed;
+ for (const auto &dimension: value_type.dimensions()) {
+ if (dimension.is_indexed()) {
+ indexed.insert(dimension.name);
+ }
+ }
+ for (size_t i = 0; i < cells.entries(); ++i) {
+ TensorSpec::Address address;
+ AddressExtractor extractor(indexed, address);
+ cells[i]["address"].traverse(extractor);
+ spec.add(address, cells[i]["value"].asDouble());
+ }
+}
+
void decode_json(const vespalib::string &path, Input &input, Slime &slime) {
if (slime::JsonFormat::decode(input, slime) == 0) {
LOG(warning, "file contains invalid json: %s", path.c_str());
@@ -90,19 +136,26 @@ ConstantTensorLoader::create(const vespalib::string &path, const vespalib::strin
}
Slime slime;
decode_json(path, slime);
- std::set<vespalib::string> indexed;
- for (const auto &dimension: value_type.dimensions()) {
- if (dimension.is_indexed()) {
- indexed.insert(dimension.name);
- }
- }
TensorSpec spec(type);
- const Inspector &cells = slime.get()["cells"];
- for (size_t i = 0; i < cells.entries(); ++i) {
- TensorSpec::Address address;
- AddressExtractor extractor(indexed, address);
- cells[i]["address"].traverse(extractor);
- spec.add(address, cells[i]["value"].asDouble());
+ bool isSingleDenseType = value_type.is_dense() && (value_type.count_indexed_dimensions() == 1);
+ bool isSingleMappedType = value_type.is_sparse() && (value_type.count_mapped_dimensions() == 1);
+ const Inspector &root = slime.get();
+ const Inspector &cells = root["cells"];
+ const Inspector &values = root["values"];
+ if (cells.type().getId() == vespalib::slime::ARRAY::ID) {
+ decodeLiteralForm(cells, value_type, spec);
+ }
+ else if (cells.type().getId() == vespalib::slime::OBJECT::ID && isSingleMappedType) {
+ decodeSingleMappedForm(cells, value_type, spec);
+ }
+ else if (values.type().getId() == vespalib::slime::ARRAY::ID && isSingleDenseType) {
+ decodeSingleDenseForm(values, value_type, spec);
+ }
+ else if (root.type().getId() == vespalib::slime::OBJECT::ID && isSingleMappedType) {
+ decodeSingleMappedForm(root, value_type, spec);
+ }
+ else if (root.type().getId() == vespalib::slime::ARRAY::ID && isSingleDenseType) {
+ decodeSingleDenseForm(root, value_type, spec);
}
try {
return std::make_unique<SimpleConstantValue>(value_from_spec(spec, _factory));
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index 45e581d73e8..9c34875dfd7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -213,7 +213,7 @@ public class JsonFormat {
if (root.field("cells").valid() && ! primitiveContent(root.field("cells")))
decodeCells(root.field("cells"), builder);
else if (root.field("values").valid() && builder.type().dimensions().stream().allMatch(d -> d.isIndexed()))
- decodeValues(root.field("values"), builder);
+ decodeValuesAtTop(root.field("values"), builder);
else if (root.field("blocks").valid())
decodeBlocks(root.field("blocks"), builder);
else
@@ -252,11 +252,11 @@ public class JsonFormat {
builder.cell(asAddress(key, builder.type()), decodeNumeric(value));
}
- private static void decodeValues(Inspector values, Tensor.Builder builder) {
- decodeValues(values, builder, new MutableInteger(0));
+ private static void decodeValuesAtTop(Inspector values, Tensor.Builder builder) {
+ decodeNestedValues(values, builder, new MutableInteger(0));
}
- private static void decodeValues(Inspector values, Tensor.Builder builder, MutableInteger index) {
+ private static void decodeNestedValues(Inspector values, Tensor.Builder builder, MutableInteger index) {
if ( ! (builder instanceof IndexedTensor.BoundBuilder indexedBuilder))
throw new IllegalArgumentException("An array of values can only be used with a dense tensor. Use a map instead");
if (values.type() == Type.STRING) {
@@ -275,7 +275,7 @@ public class JsonFormat {
values.traverse((ArrayTraverser) (__, value) -> {
if (value.type() == Type.ARRAY)
- decodeValues(value, builder, index);
+ decodeNestedValues(value, builder, index);
else if (value.type() == Type.LONG || value.type() == Type.DOUBLE)
indexedBuilder.cellByDirectIndex(index.next(), value.asDouble());
else
@@ -300,7 +300,7 @@ public class JsonFormat {
if (block.type() != Type.OBJECT)
throw new IllegalArgumentException("Expected an item in a blocks array to be an object, not " + block.type());
mixedBuilder.block(decodeAddress(block.field("address"), mixedBuilder.type().mappedSubtype()),
- decodeValues(block.field("values"), mixedBuilder));
+ decodeValuesInBlock(block.field("values"), mixedBuilder));
}
/** Decodes a tensor value directly at the root, where the format is decided by the tensor type. */
@@ -311,7 +311,7 @@ public class JsonFormat {
if (isArrayOfObjects(root))
decodeCells(root, builder);
else if ( ! hasMapped)
- decodeValues(root, builder);
+ decodeValuesAtTop(root, builder);
else if (hasMapped && hasIndexed)
decodeBlocks(root, builder);
else
@@ -330,7 +330,7 @@ public class JsonFormat {
if (value.type() != Type.ARRAY)
throw new IllegalArgumentException("Expected an item in a blocks array to be an array, not " + value.type());
mixedBuilder.block(asAddress(key, mixedBuilder.type().mappedSubtype()),
- decodeValues(value, mixedBuilder));
+ decodeValuesInBlock(value, mixedBuilder));
}
private static byte decodeHex(String input, int index) {
@@ -408,7 +408,7 @@ public class JsonFormat {
};
}
- private static double[] decodeValues(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
+ private static double[] decodeValuesInBlock(Inspector valuesField, MixedTensor.BoundBuilder mixedBuilder) {
double[] values = new double[(int)mixedBuilder.denseSubspaceSize()];
if (valuesField.type() == Type.ARRAY) {
if (valuesField.entries() == 0) {