summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2020-01-03 11:01:02 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2020-01-03 11:01:02 +0100
commitfcaf3de39b725ece9d57e3c764bf0fae36206d5d (patch)
tree4f424b7a35507da40263211a8e0aa476a6d3dc52 /vespajlib/src/main
parent869e9e83274e037cae548ac5eb3c72881e90859a (diff)
More tensor short forms in Tensor.toString()
Diffstat (limited to 'vespajlib/src/main')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java23
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java67
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java23
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java3
6 files changed, 106 insertions, 26 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index ba3a35e8eda..985dbd11bcb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -209,12 +209,18 @@ public abstract class IndexedTensor implements Tensor {
@Override
public String toString() {
if (type.rank() == 0) return Tensor.toStandardString(this);
- if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) return Tensor.toStandardString(this);
+ if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty()))
+ return Tensor.toStandardString(this);
Indexes indexes = Indexes.of(dimensionSizes);
StringBuilder b = new StringBuilder(type.toString()).append(":");
- for (int index = 0; index < size(); index++) {
+ indexedBlockToString(this, indexes, b);
+ return b.toString();
+ }
+
+ static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, StringBuilder b) {
+ for (int index = 0; index < tensor.size(); index++) {
indexes.next();
// start brackets
@@ -222,20 +228,19 @@ public abstract class IndexedTensor implements Tensor {
b.append("[");
// value
- if (type.valueType() == TensorType.Value.DOUBLE)
- b.append(get(index));
- else if (type.valueType() == TensorType.Value.FLOAT)
- b.append(get(index)); // TODO: Use getFloat
+ if (tensor.type().valueType() == TensorType.Value.DOUBLE)
+ b.append(tensor.get(index));
+ else if (tensor.type().valueType() == TensorType.Value.FLOAT)
+ b.append(tensor.getFloat(index));
else
- throw new IllegalStateException("Unexpected value type " + type.valueType());
+ throw new IllegalStateException("Unexpected value type " + tensor.type().valueType());
// end bracket and comma
for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
b.append("]");
- if (index < size() - 1)
+ if (index < tensor.size() - 1)
b.append(", ");
}
- return b.toString();
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 0c4efe78113..b11b0c58a2d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -148,7 +148,14 @@ public class MixedTensor implements Tensor {
public int hashCode() { return cells.hashCode(); }
@Override
- public String toString() { return Tensor.toStandardString(this); }
+ public String toString() {
+ if (type.rank() == 0) return Tensor.toStandardString(this);
+ if (type.rank() > 1 && type.dimensions().stream().anyMatch(d -> d.size().isEmpty()))
+ return Tensor.toStandardString(this);
+ if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) return Tensor.toStandardString(this);
+
+ return type.toString() + ":" + index.contentToString(this);
+ }
@Override
public boolean equals(Object other) {
@@ -494,7 +501,63 @@ public class MixedTensor implements Tensor {
@Override
public String toString() {
- return "indexes into " + type;
+ return "index into " + type;
+ }
+
+ private String contentToString(MixedTensor tensor) {
+ if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller");
+ if (mappedDimensions.size() == 0) {
+ StringBuilder b = new StringBuilder();
+ denseSubspaceToString(tensor, 0, b);
+ return b.toString();
+ }
+
+ // Exactly 1 mapped dimension
+ StringBuilder b = new StringBuilder("{");
+ sparseMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEach(entry -> {
+ b.append(TensorAddress.labelToString(entry.getKey().label(0 )));
+ b.append(":");
+ denseSubspaceToString(tensor, entry.getValue(), b);
+ b.append(",");
+ });
+ if (b.length() > 1)
+ b.setLength(b.length() - 1);
+ b.append("}");
+ return b.toString();
+ }
+
+ private void denseSubspaceToString(MixedTensor tensor, long subspaceIndex, StringBuilder b) {
+ if (denseSubspaceSize == 1) {
+ b.append(getDouble(subspaceIndex, 0, tensor));
+ return;
+ }
+
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseType);
+ for (int index = 0; index < denseSubspaceSize; index++) {
+ indexes.next();
+
+ // start brackets
+ for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
+ b.append("[");
+
+ // value
+ if (type.valueType() == TensorType.Value.DOUBLE)
+ b.append(getDouble(subspaceIndex, index, tensor));
+ else if (tensor.type().valueType() == TensorType.Value.FLOAT)
+ b.append(getDouble(subspaceIndex, index, tensor)); // TODO: Really use floats
+ else
+ throw new IllegalStateException("Unexpected value type " + type.valueType());
+
+ // end bracket and comma
+ for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
+ b.append("]");
+ if (index < denseSubspaceSize - 1)
+ b.append(", ");
+ }
+ }
+
+ private double getDouble(long indexedSubspaceIndex, long indexInIndexedSubspace, MixedTensor tensor) {
+ return tensor.cells.get((int)(indexedSubspaceIndex + indexInIndexedSubspace)).getDoubleValue();
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index cffd41905a1..b0b1c27962a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -31,6 +31,7 @@ import java.util.Set;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
+import java.util.stream.Collectors;
import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers;
@@ -319,23 +320,21 @@ public interface Tensor {
}
static String contentToString(Tensor tensor) {
- List<java.util.Map.Entry<TensorAddress, Double>> cellEntries = new ArrayList<>(tensor.cells().entrySet());
+ var cellEntries = new ArrayList<>(tensor.cells().entrySet());
if (tensor.type().dimensions().isEmpty()) {
if (cellEntries.isEmpty()) return "{}";
return "{" + cellEntries.get(0).getValue() +"}";
}
+ return "{" + cellEntries.stream().sorted(Map.Entry.comparingByKey())
+ .map(cell -> cellToString(cell, tensor.type()))
+ .collect(Collectors.joining(",")) +
+ "}";
+ }
- Collections.sort(cellEntries, java.util.Map.Entry.<TensorAddress, Double>comparingByKey());
-
- StringBuilder b = new StringBuilder("{");
- for (java.util.Map.Entry<TensorAddress, Double> cell : cellEntries) {
- b.append(cell.getKey().toString(tensor.type())).append(":").append(cell.getValue());
- b.append(",");
- }
- if (b.length() > 1)
- b.setLength(b.length() - 1);
- b.append("}");
- return b.toString();
+ private static String cellToString(Map.Entry<TensorAddress, Double> cell, TensorType type) {
+ return (type.rank() > 1 ? cell.getKey().toString(type) : TensorAddress.labelToString(cell.getKey().label(0))) +
+ ":" +
+ cell.getValue();
}
// ----------------- equality
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index a3805fb789a..4a076199846 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -91,7 +91,8 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return b.toString();
}
- private String labelToString(String label) {
+ /** Returns a label as a string with approriate quoting/escaping when necessary */
+ public static String labelToString(String label) {
if (TensorType.labelMatcher.matches(label)) return label; // no quoting
if (label.contains("'")) return "\"" + label + "\"";
return "'" + label + "'";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 9aa764a0b36..becec1a4493 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -99,7 +99,7 @@ class TensorParser {
if (type.isEmpty())
throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
- if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1)
+ if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " +
"but got " + type.get());
@@ -310,7 +310,7 @@ class TensorParser {
}
private void parse() {
- TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType.Dimension mappedDimension = findMappedDimension();
TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension));
if (dimensionOrder != null)
dimensionOrder.remove(mappedDimension.name());
@@ -332,6 +332,15 @@ class TensorParser {
}
}
+ private TensorType.Dimension findMappedDimension() {
+ Optional<TensorType.Dimension> mappedDimension = builder.type().dimensions().stream().filter(d -> d.isMapped()).findAny();
+ if (mappedDimension.isPresent()) return mappedDimension.get();
+ if (builder.type().rank() == 1 && builder.type().dimensions().get(0).size().isEmpty())
+ return builder.type().dimensions().get(0);
+ throw new IllegalStateException("No suitable dimension in " + builder.type() +
+ " for parsing as a mixed tensor. This is a bug.");
+ }
+
private void parseDenseSubspace(TensorAddress mappedAddress, List<String> denseDimensionOrder) {
DenseValueParser denseParser = new DenseValueParser(string.substring(position),
denseDimensionOrder,
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 58cb151875e..fee623dafa2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -312,6 +312,9 @@ public class TensorType {
/** Returns true if this is an indexed bound or unbound type */
public boolean isIndexed() { return type() == Type.indexedBound || type() == Type.indexedUnbound; }
+ /** Returns true if this is of the mapped type */
+ public boolean isMapped() { return type() == Type.mapped; }
+
/**
* Returns the dimension resulting from combining two dimensions having the same name but possibly different
* types. This works by degrading to the type making the fewer promises.