diff options
Diffstat (limited to 'vespajlib')
7 files changed, 47 insertions, 15 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/slime/Type.java b/vespajlib/src/main/java/com/yahoo/slime/Type.java index 5c5524d8b88..de188b687c5 100644 --- a/vespajlib/src/main/java/com/yahoo/slime/Type.java +++ b/vespajlib/src/main/java/com/yahoo/slime/Type.java @@ -5,6 +5,7 @@ package com.yahoo.slime; * Enumeration of all possibly Slime data types. **/ public enum Type { + NIX(0), BOOL(1), LONG(2), @@ -15,8 +16,9 @@ public enum Type { OBJECT(7); public final byte ID; - private Type(int id) { this.ID = (byte)id; } + Type(int id) { this.ID = (byte)id; } private static final Type[] types = values(); static Type asType(int id) { return types[id]; } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java index e644244178d..7f1351cc42b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedDoubleTensor.java @@ -74,7 +74,7 @@ class IndexedDoubleTensor extends IndexedTensor { @Override public Builder cell(TensorAddress address, double value) { - values[(int)toValueIndex(address, sizes())] = value; + values[(int)toValueIndex(address, sizes(), type)] = value; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java index 30157d9791a..d7eeaa4f96e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedFloatTensor.java @@ -78,7 +78,7 @@ class IndexedFloatTensor extends IndexedTensor { @Override public Builder cell(TensorAddress address, float value) { - values[(int)toValueIndex(address, sizes())] = value; + values[(int)toValueIndex(address, sizes(), type)] = value; return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index a03131f3ec9..3c15d7540b2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -112,7 +112,7 @@ public abstract class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return get((int)toValueIndex(address, dimensionSizes)); + return get((int)toValueIndex(address, dimensionSizes, type)); } catch (IndexOutOfBoundsException e) { return Double.NaN; @@ -151,14 +151,13 @@ public abstract class IndexedTensor implements Tensor { return valueIndex; } - static long toValueIndex(TensorAddress address, DimensionSizes sizes) { + static long toValueIndex(TensorAddress address, DimensionSizes sizes, TensorType type) { if (address.isEmpty()) return 0; long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.numericLabel(i) >= sizes.size(i)) { - throw new IndexOutOfBoundsException(); - } + if (address.numericLabel(i) >= sizes.size(i)) + throw new IllegalArgumentException(address + " is not within bounds of " + type); valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); } return valueIndex; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 2a713611307..30f7185959c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,10 +1,13 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.google.common.base.Joiner; + import java.util.Arrays; import java.util.Objects; import java.util.Optional; import java.util.regex.Pattern; +import java.util.stream.Collectors; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -118,9 +121,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return new StringTensorAddress(labels); } + @Override public String toString() { - return Arrays.toString(labels); + return "cell address (" + String.join(",", labels) + ")"; } } @@ -151,7 +155,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public String toString() { - return Arrays.toString(labels); + return "cell address (" + Arrays.stream(labels).mapToObj(String::valueOf).collect(Collectors.joining(",")) + ")"; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java index 6382361f187..52635905d72 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/JsonFormat.java @@ -7,6 +7,7 @@ import com.yahoo.slime.Inspector; import com.yahoo.slime.JsonDecoder; import com.yahoo.slime.ObjectTraverser; import com.yahoo.slime.Slime; +import com.yahoo.slime.Type; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -41,18 +42,26 @@ public class JsonFormat { } /** Deserializes the given tensor from JSON format */ - // TODO: Add explicit validation (valid() checks) below public static Tensor decode(TensorType type, byte[] jsonTensorValue) { Tensor.Builder tensorBuilder = Tensor.Builder.of(type); Inspector root = new JsonDecoder().decode(new Slime(), jsonTensorValue).get(); Inspector cells = root.field("cells"); + if ( cells.type() != Type.ARRAY) + throw new IllegalArgumentException("Excepted an array item named 'cells' at the top level"); cells.traverse((ArrayTraverser) (__, cell) -> decodeCell(cell, tensorBuilder.cell())); return tensorBuilder.build(); } private static void decodeCell(Inspector cell, Tensor.Builder.CellBuilder cellBuilder) { - cell.field("address").traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); - cellBuilder.value(cell.field("value").asDouble()); + Inspector address = cell.field("address"); + if ( address.type() != Type.OBJECT) + throw new IllegalArgumentException("Excepted a cell to contain an object called 'address'"); + address.traverse((ObjectTraverser) (dimension, label) -> cellBuilder.label(dimension, label.asString())); + + Inspector value = cell.field("value"); + if (value.type() != Type.LONG && value.type() != Type.DOUBLE) + throw new IllegalArgumentException("Excepted a cell to contain a numeric value called 'value'"); + cellBuilder.value(value.asDouble()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java index 5a025b6eb96..8c652f5aa27 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/serialization/JsonFormatTestCase.java @@ -8,6 +8,7 @@ import org.junit.Test; import java.nio.charset.StandardCharsets; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * @author bratseth @@ -15,7 +16,7 @@ import static org.junit.Assert.assertEquals; public class JsonFormatTestCase { @Test - public void testJsonEncodingOfSparseTensor() { + public void testSparseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); builder.cell().label("x", "a").label("y", "b").value(2.0); builder.cell().label("x", "c").label("y", "d").value(3.0); @@ -31,7 +32,7 @@ public class JsonFormatTestCase { } @Test - public void testJsonEncodingOfDenseTensor() { + public void testDenseTensor() { Tensor.Builder builder = Tensor.Builder.of(TensorType.fromSpec("tensor(x{},y{})")); builder.cell().label("x", 0).label("y", 0).value(2.0); builder.cell().label("x", 0).label("y", 1).value(3.0); @@ -50,4 +51,21 @@ public class JsonFormatTestCase { assertEquals(tensor, decoded); } + @Test + public void testTooManyCells() { + TensorType x2 = TensorType.fromSpec("tensor(x[2])"); + String json = "{\"cells\":[" + + "{\"address\":{\"x\":\"0\"},\"value\":2.0}," + + "{\"address\":{\"x\":\"1\"},\"value\":3.0}," + + "{\"address\":{\"x\":\"2\"},\"value\":5.0}" + + "]}"; + try { + JsonFormat.decode(x2, json.getBytes(StandardCharsets.UTF_8)); + fail("Excpected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("cell address (2) is not within bounds of tensor(x[2])", e.getMessage()); + } + } + } |