aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2020-01-06 10:25:30 +0100
committerGitHub <noreply@github.com>2020-01-06 10:25:30 +0100
commitb95af9b717705fff28272a1ea5e0adcf97597402 (patch)
tree254fe25f613fa3727cce888e03ffcd48bbc8ab93 /vespajlib
parent234be16d4d01656ee7b9bdc0917d31bef9772f69 (diff)
parentf9f76ab6dc479dfbbaa2b7520cdb0d163be9b7dd (diff)
Merge pull request #11637 from vespa-engine/bratseth/tensor-short-form-tostring
More tensor short forms in Tensor.toString()
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java23
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java67
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java23
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java2
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java4
12 files changed, 118 insertions, 36 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index cc0a6dc3a14..f631b3e1c58 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1263,6 +1263,7 @@
"public int hashCode()",
"public boolean equals(java.lang.Object)",
"public final java.lang.String toString(com.yahoo.tensor.TensorType)",
+ "public static java.lang.String labelToString(java.lang.String)",
"public bridge synthetic int compareTo(java.lang.Object)"
],
"fields": []
@@ -1324,6 +1325,7 @@
"public abstract com.yahoo.tensor.TensorType$Dimension$Type type()",
"public abstract com.yahoo.tensor.TensorType$Dimension withName(java.lang.String)",
"public boolean isIndexed()",
+ "public boolean isMapped()",
"public abstract java.lang.String toString()",
"public boolean equals(java.lang.Object)",
"public int hashCode()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index b255f18cdd4..ad82dd6c3ac 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -204,12 +204,18 @@ public abstract class IndexedTensor implements Tensor {
@Override
public String toString() {
if (type.rank() == 0) return Tensor.toStandardString(this);
- if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) return Tensor.toStandardString(this);
+ if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty()))
+ return Tensor.toStandardString(this);
Indexes indexes = Indexes.of(dimensionSizes);
StringBuilder b = new StringBuilder(type.toString()).append(":");
- for (int index = 0; index < size(); index++) {
+ indexedBlockToString(this, indexes, b);
+ return b.toString();
+ }
+
+ static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, StringBuilder b) {
+ for (int index = 0; index < tensor.size(); index++) {
indexes.next();
// start brackets
@@ -217,20 +223,19 @@ public abstract class IndexedTensor implements Tensor {
b.append("[");
// value
- if (type.valueType() == TensorType.Value.DOUBLE)
- b.append(get(index));
- else if (type.valueType() == TensorType.Value.FLOAT)
- b.append(get(index)); // TODO: Use getFloat
+ if (tensor.type().valueType() == TensorType.Value.DOUBLE)
+ b.append(tensor.get(index));
+ else if (tensor.type().valueType() == TensorType.Value.FLOAT)
+ b.append(tensor.getFloat(index));
else
- throw new IllegalStateException("Unexpected value type " + type.valueType());
+ throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
b.append("]");
- if (index < size() - 1)
+ if (index < tensor.size() - 1)
b.append(", ");
}
- return b.toString();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index ad4f0fd0dfb..67c6930ce35 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -132,7 +132,14 @@ public class MixedTensor implements Tensor {
public int hashCode() { return cells.hashCode(); }
@Override
- public String toString() { return Tensor.toStandardString(this); }
+ public String toString() {
+ if (type.rank() == 0) return Tensor.toStandardString(this);
+ if (type.rank() > 1 && type.dimensions().stream().anyMatch(d -> d.size().isEmpty()))
+ return Tensor.toStandardString(this);
+ if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) return Tensor.toStandardString(this);
+
+ return type.toString() + ":" + index.contentToString(this);
+ }
@Override
public boolean equals(Object other) {
@@ -479,7 +486,63 @@ public class MixedTensor implements Tensor {
@Override
public String toString() {
- return "indexes into " + type;
+ return "index into " + type;
+ }
+
+ private String contentToString(MixedTensor tensor) {
+ if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller");
+ if (mappedDimensions.size() == 0) {
+ StringBuilder b = new StringBuilder();
+ denseSubspaceToString(tensor, 0, b);
+ return b.toString();
+ }
+
+ // Exactly 1 mapped dimension
+ StringBuilder b = new StringBuilder("{");
+ sparseMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEach(entry -> {
+ b.append(TensorAddress.labelToString(entry.getKey().label(0 )));
+ b.append(":");
+ denseSubspaceToString(tensor, entry.getValue(), b);
+ b.append(",");
+ });
+ if (b.length() > 1)
+ b.setLength(b.length() - 1);
+ b.append("}");
+ return b.toString();
+ }
+
+ private void denseSubspaceToString(MixedTensor tensor, long subspaceIndex, StringBuilder b) {
+ if (denseSubspaceSize == 1) {
+ b.append(getDouble(subspaceIndex, 0, tensor));
+ return;
+ }
+
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseType);
+ for (int index = 0; index < denseSubspaceSize; index++) {
+ indexes.next();
+
+ // start brackets
+ for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
+ b.append("[");
+
+ // value
+ if (type.valueType() == TensorType.Value.DOUBLE)
+ b.append(getDouble(subspaceIndex, index, tensor));
+ else if (tensor.type().valueType() == TensorType.Value.FLOAT)
+ b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats
+ else
+ throw new IllegalStateException("Unexpected value type " + type.valueType());
+
+ // end bracket and comma
+ for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
+ b.append("]");
+ if (index < denseSubspaceSize - 1)
+ b.append(", ");
+ }
+ }
+
+ private double getDouble(long indexedSubspaceIndex, long indexInIndexedSubspace, MixedTensor tensor) {
+ return tensor.cells.get((int)(indexedSubspaceIndex + indexInIndexedSubspace)).getDoubleValue();
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 6245c26b9f4..08d4f1c08b7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -32,6 +32,7 @@ import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
+import java.util.stream.Collectors;
import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers;
@@ -312,23 +313,21 @@ public interface Tensor {
}
static String contentToString(Tensor tensor) {
- List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
+ var cellEntries = new ArrayList<>(tensor.cells().entrySet());
if (tensor.type().dimensions().isEmpty()) {
if (cellEntries.isEmpty()) return "{}";
return "{" + cellEntries.get(0).getValue() +"}";
}
+ return "{" + cellEntries.stream().sorted(Map.Entry.comparingByKey())
+ .map(cell -> cellToString(cell, tensor.type()))
+ .collect(Collectors.joining(",")) +
+ "}";
+ }
- Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
-
- StringBuilder b = new StringBuilder("{");
- for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) {
- b.append(cell.getKey().toString(tensor.type())).append(":").append(cell.getValue());
- b.append(",");
- }
- if (b.length() > 1)
- b.setLength(b.length() - 1);
- b.append("}");
- return b.toString();
+ private static String cellToString(Map.Entry<TensorAddress, Double> cell, TensorType type) {
+ return (type.rank() > 1 ? cell.getKey().toString(type) : TensorAddress.labelToString(cell.getKey().label(0))) +
+ ":" +
+ cell.getValue();
}
// ----------------- equality
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index a3805fb789a..4a076199846 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -91,7 +91,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- private String labelToString(String label) {
+ /** Returns a label as a string with approriate quoting/escaping when necessary */
+ public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
return "'" + label + "'";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 9aa764a0b36..becec1a4493 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -99,7 +99,7 @@ class TensorParser {
if (type.isEmpty())
throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
- if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1)
+ if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " +
"but got " + type.get());
@@ -310,7 +310,7 @@ class TensorParser {
}
private void parse() {
- TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType.Dimension mappedDimension = findMappedDimension();
TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension));
if (dimensionOrder != null)
dimensionOrder.remove(mappedDimension.name());
@@ -332,6 +332,15 @@ class TensorParser {
}
}
+ private TensorType.Dimension findMappedDimension() {
+ Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(d -> d.isMapped()).findAny();
+ if (mappedDimension.isPresent()) return mappedDimension.get();
+ if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty())
+ return builder.type().dimensions().get(0);
+ throw new IllegalStateException("No suitable dimension in " + builder.type() +
+ " for parsing as a mixed tensor. This is a bug.");
+ }
+
private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) {
DenseValueParser denseParser = new DenseValueParser(string.substring(position),
denseDimensionOrder,
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 240681bad5a..aeed8c33093 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -312,6 +312,9 @@ public class TensorType {
/** Returns true if this is an indexed bound or unbound type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
+ /** Returns true if this is of the mapped type */
+ public boolean isMapped() { return type() == Type.mapped; }
+
/**
* Returns the dimension resulting from combining two dimensions having the same name but possibly different
* types:
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
index 502f0270831..2699e4642e2 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java
@@ -35,7 +35,7 @@ public class MappedTensorTestCase {
cell().label("x", "0").value(1).
cell().label("x", "1").value(2).build();
assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames());
- assertEquals("tensor(x{}):{{x:0}:1.0,{x:1}:2.0}", tensor.toString());
+ assertEquals("tensor(x{}):{0:1.0,1:2.0}", tensor.toString());
}
@Test
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java
index 22b97bff52a..5d4417ec928 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java
@@ -41,7 +41,7 @@ public class MixedTensorTestCase {
// {y:2} should be 0.0 and non NaN since we specify indexed size
build();
assertEquals(Sets.newHashSet("y"), tensor.type().dimensionNames());
- assertEquals("tensor(y[3]):{{y:0}:1.0,{y:1}:2.0,{y:2}:0.0}",
+ assertEquals("tensor(y[3]):[1.0, 2.0, 0.0]",
tensor.toString());
}
@@ -57,8 +57,8 @@ public class MixedTensorTestCase {
cell().label("x", 1).label("y", 2).value(6).
build();
assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames());
- assertEquals("tensor(x[2],y[3]):{{x:0,y:0}:1.0,{x:0,y:1}:2.0,{x:0,y:2}:0.0,{x:1,y:0}:4.0,{x:1,y:1}:5.0,{x:1,y:2}:6.0}",
- tensor.toString());
+ assertEquals("tensor(x[2],y[3]):[[1.0, 2.0, 0.0], [4.0, 5.0, 6.0]]",
+ tensor.toString());
}
@Test
@@ -69,8 +69,8 @@ public class MixedTensorTestCase {
cell().label("x", "1").value(2).
build();
assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames());
- assertEquals("tensor(x{}):{{x:0}:1.0,{x:1}:2.0}",
- tensor.toString());
+ assertEquals("tensor(x{}):{0:1.0,1:2.0}",
+ tensor.toString());
}
@Test
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 5a68df6c7df..431e4b06263 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -43,7 +43,7 @@ public class TensorParserTestCase {
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).cell(1.3).build(),
"tensor():{1.3}");
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[])")).cell(1.0, 0).build(),
- "tensor(x[]):{{x:0}:1.0}");
+ "tensor(x[]):{0:1.0}");
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
"tensor(x[1]):[1.0]");
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[2])")).cell(1.0, 0).cell(2.0, 1).build(),
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 040111a6fbb..b1851b5f120 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -37,7 +37,7 @@ public class TensorTestCase {
assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString());
assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString());
assertEquals("Labels are quoted when necessary",
- "tensor(d1{}):{{d1:\"'''\"}:6.0,{d1:'[[\":\"]]'}:5.0}",
+ "tensor(d1{}):{\"'''\":6.0,'[[\":\"]]':5.0}",
Tensor.from("{ {d1:'[[\":\"]]'}: 5, {d1:\"'''\"}:6.0 }").toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index e6560242d5c..625d5d44b19 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -15,11 +15,11 @@ public class TensorFunctionTestCase {
@Test
public void testTranslation() {
- assertTranslated("join(tensor(x{}):{{x:1}:1.0}, reduce(tensor(x{}):{{x:1}:1.0}, sum, x), f(a,b)(a / b))",
+ assertTranslated("join(tensor(x{}):{1:1.0}, reduce(tensor(x{}):{1:1.0}, sum, x), f(a,b)(a / b))",
new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x"));
assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
- assertTranslated("join(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))",
+ assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max, x), f(a,b)(a==b))",
new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
}