summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2021-09-29 09:58:13 +0200
committerGitHub <noreply@github.com>2021-09-29 09:58:13 +0200
commit8923accf7e72d147d6d57185eecc4faf2b4adeb7 (patch)
tree0f856be32d11455e89547c98507a2a2d315e3225
parenta50c3b478de99e23ee5dd1af12efd3ace03d5b28 (diff)
parentac28a2c925e90d0b1c651d8019e113ae4aa5cad9 (diff)
Merge pull request #19304 from vespa-engine/lesters/additional-short-forms-stateless-rest-api
Stateless REST API: short forms for sparse and mixed tensors
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java8
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java2
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java36
-rw-r--r--model-evaluation/src/test/resources/config/models/rank-profiles.cfg9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java96
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java59
7 files changed, 188 insertions, 24 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
index bbd9962be77..9e365056355 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/handler/ModelsEvaluationHandler.java
@@ -20,6 +20,7 @@ import java.io.IOException;
import java.io.OutputStream;
import java.net.URI;
import java.nio.charset.Charset;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import java.util.Map;
import java.util.Optional;
@@ -90,8 +91,11 @@ public class ModelsEvaluationHandler extends ThreadedHttpRequestHandler {
Tensor result = evaluator.evaluate();
Optional<String> format = property(request, "format");
- if (format.isPresent() && format.get().equalsIgnoreCase("short") && result instanceof IndexedTensor) {
- return new Response(200, JsonFormat.encodeShortForm((IndexedTensor) result));
+ if (format.isPresent() && format.get().equalsIgnoreCase("short")) {
+ return new Response(200, JsonFormat.encodeShortForm(result));
+ }
+ else if (format.isPresent() && format.get().equalsIgnoreCase("string")) {
+ return new Response(200, result.toString().getBytes(StandardCharsets.UTF_8));
}
return new Response(200, JsonFormat.encode(result));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
index 0d13b7d4660..3bbdd36e777 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/MlModelsImportingTest.java
@@ -25,7 +25,7 @@ public class MlModelsImportingTest {
public void testImportingModels() {
ModelTester tester = new ModelTester("src/test/resources/config/models/");
- assertEquals(5, tester.models().size());
+ assertEquals(6, tester.models().size());
// TODO: When we get type information in Models, replace the evaluator.context().names() check below by that
{
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 8034be6bb22..7029be24a60 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -48,7 +48,7 @@ public class ModelsEvaluationHandlerTest {
public void testListModels() {
String url = "http://localhost/model-evaluation/v1";
String expected =
- "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
+ "{\"mnist_softmax\":\"http://localhost/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost/model-evaluation/v1/lightgbm_regression\"}";
handler.assertResponse(url, 200, expected);
}
@@ -56,7 +56,7 @@ public class ModelsEvaluationHandlerTest {
public void testListModelsWithDifferentHost() {
String url = "http://localhost/model-evaluation/v1";
String expected =
- "{\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\",\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"}";
+ "{\"mnist_softmax\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax\",\"mnist_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_saved\",\"mnist_softmax_saved\":\"http://localhost:8088/model-evaluation/v1/mnist_softmax_saved\",\"vespa_model\":\"http://localhost:8088/model-evaluation/v1/vespa_model\",\"xgboost_2_2\":\"http://localhost:8088/model-evaluation/v1/xgboost_2_2\",\"lightgbm_regression\":\"http://localhost:8088/model-evaluation/v1/lightgbm_regression\"}";
handler.assertResponse(url, 200, expected, Map.of("Host", "localhost:8088"));
}
@@ -188,7 +188,7 @@ public class ModelsEvaluationHandlerTest {
properties.put("Placeholder", inputTensorShortForm());
properties.put("format", "short");
String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
- String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"value\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}";
+ String expected = "{\"type\":\"tensor(d0[],d1[10])\",\"values\":[[-0.3546536862850189,0.3759574592113495,0.06054411828517914,-0.251544713973999,0.017951013520359993,1.2899067401885986,-0.10389615595340729,0.6367976665496826,-1.4136744737625122,-0.2573896050453186]]}";
handler.assertResponse(url, properties, 200, expected);
}
@@ -214,6 +214,36 @@ public class ModelsEvaluationHandlerTest {
}
@Test
+ public void testVespaModelShortOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("format", "short");
+ String url = "http://localhost/model-evaluation/v1/vespa_model/";
+ handler.assertResponse(url + "test_mapped/eval", properties, 200,
+ "{\"type\":\"tensor(d0{})\",\"cells\":{\"a\":1.0,\"b\":2.0}}");
+ handler.assertResponse(url + "test_indexed/eval", properties, 200,
+ "{\"type\":\"tensor(d0[2],d1[3])\",\"values\":[[1.0,2.0,3.0],[4.0,5.0,6.0]]}");
+ handler.assertResponse(url + "test_mixed/eval", properties, 200,
+ "{\"type\":\"tensor(x{},y[3])\",\"blocks\":{\"a\":[1.0,2.0,3.0],\"b\":[4.0,5.0,6.0]}}");
+ handler.assertResponse(url + "test_mixed_2/eval", properties, 200,
+ "{\"type\":\"tensor(a[2],b[2],c{},d[2])\",\"blocks\":{\"a\":[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]],\"b\":[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]}}");
+ }
+
+ @Test
+ public void testVespaModelLiteralOutput() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("format", "string");
+ String url = "http://localhost/model-evaluation/v1/vespa_model/";
+ handler.assertResponse(url + "test_mapped/eval", properties, 200,
+ "tensor(d0{}):{a:1.0,b:2.0}");
+ handler.assertResponse(url + "test_indexed/eval", properties, 200,
+ "tensor(d0[2],d1[3]):[[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]");
+ handler.assertResponse(url + "test_mixed/eval", properties, 200,
+ "tensor(x{},y[3]):{a:[1.0, 2.0, 3.0],b:[4.0, 5.0, 6.0]}");
+ handler.assertResponse(url + "test_mixed_2/eval", properties, 200,
+ "tensor(a[2],b[2],c{},d[2]):{a:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]],b:[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]}");
+ }
+
+ @Test
public void testMnistSavedEvaluateSpecificFunction() {
Map<String, String> properties = new HashMap<>();
properties.put("input", inputTensor());
diff --git a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
index 385115b7cd4..4877a24f171 100644
--- a/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
+++ b/model-evaluation/src/test/resources/config/models/rank-profiles.cfg
@@ -29,3 +29,12 @@ rankprofile[3].fef.property[4].value "tensor(d1[10])"
rankprofile[4].name "lightgbm_regression"
rankprofile[4].fef.property[0].name "rankingExpression(lightgbm_regression).rankingScript"
rankprofile[4].fef.property[0].value "if (!(numerical_2 >= 0.46643291586559305), 2.1594397038037663, if (categorical_2 in ["k", "l", "m"], 2.235297305276056, 2.1792953471546546)) + if (categorical_1 in ["d", "e"], 0.03070842919354316, if (!(numerical_1 >= 0.5102250691730842), -0.04439151147520909, 0.005117411709368601)) + if (!(numerical_2 >= 0.668665477622446), if (!(numerical_2 >= 0.008118820676863816), -0.15361238490967524, -0.01192330846157292), 0.03499044894987518) + if (!(numerical_1 >= 0.5201391072644542), -0.02141000620783247, if (categorical_1 in ["a", "b"], -0.004121485787596721, 0.04534090904886873)) + if (categorical_2 in ["k", "l", "m"], if (!(numerical_2 >= 0.27283279016959255), -0.01924803254356527, 0.03643772842347651), -0.02701711918923075)"
+rankprofile[5].name "vespa_model"
+rankprofile[5].fef.property[0].name "rankingExpression(test_mapped).rankingScript"
+rankprofile[5].fef.property[0].value "tensor(d0{}):{a:1, b:2}"
+rankprofile[5].fef.property[1].name "rankingExpression(test_indexed).rankingScript"
+rankprofile[5].fef.property[1].value "tensor(d0[2],d1[3]):[[1,2,3],[4,5,6]]"
+rankprofile[5].fef.property[2].name "rankingExpression(test_mixed).rankingScript"
+rankprofile[5].fef.property[2].value "tensor(x{},y[3]):{a:[1,2,3], b:[4,5,6]}"
+rankprofile[5].fef.property[3].name "rankingExpression(test_mixed_2).rankingScript"
+rankprofile[5].fef.property[3].value "tensor(a[2],b[2],c{},d[2]):{a:[[[1,2], [3,4]], [[5,6], [7,8]]], b:[[[1,2], [3,4]], [[5,6], [7,8]]] }"
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 71ed347219e..33dcd458980 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -91,7 +91,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- /** Returns a label as a string with approriate quoting/escaping when necessary */
+ /** Returns a label as a string with appropriate quoting/escaping when necessary */
public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
index cb7539d8565..87157495485 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java
@@ -11,12 +11,21 @@ import com.yahoo.slime.Slime;
import com.yahoo.slime.Type;
import com.yahoo.tensor.DimensionSizes;
import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.MappedTensor;
import com.yahoo.tensor.MixedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.functions.ConstantTensor;
+import com.yahoo.tensor.functions.Slice;
+import java.util.ArrayList;
+import java.util.HashSet;
import java.util.Iterator;
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
/**
* Writes tensors on the JSON format used in Vespa tensor document fields:
@@ -46,12 +55,33 @@ public class JsonFormat {
}
/** Serializes the given tensor type and value into a short-form JSON format */
- public static byte[] encodeShortForm(IndexedTensor tensor) {
+ public static byte[] encodeShortForm(Tensor tensor) {
Slime slime = new Slime();
Cursor root = slime.setObject();
root.setString("type", tensor.type().toString());
- Cursor value = root.setArray("value");
- encodeList(tensor, value, new long[tensor.dimensionSizes().dimensions()], 0);
+
+ // Encode as nested lists if indexed tensor
+ if (tensor instanceof IndexedTensor) {
+ IndexedTensor denseTensor = (IndexedTensor) tensor;
+ encodeValues(denseTensor, root.setArray("values"), new long[denseTensor.dimensionSizes().dimensions()], 0);
+ }
+
+ // Short form for a single mapped dimension
+ else if (tensor instanceof MappedTensor && tensor.type().dimensions().size() == 1) {
+ encodeSingleDimensionCells((MappedTensor) tensor, root);
+ }
+
+ // Short form for a mixed tensor
+ else if (tensor instanceof MixedTensor &&
+ tensor.type().dimensions().stream().filter(TensorType.Dimension::isMapped).count() >= 1) {
+ encodeBlocks((MixedTensor) tensor, root);
+ }
+
+ // No other short forms exist: default to standard cell address output
+ else {
+ encodeCells(tensor, root);
+ }
+
return com.yahoo.slime.JsonFormat.toJsonBytes(slime);
}
@@ -65,22 +95,78 @@ public class JsonFormat {
}
}
+ private static void encodeSingleDimensionCells(MappedTensor tensor, Cursor cursor) {
+ Cursor cells = cursor.setObject("cells");
+ if (tensor.type().dimensions().size() > 1)
+ throw new IllegalStateException("JSON encode of mapped tensor can only contain a single dimension");
+ tensor.cells().forEach((k,v) -> cells.setDouble(k.label(0), v));
+ }
+
private static void encodeAddress(TensorType type, TensorAddress address, Cursor addressObject) {
for (int i = 0; i < address.size(); i++)
addressObject.setString(type.dimensions().get(i).name(), address.label(i));
}
- private static void encodeList(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
+ private static void encodeValues(IndexedTensor tensor, Cursor cursor, long[] indexes, int dimension) {
DimensionSizes sizes = tensor.dimensionSizes();
for (indexes[dimension] = 0; indexes[dimension] < sizes.size(dimension); ++indexes[dimension]) {
if (dimension < (sizes.dimensions() - 1)) {
- encodeList(tensor, cursor.addArray(), indexes, dimension + 1);
+ encodeValues(tensor, cursor.addArray(), indexes, dimension + 1);
} else {
cursor.addDouble(tensor.get(indexes));
}
}
}
+ private static void encodeBlocks(MixedTensor tensor, Cursor cursor) {
+ var mappedDimensions = tensor.type().dimensions().stream().filter(d -> d.isMapped())
+ .map(d -> TensorType.Dimension.mapped(d.name())).collect(Collectors.toList());
+ if (mappedDimensions.size() < 1) {
+ throw new IllegalArgumentException("Should be ensured by caller");
+ }
+ cursor = (mappedDimensions.size() == 1) ? cursor.setObject("blocks") : cursor.setArray("blocks");
+
+ // Create tensor type for mapped dimensions subtype
+ TensorType mappedSubType = new TensorType.Builder(mappedDimensions).build();
+
+ // Find all unique indices for the mapped dimensions
+ Set<TensorAddress> denseSubSpaceAddresses = new HashSet<>();
+ tensor.cellIterator().forEachRemaining((cell) -> {
+ denseSubSpaceAddresses.add(subAddress(cell.getKey(), mappedSubType, tensor.type()));
+ });
+
+ // Slice out dense subspace of each and encode dense subspace as a list
+ for (TensorAddress denseSubSpaceAddress : denseSubSpaceAddresses) {
+ IndexedTensor denseSubspace = (IndexedTensor) sliceSubAddress(tensor, denseSubSpaceAddress, mappedSubType);
+
+ if (mappedDimensions.size() == 1) {
+ encodeValues(denseSubspace, cursor.setArray(denseSubSpaceAddress.label(0)), new long[denseSubspace.dimensionSizes().dimensions()], 0);
+ } else {
+ Cursor block = cursor.addObject();
+ encodeAddress(mappedSubType, denseSubSpaceAddress, block.setObject("address"));
+ encodeValues(denseSubspace, block.setArray("values"), new long[denseSubspace.dimensionSizes().dimensions()], 0);
+ }
+
+ }
+ }
+
+ private static TensorAddress subAddress(TensorAddress address, TensorType subType, TensorType origType) {
+ TensorAddress.Builder builder = new TensorAddress.Builder(subType);
+ for (TensorType.Dimension dim : subType.dimensions()) {
+ builder.add(dim.name(), address.label(origType.indexOfDimension(dim.name()).
+ orElseThrow(() -> new IllegalStateException("Could not find mapped dimension index"))));
+ }
+ return builder.build();
+ }
+
+ private static Tensor sliceSubAddress(Tensor tensor, TensorAddress subAddress, TensorType subType) {
+ List<Slice.DimensionValue<Name>> sliceDims = new ArrayList<>(subAddress.size());
+ for (int i = 0; i < subAddress.size(); ++i) {
+ sliceDims.add(new Slice.DimensionValue<>(subType.dimensions().get(i).name(), subAddress.label(i)));
+ }
+ return new Slice<>(new ConstantTensor<>(tensor), sliceDims).evaluate();
+ }
+
/** Deserializes the given tensor from JSON format */
// NOTE: This must be kept in sync with com.yahoo.document.json.readers.TensorReader in the document module
public static Tensor decode(TensorType type, byte[] jsonTensorValue) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
index 87796501917..cdfd19eb5c8 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java
@@ -98,27 +98,62 @@ public class JsonFormatTestCase {
}
@Test
- public void testDenseTensorShortForm() {
+ public void testEncodeIndexedShortForm() {
assertEncodeShortForm("tensor(x[]):[1.0, 2.0]",
- "{\"type\":\"tensor(x[])\",\"value\":[1.0,2.0]}");
+ "{\"type\":\"tensor(x[])\",\"values\":[1.0,2.0]}");
assertEncodeShortForm("tensor<float>(x[]):[1.0, 2.0]",
- "{\"type\":\"tensor<float>(x[])\",\"value\":[1.0,2.0]}");
+ "{\"type\":\"tensor<float>(x[])\",\"values\":[1.0,2.0]}");
assertEncodeShortForm("tensor(x[],y[]):[[1,2,3,4]]",
- "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0,3.0,4.0]]}");
+ "{\"type\":\"tensor(x[],y[])\",\"values\":[[1.0,2.0,3.0,4.0]]}");
assertEncodeShortForm("tensor(x[],y[]):[[1,2],[3,4]]",
- "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0,2.0],[3.0,4.0]]}");
+ "{\"type\":\"tensor(x[],y[])\",\"values\":[[1.0,2.0],[3.0,4.0]]}");
assertEncodeShortForm("tensor(x[],y[]):[[1],[2],[3],[4]]",
- "{\"type\":\"tensor(x[],y[])\",\"value\":[[1.0],[2.0],[3.0],[4.0]]}");
+ "{\"type\":\"tensor(x[],y[])\",\"values\":[[1.0],[2.0],[3.0],[4.0]]}");
assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2],[3,4]]]",
- "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0],[3.0,4.0]]]}");
+ "{\"type\":\"tensor(x[],y[],z[])\",\"values\":[[[1.0,2.0],[3.0,4.0]]]}");
assertEncodeShortForm("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]",
- "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0],[2.0],[3.0],[4.0]]]}");
+ "{\"type\":\"tensor(x[],y[],z[])\",\"values\":[[[1.0],[2.0],[3.0],[4.0]]]}");
assertEncodeShortForm("tensor(x[],y[],z[]):[[[1,2,3,4]]]",
- "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0,2.0,3.0,4.0]]]}");
+ "{\"type\":\"tensor(x[],y[],z[])\",\"values\":[[[1.0,2.0,3.0,4.0]]]}");
assertEncodeShortForm("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]",
- "{\"type\":\"tensor(x[],y[],z[])\",\"value\":[[[1.0]],[[2.0]],[[3.0]],[[4.0]]]}");
+ "{\"type\":\"tensor(x[],y[],z[])\",\"values\":[[[1.0]],[[2.0]],[[3.0]],[[4.0]]]}");
assertEncodeShortForm("tensor(x[],y[],z[2]):[[[1, 2]],[[3, 4]]]",
- "{\"type\":\"tensor(x[],y[],z[2])\",\"value\":[[[1.0,2.0]],[[3.0,4.0]]]}");
+ "{\"type\":\"tensor(x[],y[],z[2])\",\"values\":[[[1.0,2.0]],[[3.0,4.0]]]}");
+ }
+
+ @Test
+ public void testEncodeMappedSingleDimensionShortForm() {
+ assertEncodeShortForm("tensor(x{}):{}",
+ "{\"type\":\"tensor(x{})\",\"cells\":{}}");
+ assertEncodeShortForm("tensor(x{}):{a:1,b:2}",
+ "{\"type\":\"tensor(x{})\",\"cells\":{\"a\":1.0,\"b\":2.0}}");
+ // Multiple mapped dimensions: no short form available
+ assertEncodeShortForm("tensor(x{},y{}):{{x:a,y:b}:1,{x:c,y:d}:2}",
+ "{\"type\":\"tensor(x{},y{})\",\"cells\":[{\"address\":{\"x\":\"a\",\"y\":\"b\"},\"value\":1.0},{\"address\":{\"x\":\"c\",\"y\":\"d\"},\"value\":2.0}]}");
+ }
+
+ @Test
+ public void testEncodeMixedShortForm() {
+ assertEncodeShortForm("tensor(x{},y[2]):{a:[1,2], b:[3,4] }",
+ "{\"type\":\"tensor(x{},y[2])\",\"blocks\":{\"a\":[1.0,2.0],\"b\":[3.0,4.0]}}");
+ assertEncodeShortForm("tensor(x[2],y{}):{a:[1,2], b:[3,4] }",
+ "{\"type\":\"tensor(x[2],y{})\",\"blocks\":{\"a\":[1.0,2.0],\"b\":[3.0,4.0]}}");
+ assertEncodeShortForm("tensor(x{},y[2],z[2]):{a:[[1,2],[3,4]], b:[[5,6],[7,8]] }",
+ "{\"type\":\"tensor(x{},y[2],z[2])\",\"blocks\":{\"a\":[[1.0,2.0],[3.0,4.0]],\"b\":[[5.0,6.0],[7.0,8.0]]}}");
+ assertEncodeShortForm("tensor(x[1],y{},z[4]):{a:[[1,2,3,4]], b:[[5,6,7,8]] }",
+ "{\"type\":\"tensor(x[1],y{},z[4])\",\"blocks\":{\"a\":[[1.0,2.0,3.0,4.0]],\"b\":[[5.0,6.0,7.0,8.0]]}}");
+ assertEncodeShortForm("tensor(x[4],y[1],z{}):{a:[[1],[2],[3],[4]], b:[[5],[6],[7],[8]] }",
+ "{\"type\":\"tensor(x[4],y[1],z{})\",\"blocks\":{\"a\":[[1.0],[2.0],[3.0],[4.0]],\"b\":[[5.0],[6.0],[7.0],[8.0]]}}");
+ assertEncodeShortForm("tensor(a[2],b[2],c{},d[2]):{a:[[[1,2], [3,4]], [[5,6], [7,8]]], b:[[[1,2], [3,4]], [[5,6], [7,8]]] }",
+ "{\"type\":\"tensor(a[2],b[2],c{},d[2])\",\"blocks\":{" +
+ "\"a\":[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]," +
+ "\"b\":[[[1.0,2.0],[3.0,4.0]],[[5.0,6.0],[7.0,8.0]]]}}");
+
+ // Multiple mapped dimensions
+ assertEncodeShortForm("tensor(x{},y{},z[2]):{{x:a,y:0,z:0}:1, {x:a,y:0,z:1}:2, {x:b,y:1,z:0}:3, {x:b,y:1,z:1}:4 }",
+ "{\"type\":\"tensor(x{},y{},z[2])\",\"blocks\":[{\"address\":{\"x\":\"a\",\"y\":\"0\"},\"values\":[1.0,2.0]},{\"address\":{\"x\":\"b\",\"y\":\"1\"},\"values\":[3.0,4.0]}]}");
+ assertEncodeShortForm("tensor(x{},y[2],z{}):{{x:a,y:0,z:0}:1, {x:a,y:1,z:0}:2, {x:b,y:0,z:1}:3, {x:b,y:1,z:1}:4 }",
+ "{\"type\":\"tensor(x{},y[2],z{})\",\"blocks\":[{\"address\":{\"x\":\"a\",\"z\":\"0\"},\"values\":[1.0,2.0]},{\"address\":{\"x\":\"b\",\"z\":\"1\"},\"values\":[3.0,4.0]}]}");
}
@Test
@@ -315,7 +350,7 @@ public class JsonFormatTestCase {
}
private void assertEncodeShortForm(String tensor, String expected) {
- byte[] json = JsonFormat.encodeShortForm((IndexedTensor) Tensor.from(tensor));
+ byte[] json = JsonFormat.encodeShortForm(Tensor.from(tensor));
assertEquals(expected, new String(json, StandardCharsets.UTF_8));
}