diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-03-29 12:21:56 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-03-29 12:21:56 +0200 |
commit | e69d6e8f3d8a6504135f6d2733a3a42f6a041ed4 (patch) | |
tree | 046483fb628977f62a66cb660d4a09fcd4302e0d /vespajlib/src/main/java/com/yahoo/tensor | |
parent | 13100e8dcc72b7c879727e5d96e1fdfceb2d3bcc (diff) |
Validate query feature tensor types
- Validate tensor feature types when a tensor is set programmatically.
- Add a toShortString for messages containing tensors.
- Consistent and nicer spacing in tensor string forms.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor')
4 files changed, 89 insertions, 37 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 0e919d828ed..89eefeced56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -219,21 +219,31 @@ public abstract class IndexedTensor implements Tensor { } @Override - public String toString() { - if (type.rank() == 0) return Tensor.toStandardString(this); + public String toString() { return toString(Long.MAX_VALUE); } + + @Override + public String toShortString() { + return toString(Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + } + + private String toString(long maxCells) { + if (type.rank() == 0) return Tensor.toStandardString(this, maxCells); if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) - return Tensor.toStandardString(this); + return Tensor.toStandardString(this, maxCells); Indexes indexes = Indexes.of(dimensionSizes); StringBuilder b = new StringBuilder(type.toString()).append(":"); - indexedBlockToString(this, indexes, b); + indexedBlockToString(this, indexes, maxCells, b); return b.toString(); } - static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, StringBuilder b) { - for (int index = 0; index < tensor.size(); index++) { + static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, long maxCells, StringBuilder b) { + int index = 0; + for (; index < tensor.size() && index < maxCells; index++) { indexes.next(); + if (index > 0) + b.append(", "); // start brackets for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) @@ -252,9 +262,9 @@ public abstract class IndexedTensor implements Tensor { // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); - if (index < tensor.size() - 1) - b.append(", "); } + if (index == maxCells && index < tensor.size()) + b.append(", ...]"); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 09e93d80bd9..ad945ed18bf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -50,7 +50,7 @@ public class MappedTensor implements Tensor { public Tensor withType(TensorType other) { if (!this.type.isRenamableTo(type)) { throw new IllegalArgumentException("MappedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); + this.type + "', requested type: '" + type.toString() + "'"); } return new MappedTensor(other, cells); } @@ -72,7 +72,12 @@ public class MappedTensor implements Tensor { public int hashCode() { return cells.hashCode(); } @Override - public String toString() { return Tensor.toStandardString(this); } + public String toString() { return Tensor.toStandardString(this, Long.MAX_VALUE); } + + @Override + public String toShortString() { + return Tensor.toStandardString(this, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + } @Override public boolean equals(Object other) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 418e9efdffb..56bd94a86e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -5,6 +5,7 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -116,7 +117,7 @@ public class MixedTensor implements Tensor { public Tensor withType(TensorType other) { if (!this.type.isRenamableTo(type)) { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); + this.type + "', requested type: '" + type + "'"); } return new MixedTensor(other, cells, index); } @@ -144,12 +145,23 @@ public class MixedTensor implements Tensor { @Override public String toString() { - if (type.rank() == 0) return Tensor.toStandardString(this); + return toString(Long.MAX_VALUE); + } + + @Override + public String toShortString() { + return toString(Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + } + + private String toString(long maxCells) { + if (type.rank() == 0) + return Tensor.toStandardString(this, maxCells); if (type.rank() > 1 && type.dimensions().stream().filter(d -> d.isIndexed()).anyMatch(d -> d.size().isEmpty())) - return Tensor.toStandardString(this); - if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) return Tensor.toStandardString(this); + return Tensor.toStandardString(this, maxCells); + if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) + return Tensor.toStandardString(this, maxCells); - return type.toString() + ":" + index.contentToString(this); + return type + ":" + index.contentToString(this, maxCells); } @Override @@ -503,37 +515,50 @@ public class MixedTensor implements Tensor { return "index into " + type; } - private String contentToString(MixedTensor tensor) { + private String contentToString(MixedTensor tensor, long maxCells) { if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller"); if (mappedDimensions.size() == 0) { StringBuilder b = new StringBuilder(); - denseSubspaceToString(tensor, 0, b); + int cellsWritten = denseSubspaceToString(tensor, 0, maxCells, b); + if (cellsWritten == maxCells && cellsWritten < tensor.size()) + b.append("...]"); return b.toString(); } // Exactly 1 mapped dimension StringBuilder b = new StringBuilder("{"); - sparseMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEach(entry -> { - b.append(TensorAddress.labelToString(entry.getKey().label(0 ))); + var cellEntries = new ArrayList<>(sparseMap.entrySet()); + cellEntries.sort(Map.Entry.comparingByKey()); + int cellsWritten = 0; + for (int index = 0; index < cellEntries.size() && cellsWritten < maxCells; index++) { + if (index > 0) + b.append(", "); + b.append(TensorAddress.labelToString(cellEntries.get(index).getKey().label(0 ))); b.append(":"); - denseSubspaceToString(tensor, entry.getValue(), b); - b.append(","); - }); - if (b.length() > 1) - b.setLength(b.length() - 1); + cellsWritten += denseSubspaceToString(tensor, cellEntries.get(index).getValue(), maxCells - cellsWritten, b); + } + if (cellsWritten >= maxCells && cellsWritten < tensor.size()) + b.append(", ..."); b.append("}"); return b.toString(); } - private void denseSubspaceToString(MixedTensor tensor, long subspaceIndex, StringBuilder b) { + private int denseSubspaceToString(MixedTensor tensor, long subspaceIndex, long maxCells, StringBuilder b) { + if (maxCells <= 0) { + return 0; + } + if (denseSubspaceSize == 1) { b.append(getDouble(subspaceIndex, 0, tensor)); - return; + return 1; } IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseType); - for (int index = 0; index < denseSubspaceSize; index++) { + int index = 0; + for (; index < denseSubspaceSize && index < maxCells; index++) { indexes.next(); + if (index > 0) + b.append(", "); // start brackets for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) @@ -549,12 +574,11 @@ public class MixedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); } - // end bracket and comma + // end bracket for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); - if (index < denseSubspaceSize - 1) - b.append(", "); } + return index; } private double getDouble(long indexedSubspaceIndex, long indexInIndexedSubspace, MixedTensor tensor) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index ca396ae5bf2..06e7b010a7a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -316,28 +316,41 @@ public interface Tensor { @Override String toString(); + /** Returns an abbreviated string representation of this tensor suitable for human-readable messages */ + String toShortString(); + /** * Call this from toString in implementations to return this tensor on the * <a href="https://docs.vespa.ai/en/reference/tensor.html#tensor-literal-form">tensor literal form</a>. * (toString cannot be a default method because default methods cannot override super methods). * * @param tensor the tensor to return the standard string format of + * @param maxCells the max number of cells to output, after which just , "..." is output to represent the rest + * of the cells * @return the tensor on the standard string format */ - static String toStandardString(Tensor tensor) { - return tensor.type() + ":" + contentToString(tensor); + static String toStandardString(Tensor tensor, long maxCells) { + return tensor.type() + ":" + contentToString(tensor, maxCells); } - static String contentToString(Tensor tensor) { + static String contentToString(Tensor tensor, long maxCells) { var cellEntries = new ArrayList<>(tensor.cells().entrySet()); + cellEntries.sort(Map.Entry.comparingByKey()); if (tensor.type().dimensions().isEmpty()) { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } - return "{" + cellEntries.stream().sorted(Map.Entry.comparingByKey()) - .map(cell -> cellToString(cell, tensor.type())) - .collect(Collectors.joining(",")) + - "}"; + StringBuilder b = new StringBuilder("{"); + int i = 0; + for (; i < cellEntries.size() && i < maxCells; i++) { + if (i > 0) + b.append(", "); + b.append(cellToString(cellEntries.get(i), tensor.type())); + } + if (i == maxCells && i < tensor.size()) + b.append(", ..."); + b.append("}"); + return b.toString(); } private static String cellToString(Map.Entry<TensorAddress, Double> cell, TensorType type) { |