diff options
3 files changed, 147 insertions, 5 deletions
diff --git a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java index 5b1c7a478b1..df89919a76e 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/handler/ModelsEvaluationHandlerTest.java @@ -174,6 +174,15 @@ public class ModelsEvaluationHandlerTest { } @Test + public void testMnistSoftmaxEvaluateSpecificFunctionWithBindingsShortForm() { + Map<String, String> properties = new HashMap<>(); + properties.put("Placeholder", inputTensorShortForm()); + String url = "http://localhost/model-evaluation/v1/mnist_softmax/default.add/eval"; + String expected = "{\"cells\":[{\"address\":{\"d0\":\"0\",\"d1\":\"0\"},\"value\":-0.3546536862850189},{\"address\":{\"d0\":\"0\",\"d1\":\"1\"},\"value\":0.3759574592113495},{\"address\":{\"d0\":\"0\",\"d1\":\"2\"},\"value\":0.06054411828517914},{\"address\":{\"d0\":\"0\",\"d1\":\"3\"},\"value\":-0.251544713973999},{\"address\":{\"d0\":\"0\",\"d1\":\"4\"},\"value\":0.017951013520359993},{\"address\":{\"d0\":\"0\",\"d1\":\"5\"},\"value\":1.2899067401885986},{\"address\":{\"d0\":\"0\",\"d1\":\"6\"},\"value\":-0.10389615595340729},{\"address\":{\"d0\":\"0\",\"d1\":\"7\"},\"value\":0.6367976665496826},{\"address\":{\"d0\":\"0\",\"d1\":\"8\"},\"value\":-1.4136744737625122},{\"address\":{\"d0\":\"0\",\"d1\":\"9\"},\"value\":-0.2573896050453186}]}"; + handler.assertResponse(url, properties, 200, expected); + } + + @Test public void testMnistSavedDetails() { String url = "http://localhost:8080/model-evaluation/v1/mnist_saved"; String expected = "{\"model\":\"mnist_saved\",\"functions\":[{\"function\":\"serving_default.y\",\"info\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y\",\"eval\":\"http://localhost:8080/model-evaluation/v1/mnist_saved/serving_default.y/eval\",\"arguments\":[{\"name\":\"input\",\"type\":\"tensor(d0[],d1[784])\"}]}]}"; @@ -224,4 +233,17 @@ public class ModelsEvaluationHandlerTest { return b.build().toString(); } + private String inputTensorShortForm() { + StringBuilder sb = new StringBuilder(); + sb.append("[["); + for (int i = 0; i < 784; i++) { + sb.append("0.0"); + if (i < 783) { + sb.append(","); + } + } + sb.append("]]"); + return sb.toString(); + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 0a1d9b6cf6e..eaa6e50e87f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -124,15 +124,30 @@ class TensorParser { if (type.isEmpty()) throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) - throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + - "only dense dimensions with a given size"); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); - new DenseValueParser(valueString, dimensionOrder, builder).parse(); + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type.get()); + + if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) { + new UnboundDenseValueParser(valueString, builder).parse(); + return checkBoundDimensionSizes(builder.build()); + } + + new DenseValueParser(valueString, dimensionOrder, (IndexedTensor.BoundBuilder) builder).parse(); return builder.build(); } + private static Tensor checkBoundDimensionSizes(IndexedTensor tensor) { + TensorType type = tensor.type(); + for (int i = 0; i < type.dimensions().size(); ++i) { + TensorType.Dimension dimension = type.dimensions().get(i); + if (dimension.size().isPresent() && dimension.size().get() != tensor.dimensionSizes().size(i)) { + throw new IllegalArgumentException("Unexpected size " + tensor.dimensionSizes().size(i) + + " for dimension " + dimension.name() + " for type " + type); + } + } + return tensor; + } + private static abstract class ValueParser { protected final String string; @@ -299,6 +314,64 @@ class TensorParser { } /** + * Parses unbound tensor short forms - e.g. tensor(x[],y[]):[[1,2,3],[4,5,6]] + */ + private static class UnboundDenseValueParser extends ValueParser { + + private final IndexedTensor.Builder builder; + private final long[] indexes; + + public UnboundDenseValueParser(String string, IndexedTensor.Builder builder) { + super(string); + this.builder = builder; + this.indexes = new long[builder.type().dimensions().size()]; + } + + public void parse() { + consumeList(0); + } + + private void consumeList(int dimension) { + consume('['); + indexes[dimension] = 0; + while ( ! atListEnd() ) { + if (isInnerMostDimension(dimension)) { + consumeNumber(); + } else { + consumeList(dimension + 1); + } + indexes[dimension]++; + consumeOptional(','); + } + consume(']'); + } + + private void consumeNumber() { + Number number = consumeNumber(builder.type().valueType()); + switch (builder.type().valueType()) { + case DOUBLE: builder.cell((Double)number, indexes); break; + case FLOAT: builder.cell((Float)number, indexes); break; + case BFLOAT16: builder.cell((Float)number, indexes); break; + case INT8: builder.cell((Float)number, indexes); break; + } + } + + private boolean isInnerMostDimension(int dimension) { + return dimension == (indexes.length - 1); + } + + protected boolean atListEnd() { + skipSpace(); + if (position >= string.length()) { + throw new IllegalArgumentException("At value position " + position + ": Expected a ']'" + + " but got the end of the string"); + } + return string.charAt(position) == ']'; + } + + } + + /** * Parses mixed tensor short forms {a:[1,2], ...} AND 1d mapped tensor short form {a:b, ...}. */ private static class MixedValueParser extends ValueParser { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 431e4b06263..b869107e744 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -150,6 +150,53 @@ public class TensorParserTestCase { Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); } + @Test + public void testUnboundShortFormParsing() { + assertEquals(Tensor.from("tensor(x[]):[1.0, 2.0]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[])")).cell(1.0, 0).cell(2.0, 1).build()); + assertEquals(Tensor.from("tensor<float>(x[]):[1.0, 2.0]"), + Tensor.Builder.of(TensorType.fromSpec("tensor<float>(x[])")).cell(1.0, 0).cell(2.0, 1).build()); + assertEquals(Tensor.from("tensor<int8>(x[]):[1.0, 2.0]"), + Tensor.Builder.of(TensorType.fromSpec("tensor<int8>(x[])")).cell(1.0, 0).cell(2.0, 1).build()); + assertEquals(Tensor.from("tensor<bfloat16>(x[]):[1.0, 2.0]"), + Tensor.Builder.of(TensorType.fromSpec("tensor<bfloat16>(x[])")).cell(1.0, 0).cell(2.0, 1).build()); + + assertEquals(Tensor.from("tensor(x[],y[]):[[1,2,3,4]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[])")) + .cell(1.0, 0, 0).cell(2.0, 0, 1).cell(3.0, 0, 2).cell(4.0, 0, 3).build()); + assertEquals(Tensor.from("tensor(x[],y[]):[[1,2],[3,4]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[])")) + .cell(1.0, 0, 0).cell(2.0, 0, 1).cell(3.0, 1, 0).cell(4.0, 1, 1).build()); + assertEquals(Tensor.from("tensor(x[],y[]):[[1],[2],[3],[4]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[])")) + .cell(1.0, 0, 0).cell(2.0, 1, 0).cell(3.0, 2, 0).cell(4.0, 3, 0).build()); + assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1,2],[3,4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 0, 1, 0).cell(4.0, 0, 1, 1).build()); + assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1],[2],[3],[4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 0, 1, 0).cell(3.0, 0, 2, 0).cell(4.0, 0, 3, 0).build()); + assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1,2,3,4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 0, 0, 2).cell(4.0, 0, 0, 3).build()); + assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1]],[[2]],[[3]],[[4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 1, 0, 0).cell(3.0, 2, 0, 0).cell(4.0, 3, 0, 0).build()); + assertEquals(Tensor.from("tensor(x[],y[],z[]):[[[1, 2]],[[3, 4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 1, 0, 0).cell(4.0, 1, 0, 1).build()); + + assertEquals(Tensor.from("tensor(x[],y[],z[4]):[[[1,2,3,4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 0, 0, 2).cell(4.0, 0, 0, 3).build()); + assertEquals(Tensor.from("tensor(x[2],y[],z[2]):[[[1, 2]],[[3, 4]]]"), + Tensor.Builder.of(TensorType.fromSpec("tensor(x[],y[],z[])")) + .cell(1.0, 0, 0, 0).cell(2.0, 0, 0, 1).cell(3.0, 1, 0, 0).cell(4.0, 1, 0, 1).build()); + + assertIllegal("Unexpected size 2 for dimension y for type tensor(x[],y[3])", + "tensor(x[],y[3]):[[1,2],[3,4]]"); + } + private void assertDense(Tensor expectedTensor, String denseFormat) { assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat)); assertEquals(denseFormat, expectedTensor.toString()); |