diff options
author | Jon Bratseth <bratseth@oath.com> | 2020-01-06 10:25:30 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2020-01-06 10:25:30 +0100 |
commit | b95af9b717705fff28272a1ea5e0adcf97597402 (patch) | |
tree | 254fe25f613fa3727cce888e03ffcd48bbc8ab93 /vespajlib | |
parent | 234be16d4d01656ee7b9bdc0917d31bef9772f69 (diff) | |
parent | f9f76ab6dc479dfbbaa2b7520cdb0d163be9b7dd (diff) |
Merge pull request #11637 from vespa-engine/bratseth/tensor-short-form-tostring
More tensor short forms in Tensor.toString()
Diffstat (limited to 'vespajlib')
12 files changed, 118 insertions, 36 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index cc0a6dc3a14..f631b3e1c58 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1263,6 +1263,7 @@ "public int hashCode()", "public boolean equals(java.lang.Object)", "public final java.lang.String toString(com.yahoo.tensor.TensorType)", + "public static java.lang.String labelToString(java.lang.String)", "public bridge synthetic int compareTo(java.lang.Object)" ], "fields": [] @@ -1324,6 +1325,7 @@ "public abstract com.yahoo.tensor.TensorType$Dimension$Type type()", "public abstract com.yahoo.tensor.TensorType$Dimension withName(java.lang.String)", "public boolean isIndexed()", + "public boolean isMapped()", "public abstract java.lang.String toString()", "public boolean equals(java.lang.Object)", "public int hashCode()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index b255f18cdd4..ad82dd6c3ac 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -204,12 +204,18 @@ public abstract class IndexedTensor implements Tensor { @Override public String toString() { if (type.rank() == 0) return Tensor.toStandardString(this); - if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) return Tensor.toStandardString(this); + if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) + return Tensor.toStandardString(this); Indexes indexes = Indexes.of(dimensionSizes); StringBuilder b = new StringBuilder(type.toString()).append(":"); - for (int index = 0; index < size(); index++) { + indexedBlockToString(this, indexes, b); + return b.toString(); + } + + static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, StringBuilder b) { + for (int index = 0; index < tensor.size(); index++) { indexes.next(); // start brackets @@ -217,20 +223,19 @@ public abstract class IndexedTensor implements Tensor { b.append("["); // value - if (type.valueType() == TensorType.Value.DOUBLE) - b.append(get(index)); - else if (type.valueType() == TensorType.Value.FLOAT) - b.append(get(index)); // TODO: Use getFloat + if (tensor.type().valueType() == TensorType.Value.DOUBLE) + b.append(tensor.get(index)); + else if (tensor.type().valueType() == TensorType.Value.FLOAT) + b.append(tensor.getFloat(index)); else - throw new IllegalStateException("Unexpected value type " + type.valueType()); + throw new IllegalStateException("Unexpected value type " + tensor.type().valueType()); // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); - if (index < size() - 1) + if (index < tensor.size() - 1) b.append(", "); } - return b.toString(); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index ad4f0fd0dfb..67c6930ce35 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -132,7 +132,14 @@ public class MixedTensor implements Tensor { public int hashCode() { return cells.hashCode(); } @Override - public String toString() { return Tensor.toStandardString(this); } + public String toString() { + if (type.rank() == 0) return Tensor.toStandardString(this); + if (type.rank() > 1 && type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) + return Tensor.toStandardString(this); + if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) return Tensor.toStandardString(this); + + return type.toString() + ":" + index.contentToString(this); + } @Override public boolean equals(Object other) { @@ -479,7 +486,63 @@ public class MixedTensor implements Tensor { @Override public String toString() { - return "indexes into " + type; + return "index into " + type; + } + + private String contentToString(MixedTensor tensor) { + if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller"); + if (mappedDimensions.size() == 0) { + StringBuilder b = new StringBuilder(); + denseSubspaceToString(tensor, 0, b); + return b.toString(); + } + + // Exactly 1 mapped dimension + StringBuilder b = new StringBuilder("{"); + sparseMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEach(entry -> { + b.append(TensorAddress.labelToString(entry.getKey().label(0 ))); + b.append(":"); + denseSubspaceToString(tensor, entry.getValue(), b); + b.append(","); + }); + if (b.length() > 1) + b.setLength(b.length() - 1); + b.append("}"); + return b.toString(); + } + + private void denseSubspaceToString(MixedTensor tensor, long subspaceIndex, StringBuilder b) { + if (denseSubspaceSize == 1) { + b.append(getDouble(subspaceIndex, 0, tensor)); + return; + } + + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseType); + for (int index = 0; index < denseSubspaceSize; index++) { + indexes.next(); + + // start brackets + for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) + b.append("["); + + // value + if (type.valueType() == TensorType.Value.DOUBLE) + b.append(getDouble(subspaceIndex, index, tensor)); + else if (tensor.type().valueType() == TensorType.Value.FLOAT) + b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats + else + throw new IllegalStateException("Unexpected value type " + type.valueType()); + + // end bracket and comma + for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) + b.append("]"); + if (index < denseSubspaceSize - 1) + b.append(", "); + } + } + + private double getDouble(long indexedSubspaceIndex, long indexInIndexedSubspace, MixedTensor tensor) { + return tensor.cells.get((int)(indexedSubspaceIndex + indexInIndexedSubspace)).getDoubleValue(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 6245c26b9f4..08d4f1c08b7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -32,6 +32,7 @@ import java.util.Set; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; +import java.util.stream.Collectors; import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; @@ -312,23 +313,21 @@ public interface Tensor { } static String contentToString(Tensor tensor) { - List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet()); + var cellEntries = new ArrayList<>(tensor.cells().entrySet()); if (tensor.type().dimensions().isEmpty()) { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } + return "{" + cellEntries.stream().sorted(Map.Entry.comparingByKey()) + .map(cell -> cellToString(cell, tensor.type())) + .collect(Collectors.joining(",")) + + "}"; + } - Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey()); - - StringBuilder b = new StringBuilder("{"); - for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) { - b.append(cell.getKey().toString(tensor.type())).append(":").append(cell.getValue()); - b.append(","); - } - if (b.length() > 1) - b.setLength(b.length() - 1); - b.append("}"); - return b.toString(); + private static String cellToString(Map.Entry<TensorAddress, Double> cell, TensorType type) { + return (type.rank() > 1 ? cell.getKey().toString(type) : TensorAddress.labelToString(cell.getKey().label(0))) + + ":" + + cell.getValue(); } // ----------------- equality diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index a3805fb789a..4a076199846 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -91,7 +91,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return b.toString(); } - private String labelToString(String label) { + /** Returns a label as a string with approriate quoting/escaping when necessary */ + public static String labelToString(String label) { if (TensorType.labelMatcher.matches(label)) return label; // no quoting if (label.contains("'")) return "\"" + label + "\""; return "'" + label + "'"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 9aa764a0b36..becec1a4493 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -99,7 +99,7 @@ class TensorParser { if (type.isEmpty()) throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1) + if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1) throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " + "but got " + type.get()); @@ -310,7 +310,7 @@ class TensorParser { } private void parse() { - TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType.Dimension mappedDimension = findMappedDimension(); TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); if (dimensionOrder != null) dimensionOrder.remove(mappedDimension.name()); @@ -332,6 +332,15 @@ class TensorParser { } } + private TensorType.Dimension findMappedDimension() { + Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(d -> d.isMapped()).findAny(); + if (mappedDimension.isPresent()) return mappedDimension.get(); + if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty()) + return builder.type().dimensions().get(0); + throw new IllegalStateException("No suitable dimension in " + builder.type() + + " for parsing as a mixed tensor. This is a bug."); + } + private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) { DenseValueParser denseParser = new DenseValueParser(string.substring(position), denseDimensionOrder, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 240681bad5a..aeed8c33093 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -312,6 +312,9 @@ public class TensorType { /** Returns true if this is an indexed bound or unbound type */ public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; } + /** Returns true if this is of the mapped type */ + public boolean isMapped() { return type() == Type.mapped; } + /** * Returns the dimension resulting from combining two dimensions having the same name but possibly different * types: diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java index 502f0270831..2699e4642e2 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java @@ -35,7 +35,7 @@ public class MappedTensorTestCase { cell().label("x", "0").value(1). cell().label("x", "1").value(2).build(); assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames()); - assertEquals("tensor(x{}):{{x:0}:1.0,{x:1}:2.0}", tensor.toString()); + assertEquals("tensor(x{}):{0:1.0,1:2.0}", tensor.toString()); } @Test diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java index 22b97bff52a..5d4417ec928 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java @@ -41,7 +41,7 @@ public class MixedTensorTestCase { // {y:2} should be 0.0 and non NaN since we specify indexed size build(); assertEquals(Sets.newHashSet("y"), tensor.type().dimensionNames()); - assertEquals("tensor(y[3]):{{y:0}:1.0,{y:1}:2.0,{y:2}:0.0}", + assertEquals("tensor(y[3]):[1.0, 2.0, 0.0]", tensor.toString()); } @@ -57,8 +57,8 @@ public class MixedTensorTestCase { cell().label("x", 1).label("y", 2).value(6). build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x[2],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:2.0,{x:0,y:2}:0.0,{x:1,y:0}:4.0,{x:1,y:1}:5.0,{x:1,y:2}:6.0}", - tensor.toString()); + assertEquals("tensor(x[2],y[3]):[[1.0, 2.0, 0.0], [4.0, 5.0, 6.0]]", + tensor.toString()); } @Test @@ -69,8 +69,8 @@ public class MixedTensorTestCase { cell().label("x", "1").value(2). build(); assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames()); - assertEquals("tensor(x{}):{{x:0}:1.0,{x:1}:2.0}", - tensor.toString()); + assertEquals("tensor(x{}):{0:1.0,1:2.0}", + tensor.toString()); } @Test diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 5a68df6c7df..431e4b06263 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -43,7 +43,7 @@ public class TensorParserTestCase { assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).cell(1.3).build(), "tensor():{1.3}"); assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[])")).cell(1.0, 0).build(), - "tensor(x[]):{{x:0}:1.0}"); + "tensor(x[]):{0:1.0}"); assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(), "tensor(x[1]):[1.0]"); assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[2])")).cell(1.0, 0).cell(2.0, 1).build(), diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 040111a6fbb..b1851b5f120 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -37,7 +37,7 @@ public class TensorTestCase { assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); assertEquals("Labels are quoted when necessary", - "tensor(d1{}):{{d1:\"'''\"}:6.0,{d1:'[[\":\"]]'}:5.0}", + "tensor(d1{}):{\"'''\":6.0,'[[\":\"]]':5.0}", Tensor.from("{ {d1:'[[\":\"]]'}: 5, {d1:\"'''\"}:6.0 }").toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index e6560242d5c..625d5d44b19 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -15,11 +15,11 @@ public class TensorFunctionTestCase { @Test public void testTranslation() { - assertTranslated("join(tensor(x{}):{{x:1}:1.0}, reduce(tensor(x{}):{{x:1}:1.0}, sum, x), f(a,b)(a / b))", + assertTranslated("join(tensor(x{}):{1:1.0}, reduce(tensor(x{}):{1:1.0}, sum, x), f(a,b)(a / b))", new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x")); assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); - assertTranslated("join(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))", + assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max, x), f(a,b)(a==b))", new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); } |