From e69d6e8f3d8a6504135f6d2733a3a42f6a041ed4 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 29 Mar 2022 12:21:56 +0200 Subject: Validate query feature tensor types - Validate tensor feature types when a tensor is set programmatically. - Add a toShortString for messages containing tensors. - Consistent and nicer spacing in tensor string forms. --- .../com/yahoo/processing/request/Properties.java | 2 +- container-search/abi-spec.json | 2 + .../java/com/yahoo/search/query/Properties.java | 12 ++++ .../query/profile/QueryProfileProperties.java | 39 +++++++++---- .../query/profile/types/TensorFieldType.java | 10 +++- .../search/query/properties/QueryProperties.java | 1 + .../yahoo/search/query/ranking/RankFeatures.java | 7 +++ .../types/test/QueryProfileTypeTestCase.java | 12 ++-- vespajlib/abi-spec.json | 8 ++- .../main/java/com/yahoo/tensor/IndexedTensor.java | 26 ++++++--- .../main/java/com/yahoo/tensor/MappedTensor.java | 9 ++- .../main/java/com/yahoo/tensor/MixedTensor.java | 64 +++++++++++++------- .../src/main/java/com/yahoo/tensor/Tensor.java | 27 ++++++--- .../com/yahoo/tensor/MappedTensorTestCase.java | 4 +- .../java/com/yahoo/tensor/MixedTensorTestCase.java | 16 +++-- .../test/java/com/yahoo/tensor/TensorTestCase.java | 68 ++++++++++++++++++---- .../tensor/functions/TensorFunctionTestCase.java | 4 +- 17 files changed, 232 insertions(+), 79 deletions(-) diff --git a/container-core/src/main/java/com/yahoo/processing/request/Properties.java b/container-core/src/main/java/com/yahoo/processing/request/Properties.java index cc53442f4d3..08072e83ce4 100644 --- a/container-core/src/main/java/com/yahoo/processing/request/Properties.java +++ b/container-core/src/main/java/com/yahoo/processing/request/Properties.java @@ -251,7 +251,7 @@ public class Properties implements Cloneable { * @throws RuntimeException if no instance in the chain accepted this name-value pair */ public final void set(String name, Object value) { - set(new CompoundName(name), value, Collections.emptyMap()); + set(new CompoundName(name), value, Map.of()); } /** diff --git a/container-search/abi-spec.json b/container-search/abi-spec.json index b1b80eac3a4..4b2c81e9943 100644 --- a/container-search/abi-spec.json +++ b/container-search/abi-spec.json @@ -5325,6 +5325,7 @@ "public com.yahoo.search.query.Properties clone()", "public com.yahoo.search.Query getParentQuery()", "public void setParentQuery(com.yahoo.search.Query)", + "public void requireSettable(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", "public bridge synthetic com.yahoo.processing.request.Properties clone()", "public bridge synthetic com.yahoo.processing.request.Properties chained()", "public bridge synthetic java.lang.Object clone()" @@ -6030,6 +6031,7 @@ "public com.yahoo.search.query.profile.compiled.CompiledQueryProfile getQueryProfile()", "public java.lang.Object get(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)", "public void set(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", + "public void requireSettable(com.yahoo.processing.request.CompoundName, java.lang.Object, java.util.Map)", "public void clearAll(com.yahoo.processing.request.CompoundName, java.util.Map)", "public java.util.Map listProperties(com.yahoo.processing.request.CompoundName, java.util.Map, com.yahoo.processing.request.Properties)", "public boolean isComplete(java.lang.StringBuilder, java.util.Map)", diff --git a/container-search/src/main/java/com/yahoo/search/query/Properties.java b/container-search/src/main/java/com/yahoo/search/query/Properties.java index 12a82afc7bd..d4fc4d57cd6 100644 --- a/container-search/src/main/java/com/yahoo/search/query/Properties.java +++ b/container-search/src/main/java/com/yahoo/search/query/Properties.java @@ -1,8 +1,11 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.search.query; +import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; +import java.util.Map; + /** * Object properties keyed by name which can be looked up using default values and * with conversion to various primitive wrapper types. @@ -50,4 +53,13 @@ public abstract class Properties extends com.yahoo.processing.request.Properties chained().setParentQuery(query); } + /** + * Throws IllegalInputException if the given key cannot be set to the given value. + * This default implementation just passes to the chained properties, if any. + */ + public void requireSettable(CompoundName name, Object value, Map context) { + if (chained() != null) + chained().requireSettable(name, value, context); + } + } diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java index 5b3758f103d..19e0e441359 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/QueryProfileProperties.java @@ -14,6 +14,7 @@ import com.yahoo.search.query.profile.types.ConversionContext; import com.yahoo.search.query.profile.types.FieldDescription; import com.yahoo.search.query.profile.types.QueryProfileFieldType; import com.yahoo.search.query.profile.types.QueryProfileType; +import com.yahoo.tensor.Tensor; import java.util.ArrayList; import java.util.Collections; @@ -91,6 +92,15 @@ public class QueryProfileProperties extends Properties { */ @Override public void set(CompoundName name, Object value, Map context) { + setOrCheckSettable(name, value, context, true); + } + + @Override + public void requireSettable(CompoundName name, Object value, Map context) { + setOrCheckSettable(name, value, context, false); + } + + private void setOrCheckSettable(CompoundName name, Object value, Map context, boolean set) { try { name = unalias(name, context); @@ -110,29 +120,36 @@ public class QueryProfileProperties extends Properties { if (value instanceof String && value.toString().startsWith("ref:")) { if (profile.getRegistry() == null) throw new IllegalInputException("Runtime query profile references does not work when the " + - "QueryProfileProperties are constructed without a registry"); + "QueryProfileProperties are constructed without a registry"); String queryProfileId = value.toString().substring(4); value = profile.getRegistry().findQueryProfile(queryProfileId); if (value == null) throw new IllegalInputException("Query profile '" + queryProfileId + "' is not found"); } - if (value instanceof CompiledQueryProfile) { // this will be due to one of the two clauses above - if (references == null) - references = new ArrayList<>(); - references.add(0, new Pair<>(name, (CompiledQueryProfile)value)); // references set later has precedence - put first - } - else { - if (values == null) - values = new HashMap<>(); - values.put(name, value); + if (set) { + if (value instanceof CompiledQueryProfile) { // this will be due to one of the two clauses above + if (references == null) + references = new ArrayList<>(); + // references set later has precedence - put first + references.add(0, new Pair<>(name, (CompiledQueryProfile) value)); + } else { + if (values == null) + values = new HashMap<>(); + values.put(name, value); + } } } catch (IllegalArgumentException e) { - throw new IllegalInputException("Could not set '" + name + "' to '" + value + "'", e); + throw new IllegalInputException("Could not set '" + name + "' to '" + toShortString(value) + "'", e); } } + private String toShortString(Object value) { + if ( ! (value instanceof Tensor)) return value.toString(); + return ((Tensor)value).toShortString(); + } + private Object convertByType(CompoundName name, Object value, Map context) { QueryProfileType type; QueryProfileType explicitTypeFromField = null; diff --git a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java index 6f1cfccc16b..db6a58a4dd3 100644 --- a/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java +++ b/container-search/src/main/java/com/yahoo/search/query/profile/types/TensorFieldType.java @@ -51,7 +51,15 @@ public class TensorFieldType extends FieldType { @Override public Object convertFrom(Object o, ConversionContext context) { - if (o instanceof Tensor) return o; + Tensor tensor = toTensor(o, context); + if (tensor == null) return null; + if (! tensor.type().isAssignableTo(type)) + throw new IllegalArgumentException("Require a tensor of type " + type); + return tensor; + } + + private Tensor toTensor(Object o, ConversionContext context) { + if (o instanceof Tensor) return (Tensor)o; if (o instanceof String && ((String)o).startsWith("embed(")) return encode((String)o, context); if (o instanceof String) return Tensor.from(type, (String)o); return null; diff --git a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java index 98b65c6edd9..2c0f5dc8bea 100644 --- a/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java +++ b/container-search/src/main/java/com/yahoo/search/query/properties/QueryProperties.java @@ -322,6 +322,7 @@ public class QueryProperties extends Properties { } } else if (key.first().equals("rankfeature") || key.first().equals("featureoverride") ) { // featureoverride is deprecated + chained().requireSettable(key, value, context); setRankingFeature(query, key.rest().toString(), toSpecifiedType(key.rest().toString(), value, profileRegistry.getTypeRegistry().getComponent("features"), diff --git a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java index 1a4ecb4ecd8..807d70739cc 100644 --- a/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java +++ b/container-search/src/main/java/com/yahoo/search/query/ranking/RankFeatures.java @@ -2,8 +2,10 @@ package com.yahoo.search.query.ranking; import com.yahoo.fs4.MapEncoder; +import com.yahoo.processing.request.CompoundName; import com.yahoo.search.Query; import com.yahoo.search.query.Ranking; +import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.tensor.Tensor; import com.yahoo.text.JSON; @@ -47,9 +49,14 @@ public class RankFeatures implements Cloneable { /** Sets a tensor rank feature */ public void put(String name, Tensor value) { + verifyType(name, value); features.put(name, value); } + private void verifyType(String name, Object value) { + parent.getParent().properties().requireSettable(new CompoundName(List.of("ranking", "features", name)), value, Map.of()); + } + /** * Sets a rank feature to a value represented as a string. * diff --git a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java index 20678f3b7bb..a77de954b3a 100644 --- a/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java +++ b/container-search/src/test/java/com/yahoo/search/query/profile/types/test/QueryProfileTypeTestCase.java @@ -20,7 +20,6 @@ import com.yahoo.search.query.profile.types.FieldType; import com.yahoo.search.query.profile.types.QueryProfileType; import com.yahoo.search.query.profile.types.QueryProfileTypeRegistry; import org.junit.Before; -import org.junit.Ignore; import org.junit.Test; import java.net.URLEncoder; @@ -439,7 +438,6 @@ public class QueryProfileTypeTestCase { } @Test - @Ignore public void testTensorRankFeatureSetProgrammaticallyWithWrongType() { QueryProfile profile = new QueryProfile("test"); profile.setType(testtype); @@ -454,16 +452,18 @@ public class QueryProfileTypeTestCase { fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("'query(myTensor1)' must be of type tensor(a{},b{}) but was of type tensor(x[3])", - e.getMessage()); + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "Require a tensor of type tensor(a{},b{})", + Exceptions.toMessageString(e)); } try { query.properties().set("ranking.features.query(myTensor1)", Tensor.from(tensorString)); fail("Expected exception"); } catch (IllegalArgumentException e) { - assertEquals("'query(myTensor1)' must be of type tensor(a{},b{}) but was of type tensor(x[3])", - e.getMessage()); + assertEquals("Could not set 'ranking.features.query(myTensor1)' to 'tensor(x[3]):[0.1, 0.2, 0.3]': " + + "Require a tensor of type tensor(a{},b{})", + Exceptions.toMessageString(e)); } } diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 4e25d8ab0e0..bc7888aaac6 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -902,6 +902,7 @@ "public java.util.Map cells()", "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public java.lang.String toString()", + "public java.lang.String toShortString()", "public boolean equals(java.lang.Object)", "public bridge synthetic com.yahoo.tensor.Tensor withType(com.yahoo.tensor.TensorType)" ], @@ -952,6 +953,7 @@ "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", + "public java.lang.String toShortString()", "public boolean equals(java.lang.Object)" ], "fields": [] @@ -1043,6 +1045,7 @@ "public com.yahoo.tensor.Tensor remove(java.util.Set)", "public int hashCode()", "public java.lang.String toString()", + "public java.lang.String toShortString()", "public boolean equals(java.lang.Object)", "public long denseSubspaceSize()", "public static com.yahoo.tensor.TensorType createPartialType(com.yahoo.tensor.TensorType$Value, java.util.List)" @@ -1229,8 +1232,9 @@ "public java.util.List largest()", "public java.util.List smallest()", "public abstract java.lang.String toString()", - "public static java.lang.String toStandardString(com.yahoo.tensor.Tensor)", - "public static java.lang.String contentToString(com.yahoo.tensor.Tensor)", + "public abstract java.lang.String toShortString()", + "public static java.lang.String toStandardString(com.yahoo.tensor.Tensor, long)", + "public static java.lang.String contentToString(com.yahoo.tensor.Tensor, long)", "public abstract boolean equals(java.lang.Object)", "public abstract int hashCode()", "public static boolean equals(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 0e919d828ed..89eefeced56 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -219,21 +219,31 @@ public abstract class IndexedTensor implements Tensor { } @Override - public String toString() { - if (type.rank() == 0) return Tensor.toStandardString(this); + public String toString() { return toString(Long.MAX_VALUE); } + + @Override + public String toShortString() { + return toString(Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + } + + private String toString(long maxCells) { + if (type.rank() == 0) return Tensor.toStandardString(this, maxCells); if (type.dimensions().stream().anyMatch(d -> d.size().isEmpty())) - return Tensor.toStandardString(this); + return Tensor.toStandardString(this, maxCells); Indexes indexes = Indexes.of(dimensionSizes); StringBuilder b = new StringBuilder(type.toString()).append(":"); - indexedBlockToString(this, indexes, b); + indexedBlockToString(this, indexes, maxCells, b); return b.toString(); } - static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, StringBuilder b) { - for (int index = 0; index < tensor.size(); index++) { + static void indexedBlockToString(IndexedTensor tensor, Indexes indexes, long maxCells, StringBuilder b) { + int index = 0; + for (; index < tensor.size() && index < maxCells; index++) { indexes.next(); + if (index > 0) + b.append(", "); // start brackets for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) @@ -252,9 +262,9 @@ public abstract class IndexedTensor implements Tensor { // end bracket and comma for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); - if (index < tensor.size() - 1) - b.append(", "); } + if (index == maxCells && index < tensor.size()) + b.append(", ...]"); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index 09e93d80bd9..ad945ed18bf 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -50,7 +50,7 @@ public class MappedTensor implements Tensor { public Tensor withType(TensorType other) { if (!this.type.isRenamableTo(type)) { throw new IllegalArgumentException("MappedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); + this.type + "', requested type: '" + type.toString() + "'"); } return new MappedTensor(other, cells); } @@ -72,7 +72,12 @@ public class MappedTensor implements Tensor { public int hashCode() { return cells.hashCode(); } @Override - public String toString() { return Tensor.toStandardString(this); } + public String toString() { return Tensor.toStandardString(this, Long.MAX_VALUE); } + + @Override + public String toShortString() { + return Tensor.toStandardString(this, Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + } @Override public boolean equals(Object other) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 418e9efdffb..56bd94a86e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -5,6 +5,7 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; import java.util.HashMap; import java.util.Iterator; import java.util.List; @@ -116,7 +117,7 @@ public class MixedTensor implements Tensor { public Tensor withType(TensorType other) { if (!this.type.isRenamableTo(type)) { throw new IllegalArgumentException("MixedTensor.withType: types are not compatible. Current type: '" + - this.type.toString() + "', requested type: '" + type.toString() + "'"); + this.type + "', requested type: '" + type + "'"); } return new MixedTensor(other, cells, index); } @@ -144,12 +145,23 @@ public class MixedTensor implements Tensor { @Override public String toString() { - if (type.rank() == 0) return Tensor.toStandardString(this); + return toString(Long.MAX_VALUE); + } + + @Override + public String toShortString() { + return toString(Math.max(2, 10 / (type().dimensions().stream().filter(d -> d.isMapped()).count() + 1))); + } + + private String toString(long maxCells) { + if (type.rank() == 0) + return Tensor.toStandardString(this, maxCells); if (type.rank() > 1 && type.dimensions().stream().filter(d -> d.isIndexed()).anyMatch(d -> d.size().isEmpty())) - return Tensor.toStandardString(this); - if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) return Tensor.toStandardString(this); + return Tensor.toStandardString(this, maxCells); + if (type.dimensions().stream().filter(d -> d.isMapped()).count() > 1) + return Tensor.toStandardString(this, maxCells); - return type.toString() + ":" + index.contentToString(this); + return type + ":" + index.contentToString(this, maxCells); } @Override @@ -503,37 +515,50 @@ public class MixedTensor implements Tensor { return "index into " + type; } - private String contentToString(MixedTensor tensor) { + private String contentToString(MixedTensor tensor, long maxCells) { if (mappedDimensions.size() > 1) throw new IllegalStateException("Should be ensured by caller"); if (mappedDimensions.size() == 0) { StringBuilder b = new StringBuilder(); - denseSubspaceToString(tensor, 0, b); + int cellsWritten = denseSubspaceToString(tensor, 0, maxCells, b); + if (cellsWritten == maxCells && cellsWritten < tensor.size()) + b.append("...]"); return b.toString(); } // Exactly 1 mapped dimension StringBuilder b = new StringBuilder("{"); - sparseMap.entrySet().stream().sorted(Map.Entry.comparingByKey()).forEach(entry -> { - b.append(TensorAddress.labelToString(entry.getKey().label(0 ))); + var cellEntries = new ArrayList<>(sparseMap.entrySet()); + cellEntries.sort(Map.Entry.comparingByKey()); + int cellsWritten = 0; + for (int index = 0; index < cellEntries.size() && cellsWritten < maxCells; index++) { + if (index > 0) + b.append(", "); + b.append(TensorAddress.labelToString(cellEntries.get(index).getKey().label(0 ))); b.append(":"); - denseSubspaceToString(tensor, entry.getValue(), b); - b.append(","); - }); - if (b.length() > 1) - b.setLength(b.length() - 1); + cellsWritten += denseSubspaceToString(tensor, cellEntries.get(index).getValue(), maxCells - cellsWritten, b); + } + if (cellsWritten >= maxCells && cellsWritten < tensor.size()) + b.append(", ..."); b.append("}"); return b.toString(); } - private void denseSubspaceToString(MixedTensor tensor, long subspaceIndex, StringBuilder b) { + private int denseSubspaceToString(MixedTensor tensor, long subspaceIndex, long maxCells, StringBuilder b) { + if (maxCells <= 0) { + return 0; + } + if (denseSubspaceSize == 1) { b.append(getDouble(subspaceIndex, 0, tensor)); - return; + return 1; } IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseType); - for (int index = 0; index < denseSubspaceSize; index++) { + int index = 0; + for (; index < denseSubspaceSize && index < maxCells; index++) { indexes.next(); + if (index > 0) + b.append(", "); // start brackets for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) @@ -549,12 +574,11 @@ public class MixedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); } - // end bracket and comma + // end bracket for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); - if (index < denseSubspaceSize - 1) - b.append(", "); } + return index; } private double getDouble(long indexedSubspaceIndex, long indexInIndexedSubspace, MixedTensor tensor) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index ca396ae5bf2..06e7b010a7a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -316,28 +316,41 @@ public interface Tensor { @Override String toString(); + /** Returns an abbreviated string representation of this tensor suitable for human-readable messages */ + String toShortString(); + /** * Call this from toString in implementations to return this tensor on the * tensor literal form. * (toString cannot be a default method because default methods cannot override super methods). * * @param tensor the tensor to return the standard string format of + * @param maxCells the max number of cells to output, after which just , "..." is output to represent the rest + * of the cells * @return the tensor on the standard string format */ - static String toStandardString(Tensor tensor) { - return tensor.type() + ":" + contentToString(tensor); + static String toStandardString(Tensor tensor, long maxCells) { + return tensor.type() + ":" + contentToString(tensor, maxCells); } - static String contentToString(Tensor tensor) { + static String contentToString(Tensor tensor, long maxCells) { var cellEntries = new ArrayList<>(tensor.cells().entrySet()); + cellEntries.sort(Map.Entry.comparingByKey()); if (tensor.type().dimensions().isEmpty()) { if (cellEntries.isEmpty()) return "{}"; return "{" + cellEntries.get(0).getValue() +"}"; } - return "{" + cellEntries.stream().sorted(Map.Entry.comparingByKey()) - .map(cell -> cellToString(cell, tensor.type())) - .collect(Collectors.joining(",")) + - "}"; + StringBuilder b = new StringBuilder("{"); + int i = 0; + for (; i < cellEntries.size() && i < maxCells; i++) { + if (i > 0) + b.append(", "); + b.append(cellToString(cellEntries.get(i), tensor.type())); + } + if (i == maxCells && i < tensor.size()) + b.append(", ..."); + b.append("}"); + return b.toString(); } private static String cellToString(Map.Entry cell, TensorType type) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java index 7bb02f03735..ba814f7ad54 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MappedTensorTestCase.java @@ -35,7 +35,7 @@ public class MappedTensorTestCase { cell().label("x", "0").value(1). cell().label("x", "1").value(2).build(); assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames()); - assertEquals("tensor(x{}):{0:1.0,1:2.0}", tensor.toString()); + assertEquals("tensor(x{}):{0:1.0, 1:2.0}", tensor.toString()); } @Test @@ -45,7 +45,7 @@ public class MappedTensorTestCase { cell().label("x", "0").label("y", "0").value(1). cell().label("x", "1").label("y", "0").value(2).build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0,{x:1,y:0}:2.0}", tensor.toString()); + assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0, {x:1,y:0}:2.0}", tensor.toString()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java index 50f2bc5efff..a26e56c4468 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MixedTensorTestCase.java @@ -69,7 +69,7 @@ public class MixedTensorTestCase { cell().label("x", "1").value(2). build(); assertEquals(Sets.newHashSet("x"), tensor.type().dimensionNames()); - assertEquals("tensor(x{}):{0:1.0,1:2.0}", + assertEquals("tensor(x{}):{0:1.0, 1:2.0}", tensor.toString()); } @@ -84,7 +84,7 @@ public class MixedTensorTestCase { cell().label("x", "1").label("y", "2").value(6). build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0,{x:0,y:1}:2.0,{x:1,y:0}:4.0,{x:1,y:1}:5.0,{x:1,y:2}:6.0}", + assertEquals("tensor(x{},y{}):{{x:0,y:0}:1.0, {x:0,y:1}:2.0, {x:1,y:0}:4.0, {x:1,y:1}:5.0, {x:1,y:2}:6.0}", tensor.toString()); } @@ -100,7 +100,7 @@ public class MixedTensorTestCase { cell().label("x", "2").label("y", 2).value(6). build(); assertEquals(Sets.newHashSet("x", "y"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y[3]):{1:[1.0, 2.0, 0.0],2:[4.0, 5.0, 6.0]}", + assertEquals("tensor(x{},y[3]):{1:[1.0, 2.0, 0.0], 2:[4.0, 5.0, 6.0]}", tensor.toString()); } @@ -122,7 +122,9 @@ public class MixedTensorTestCase { cell().label("x", "x2").label("y", 2).label("z","z2").value(16). build(); assertEquals(Sets.newHashSet("x", "y", "z"), tensor.type().dimensionNames()); - assertEquals("tensor(x{},y[3],z{}):{{x:x1,y:0,z:z1}:1.0,{x:x1,y:0,z:z2}:2.0,{x:x1,y:1,z:z1}:3.0,{x:x1,y:1,z:z2}:4.0,{x:x1,y:2,z:z1}:5.0,{x:x1,y:2,z:z2}:6.0,{x:x2,y:0,z:z1}:11.0,{x:x2,y:0,z:z2}:12.0,{x:x2,y:1,z:z1}:13.0,{x:x2,y:1,z:z2}:14.0,{x:x2,y:2,z:z1}:15.0,{x:x2,y:2,z:z2}:16.0}", + assertEquals("tensor(x{},y[3],z{}):{{x:x1,y:0,z:z1}:1.0, {x:x1,y:0,z:z2}:2.0, {x:x1,y:1,z:z1}:3.0, " + + "{x:x1,y:1,z:z2}:4.0, {x:x1,y:2,z:z1}:5.0, {x:x1,y:2,z:z2}:6.0, {x:x2,y:0,z:z1}:11.0, " + + "{x:x2,y:0,z:z2}:12.0, {x:x2,y:1,z:z1}:13.0, {x:x2,y:1,z:z2}:14.0, {x:x2,y:2,z:z1}:15.0, {x:x2,y:2,z:z2}:16.0}", tensor.toString()); } @@ -148,7 +150,11 @@ public class MixedTensorTestCase { cell().label("i", "b").label("k","d").label("j",1).label("l",1).value(16). build(); assertEquals(Sets.newHashSet("i", "j", "k", "l"), tensor.type().dimensionNames()); - assertEquals("tensor(i{},j[2],k{},l[2]):{{i:a,j:0,k:c,l:0}:1.0,{i:a,j:0,k:c,l:1}:2.0,{i:a,j:0,k:d,l:0}:5.0,{i:a,j:0,k:d,l:1}:6.0,{i:a,j:1,k:c,l:0}:3.0,{i:a,j:1,k:c,l:1}:4.0,{i:a,j:1,k:d,l:0}:7.0,{i:a,j:1,k:d,l:1}:8.0,{i:b,j:0,k:c,l:0}:9.0,{i:b,j:0,k:c,l:1}:10.0,{i:b,j:0,k:d,l:0}:13.0,{i:b,j:0,k:d,l:1}:14.0,{i:b,j:1,k:c,l:0}:11.0,{i:b,j:1,k:c,l:1}:12.0,{i:b,j:1,k:d,l:0}:15.0,{i:b,j:1,k:d,l:1}:16.0}", + assertEquals("tensor(i{},j[2],k{},l[2]):{{i:a,j:0,k:c,l:0}:1.0, {i:a,j:0,k:c,l:1}:2.0, " + + "{i:a,j:0,k:d,l:0}:5.0, {i:a,j:0,k:d,l:1}:6.0, {i:a,j:1,k:c,l:0}:3.0, {i:a,j:1,k:c,l:1}:4.0, " + + "{i:a,j:1,k:d,l:0}:7.0, {i:a,j:1,k:d,l:1}:8.0, {i:b,j:0,k:c,l:0}:9.0, {i:b,j:0,k:c,l:1}:10.0, " + + "{i:b,j:0,k:d,l:0}:13.0, {i:b,j:0,k:d,l:1}:14.0, {i:b,j:1,k:c,l:0}:11.0, {i:b,j:1,k:c,l:1}:12.0, "+ + "{i:b,j:1,k:d,l:0}:15.0, {i:b,j:1,k:d,l:1}:16.0}", tensor.toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index fd33cf97220..2067d7a8492 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -31,16 +31,60 @@ import static org.junit.Assert.fail; public class TensorTestCase { @Test - public void testStringForm() { - assertEquals("tensor():{5.7}", Tensor.from("{5.7}").toString()); + public void testFactory() { assertTrue(Tensor.from("tensor():{5.7}") instanceof IndexedTensor); - assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l2}:6.0}", Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); - assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3,{d1:l1,d2:l2}:0.0}", Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); + } + + @Test + public void testToString() { + assertEquals("tensor():{5.7}", Tensor.from("{5.7}").toString()); + assertEquals("tensor(x[3]):[0.1, 0.2, 0.3]", + Tensor.from("tensor(x[3]):[0.1, 0.2, 0.3]").toString()); + assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:6.0}", + Tensor.from("{ {d1:l1,d2:l1}: 5, {d2:l2, d1:l1}:6.0} ").toString()); + assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:-5.3, {d1:l1,d2:l2}:0.0}", + Tensor.from("{ {d1:l1,d2:l1}:-5.3, {d2:l2, d1:l1}:0}").toString()); + assertEquals("tensor(m{},x[3]):{k1:[0.0, 1.0, 2.0], k2:[0.0, 1.0, 2.0], k3:[0.0, 1.0, 2.0], k4:[0.0, 1.0, 2.0]}", + Tensor.from("tensor(m{},x[3]):{k1:[0,1,2], k2:[0,1,2], k3:[0,1,2], k4:[0,1,2]}").toString()); + assertEquals("tensor(m{},n{},x[3]):" + + "{{m:k1,n:k1,x:0}:0.0, {m:k1,n:k1,x:1}:1.0, {m:k1,n:k1,x:2}:2.0," + + " {m:k2,n:k1,x:0}:0.0, {m:k2,n:k1,x:1}:1.0, {m:k2,n:k1,x:2}:2.0," + + " {m:k3,n:k1,x:0}:0.0, {m:k3,n:k1,x:1}:1.0, {m:k3,n:k1,x:2}:2.0}", + Tensor.from("tensor(m{},n{},x[3]):" + + "{{m:k1,n:k1,x:0}:0, {m:k1,n:k1,x:1}:1, {m:k1,n:k1,x:2}:2, " + + " {m:k2,n:k1,x:0}:0, {m:k2,n:k1,x:1}:1, {m:k2,n:k1,x:2}:2, " + + " {m:k3,n:k1,x:0}:0, {m:k3,n:k1,x:1}:1, {m:k3,n:k1,x:2}:2}").toString()); + assertEquals("tensor(m{},x[2],y[2]):" + + "{k1:[[0.0, 1.0], [2.0, 3.0]], k2:[[0.0, 1.0], [2.0, 3.0]], k3:[[0.0, 1.0], [2.0, 3.0]]}", + Tensor.from("tensor(m{},x[2],y[2]):{k1:[[0,1],[2,3]], k2:[[0,1],[2,3]], k3:[[0,1],[2,3]]}").toString()); assertEquals("Labels are quoted when necessary", - "tensor(d1{}):{\"'''\":6.0,'[[\":\"]]':5.0}", + "tensor(d1{}):{\"'''\":6.0, '[[\":\"]]':5.0}", Tensor.from("{ {d1:'[[\":\"]]'}: 5, {d1:\"'''\"}:6.0 }").toString()); } + @Test + public void testToShortString() { + assertEquals("tensor(x[10]):[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0]", + Tensor.from("tensor(x[10]):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]").toShortString()); + assertEquals("tensor(x[14]):[0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, ...]", + Tensor.from("tensor(x[14]):[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13]").toShortString()); + assertEquals("tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0, {d1:l1,d2:l3}:6.0, ...}", + Tensor.from("{{d1:l1,d2:l1}:6, {d2:l2,d1:l1}:6, {d2:l3,d1:l1}:6, {d2:l4,d1:l1}:6, {d2:l5,d1:l1}:6," + + " {d2:l6,d1:l1}:6, {d2:l7,d1:l1}:6, {d2:l8,d1:l1}:6, {d2:l9,d1:l1}:6, {d2:l2,d1:l2}:6," + + " {d2:l2,d1:l3}:6, {d2:l2,d1:l4}:6}").toShortString()); + assertEquals("tensor(m{},x[3]):{k1:[0.0, 1.0, 2.0], k2:[0.0, 1.0, ...}", + Tensor.from("tensor(m{},x[3]):{k1:[0,1,2], k2:[0,1,2], k3:[0,1,2], k4:[0,1,2]}").toShortString()); + assertEquals("tensor(m{},x[3]):{k1:[0.0, 1.0, 2.0], k2:[0.0, 1.0, ...}", + Tensor.from("tensor(m{},x[3]):{k1:[0,1,2], k2:[0,1,2], k3:[0,1,2], k4:[0,1,2]}").toShortString()); + assertEquals("tensor(m{},n{},x[3]):{{m:k1,n:k1,x:0}:0.0, {m:k1,n:k1,x:1}:1.0, {m:k1,n:k1,x:2}:2.0, ...}", + Tensor.from("tensor(m{},n{},x[3]):" + + "{{m:k1,n:k1,x:0}:0, {m:k1,n:k1,x:1}:1, {m:k1,n:k1,x:2}:2, " + + " {m:k2,n:k1,x:0}:0, {m:k2,n:k1,x:1}:1, {m:k2,n:k1,x:2}:2, " + + " {m:k3,n:k1,x:0}:0, {m:k3,n:k1,x:1}:1, {m:k3,n:k1,x:2}:2}").toShortString()); + assertEquals("tensor(m{},x[2],y[2]):{k1:[[0.0, 1.0], [2.0, 3.0]], k2:[[0.0, ...}", + Tensor.from("tensor(m{},x[2],y[2]):{k1:[[0,1],[2,3]], k2:[[0,1],[2,3]], k3:[[0,1],[2,3]]}").toShortString()); + } + @Test public void testValueTypes() { assertEquals(Tensor.from("tensor(x[1]):{{x:0}:5}").getClass(), IndexedDoubleTensor.class); @@ -60,13 +104,6 @@ public class TensorTestCase { IndexedFloatTensor.class); } - private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { - Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); - Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); - assertEquals(valueType, t1.multiply(t2).type().valueType()); - assertEquals(valueType, t2.multiply(t1).type().valueType()); - } - @Test public void testValueTypeResolving() { assertCellTypeResult(TensorType.Value.DOUBLE, "double", "double"); @@ -319,6 +356,13 @@ public class TensorTestCase { "tensor(x[2],y[2]):[[4,2],[3,4]]"); } + private void assertCellTypeResult(TensorType.Value valueType, String type1, String type2) { + Tensor t1 = Tensor.from("tensor<" + type1 + ">(x[1]):[3] }"); + Tensor t2 = Tensor.from("tensor<" + type2 + ">(x[1]):[5] }"); + assertEquals(valueType, t1.multiply(t2).type().valueType()); + assertEquals(valueType, t2.multiply(t1).type().valueType()); + } + private void assertLargest(String expectedCells, String tensorString) { Tensor tensor = Tensor.from(tensorString); assertEquals(expectedCells, asString(tensor.largest(), tensor.type())); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index ce165474a53..738213ecb97 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -19,9 +19,9 @@ public class TensorFunctionTestCase { new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x")); assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); - assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max, x), f(a,b)(a==b))", + assertTranslated("join(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, reduce(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, max, x), f(a,b)(a==b))", new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); - assertTranslated("join(tensor(x{}):{1:1.0,3:5.0,9:3.0}, reduce(tensor(x{}):{1:1.0,3:5.0,9:3.0}, max), f(a,b)(a==b))", + assertTranslated("join(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, reduce(tensor(x{}):{1:1.0, 3:5.0, 9:3.0}, max), f(a,b)(a==b))", new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"))); } -- cgit v1.2.3