summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2021-08-31 15:59:49 +0200
committerGitHub <noreply@github.com>2021-08-31 15:59:49 +0200
commit371836c15a63e120ac5d0ba6c44cc49d85593eea (patch)
treef6c35a4bf642fff9cc41117d535bb6e956d6957f
parentfc0255c85f91c4b80384909aa56a9ec83a9a1613 (diff)
parent522578bb7392c95ac0c8ef2b599f65d6d79df987 (diff)
Merge pull request #18918 from vespa-engine/lesters/parse-unbound-tensor-in-short-form
Parse unbound tensors in short form
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java83
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java47
3 files changed, 147 insertions, 5 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
index 5b1c7a478b1..df89919a76e 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java
@@ -174,6 +174,15 @@ public class ModelsEvaluationHandlerTest {
}
@Test
+ public void testMnistSoftmaxEvaluateSpecificFunctionWithBindingsShortForm() {
+ Map<String, String> properties = new HashMap<>();
+ properties.put("Placeholder", inputTensorShortForm());
+ String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval";
+ String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}";
+ handler.assertResponse(url, properties, 200, expected);
+ }
+
+ @Test
public void testMnistSavedDetails() {
String url = "http://localhost:8080/model-evaluation/v1/mnist_saved";
String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}";
@@ -224,4 +233,17 @@ public class ModelsEvaluationHandlerTest {
return b.build().toString();
}
+ private String inputTensorShortForm() {
+ StringBuilder sb = new StringBuilder();
+ sb.append("[[");
+ for (int i = 0; i < 784; i++) {
+ sb.append("0.0");
+ if (i < 783) {
+ sb.append(",");
+ }
+ }
+ sb.append("]]");
+ return sb.toString();
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 0a1d9b6cf6e..eaa6e50e87f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -124,15 +124,30 @@ class TensorParser {
if (type.isEmpty())
throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
- if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty())))
- throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
- "only dense dimensions with a given size");
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get());
- new DenseValueParser(valueString, dimensionOrder, builder).parse();
+ IndexedTensor.Builder builder = IndexedTensor.Builder.of(type.get());
+
+ if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) {
+ new UnboundDenseValueParser(valueString, builder).parse();
+ return checkBoundDimensionSizes(builder.build());
+ }
+
+ new DenseValueParser(valueString, dimensionOrder, (IndexedTensor.BoundBuilder) builder).parse();
return builder.build();
}
+ private static Tensor checkBoundDimensionSizes(IndexedTensor tensor) {
+ TensorType type = tensor.type();
+ for (int i = 0; i < type.dimensions().size(); ++i) {
+ TensorType.Dimension dimension = type.dimensions().get(i);
+ if (dimension.size().isPresent() && dimension.size().get() != tensor.dimensionSizes().size(i)) {
+ throw new IllegalArgumentException("Unexpected size " + tensor.dimensionSizes().size(i) +
+ " for dimension " + dimension.name() + " for type " + type);
+ }
+ }
+ return tensor;
+ }
+
private static abstract class ValueParser {
protected final String string;
@@ -299,6 +314,64 @@ class TensorParser {
}
/**
+ * Parses unbound tensor short forms - e.g. tensor(x[],y[]):[[1,2,3],[4,5,6]]
+ */
+ private static class UnboundDenseValueParser extends ValueParser {
+
+ private final IndexedTensor.Builder builder;
+ private final long[] indexes;
+
+ public UnboundDenseValueParser(String string, IndexedTensor.Builder builder) {
+ super(string);
+ this.builder = builder;
+ this.indexes = new long[builder.type().dimensions().size()];
+ }
+
+ public void parse() {
+ consumeList(0);
+ }
+
+ private void consumeList(int dimension) {
+ consume('[');
+ indexes[dimension] = 0;
+ while ( ! atListEnd() ) {
+ if (isInnerMostDimension(dimension)) {
+ consumeNumber();
+ } else {
+ consumeList(dimension + 1);
+ }
+ indexes[dimension]++;
+ consumeOptional(',');
+ }
+ consume(']');
+ }
+
+ private void consumeNumber() {
+ Number number = consumeNumber(builder.type().valueType());
+ switch (builder.type().valueType()) {
+ case DOUBLE: builder.cell((Double)number, indexes); break;
+ case FLOAT: builder.cell((Float)number, indexes); break;
+ case BFLOAT16: builder.cell((Float)number, indexes); break;
+ case INT8: builder.cell((Float)number, indexes); break;
+ }
+ }
+
+ private boolean isInnerMostDimension(int dimension) {
+ return dimension == (indexes.length - 1);
+ }
+
+ protected boolean atListEnd() {
+ skipSpace();
+ if (position >= string.length()) {
+ throw new IllegalArgumentException("At value position " + position + ": Expected a ']'" +
+ " but got the end of the string");
+ }
+ return string.charAt(position) == ']';
+ }
+
+ }
+
+ /**
* Parses mixed tensor short forms {a:[1,2], ...} AND 1d mapped tensor short form {a:b, ...}.
*/
private static class MixedValueParser extends ValueParser {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 431e4b06263..b869107e744 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -150,6 +150,53 @@ public class TensorParserTestCase {
Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}"));
}
+ @Test
+ public void testUnboundShortFormParsing() {
+ assertEquals(Tensor.from("tensor(x[]):[1.0, 2.0]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[])")).cell(1.0, 0).cell(2.0, 1).build());
+ assertEquals(Tensor.from("tensor<float>(x[]):[1.0, 2.0]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[])")).cell(1.0, 0).cell(2.0, 1).build());
+ assertEquals(Tensor.from("tensor<int8>(x[]):[1.0, 2.0]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[])")).cell(1.0, 0).cell(2.0, 1).build());
+ assertEquals(Tensor.from("tensor<bfloat16>(x[]):[1.0, 2.0]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[])")).cell(1.0, 0).cell(2.0, 1).build());
+
+ assertEquals(Tensor.from("tensor(x[],y[]):[[1,2,3,4]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[])"))
+ .cell(1.0, 0, 0).cell(2.0, 0, 1).cell(3.0, 0, 2).cell(4.0, 0, 3).build());
+ assertEquals(Tensor.from("tensor(x[],y[]):[[1,2],[3,4]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[])"))
+ .cell(1.0, 0, 0).cell(2.0, 0, 1).cell(3.0, 1, 0).cell(4.0, 1, 1).build());
+ assertEquals(Tensor.from("tensor(x[],y[]):[[1],[2],[3],[4]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[])"))
+ .cell(1.0, 0, 0).cell(2.0, 1, 0).cell(3.0, 2, 0).cell(4.0, 3, 0).build());
+ assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1,2],[3,4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 0, 1, 0).cell(4.0, 0, 1, 1).build());
+ assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 0, 1, 0).cell(3.0, 0, 2, 0).cell(4.0, 0, 3, 0).build());
+ assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1,2,3,4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 0, 0, 2).cell(4.0, 0, 0, 3).build());
+ assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 1, 0, 0).cell(3.0, 2, 0, 0).cell(4.0, 3, 0, 0).build());
+ assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1, 2]],[[3, 4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 1, 0, 0).cell(4.0, 1, 0, 1).build());
+
+ assertEquals(Tensor.from("tensor(x[],y[],z[4]):[[[1,2,3,4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 0, 0, 2).cell(4.0, 0, 0, 3).build());
+ assertEquals(Tensor.from("tensor(x[2],y[],z[2]):[[[1, 2]],[[3, 4]]]"),
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])"))
+ .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 1, 0, 0).cell(4.0, 1, 0, 1).build());
+
+ assertIllegal("Unexpected size 2 for dimension y for type tensor(x[],y[3])",
+ "tensor(x[],y[3]):[[1,2],[3,4]]");
+ }
+
private void assertDense(Tensor expectedTensor, String denseFormat) {
assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat));
assertEquals(denseFormat, expectedTensor.toString());