diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-02-16 16:43:20 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-02-16 16:43:20 +0100 |
commit | e25d913b884339afc4f8e3073e4e4b795e55d930 (patch) | |
tree | 408e9fded165a07fae202fd691f6f2864680ac63 /vespajlib/src | |
parent | 6f99bd502132cd378124a40060ac1d74d54f5e92 (diff) |
Resolve slice dimension
Diffstat (limited to 'vespajlib/src')
28 files changed, 88 insertions, 59 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java index 457cfcbfa5f..3b12b6bdba1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java @@ -57,7 +57,7 @@ public class TypeResolver { static public TensorType peek(TensorType inputType, List<String> peekDimensions) { if (peekDimensions.isEmpty()) { - throw new IllegalArgumentException("peeking no dimensions makes no sense"); + throw new IllegalArgumentException("Peeking no dimensions makes no sense"); } Map<String, Dimension> map = new HashMap<>(); for (Dimension dim : inputType.dimensions()) { @@ -67,7 +67,7 @@ public class TypeResolver { if (map.containsKey(name)) { map.remove(name); } else { - throw new IllegalArgumentException("peeking non-existing dimension "+name+" in type "+inputType); + throw new IllegalArgumentException("Peeking non-existing dimension '" + name + "'"); } } if (map.isEmpty()) { @@ -79,10 +79,10 @@ public class TypeResolver { static public TensorType rename(TensorType inputType, List<String> from, List<String> to) { if (from.isEmpty()) { - throw new IllegalArgumentException("renaming no dimensions"); + throw new IllegalArgumentException("Renaming no dimensions"); } if (from.size() != to.size()) { - throw new IllegalArgumentException("bad rename, from size "+from.size()+" != to.size "+to.size()); + throw new IllegalArgumentException("Bad rename, from size "+from.size()+" != to.size "+to.size()); } Map<String,Dimension> oldDims = new HashMap<>(); for (Dimension dim : inputType.dimensions()) { @@ -96,7 +96,7 @@ public class TypeResolver { var dim = oldDims.remove(oldName); newDims.put(newName, dim.withName(newName)); } else { - logger.log(Level.WARNING, "renaming non-existing dimension "+oldName+" in type "+inputType); + logger.log(Level.WARNING, "Renaming non-existing dimension "+oldName+" in type "+inputType); // throw new IllegalArgumentException("bad rename, dimension "+oldName+" not found"); } } @@ -106,13 +106,13 @@ public class TypeResolver { if (inputType.dimensions().size() == newDims.size()) { return new TensorType(inputType.valueType(), newDims.values()); } else { - throw new IllegalArgumentException("bad rename, lost some dimenions"); + throw new IllegalArgumentException("Bad rename, lost some dimensions"); } } static public TensorType cell_cast(TensorType inputType, Value toCellType) { if (toCellType != Value.DOUBLE && inputType.dimensions().isEmpty()) { - throw new IllegalArgumentException("cannot cast "+inputType+" to valueType"+toCellType); + throw new IllegalArgumentException("Cannot cast "+inputType+" to valueType"+toCellType); } return new TensorType(toCellType, inputType.dimensions()); } @@ -188,7 +188,7 @@ public class TypeResolver { if (allOk) { return join(lhs, rhs); } else { - throw new IllegalArgumentException("types in merge() dimensions mismatch: "+lhs+" != "+rhs); + throw new IllegalArgumentException("Types in merge() dimensions mismatch: "+lhs+" != "+rhs); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index a376536015a..dbc8396d701 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -58,7 +58,7 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return name; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 16ca7104f8d..55dd8a7bc8a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -48,7 +48,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index fcdc1233550..f1f0b9d67b0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -48,7 +48,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index 8c6c27e171a..09f84e6747e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -107,7 +107,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "cell_cast(" + argument.toString(context) + ", " + valueType + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 32a4c8cd2ff..6d4b15be991 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -285,7 +285,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 1544369ba2f..a0fd9272f54 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -47,6 +47,6 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override - public String toString(ToStringContext context) { return constant.toString(); } + public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index 2c0fa483021..92d89ec68f7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -41,7 +41,7 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 97126ad88a7..46992115c23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -46,11 +46,11 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens TensorType type() { return type; } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return type().toString() + ":" + contentToString(context); } - abstract String contentToString(ToStringContext context); + abstract String contentToString(ToStringContext<NAMETYPE> context); /** Creates a dynamic tensor function. The cell addresses must match the type. */ public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { @@ -80,7 +80,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens } @Override - String contentToString(ToStringContext context) { + String contentToString(ToStringContext<NAMETYPE> context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.values().iterator().next().toString(context) + "}"; @@ -121,7 +121,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens } @Override - String contentToString(ToStringContext context) { + String contentToString(ToStringContext<NAMETYPE> context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.get(0).toString(context) + "}"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java index 8fc246a7d9d..c049e5d41da 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -41,7 +41,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 89e981df49e..54e83fa472f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -117,9 +117,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; } + public String toString(ToStringContext<NAMETYPE> context) { return type + "(" + generatorToString(context) + ")"; } - private String generatorToString(ToStringContext context) { + private String generatorToString(ToStringContext<NAMETYPE> context) { if (freeGenerator != null) return freeGenerator.toString(); else @@ -183,11 +183,11 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } /** A context which adds the bindings of the generate dimension names to the given context. */ - private class GenerateToStringContext implements ToStringContext { + private class GenerateToStringContext implements ToStringContext<NAMETYPE> { - private final ToStringContext context; + private final ToStringContext<NAMETYPE> context; - public GenerateToStringContext(ToStringContext context) { + public GenerateToStringContext(ToStringContext<NAMETYPE> context) { this.context = context; } @@ -200,7 +200,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public ToStringContext parent() { return context; } + public ToStringContext<NAMETYPE> parent() { return context; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 0d4aeb5c37d..52bef482fb4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -75,7 +75,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 903d0b2dcd9..f47202d1b9f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -39,7 +39,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction< } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index c862aa4eaf6..8f4e2f466d4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -41,7 +41,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction< } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 40620cb95fe..46772d8cbff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -71,7 +71,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "map(" + argument.toString(context) + ", " + mapper + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 810e01011fe..8ac6d711c48 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -45,7 +45,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index bd42e95a59e..adc84225a63 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -70,7 +70,7 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 8cf1964585a..18c5db8e3a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -38,7 +38,7 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index 7d5b11d6672..45b827db900 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -42,7 +42,7 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 79209fd8f09..8841cff15e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -86,7 +86,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 9b56fefb5f0..7505355beed 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -314,7 +314,7 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "reduce_join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ", " + @@ -324,8 +324,8 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N private static class MultiDimensionIterator { - private long[] bounds; - private long[] iterator; + private final long[] bounds; + private final long[] iterator; private int remaining; MultiDimensionIterator(TensorType type) { @@ -364,9 +364,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N remaining -= 1; } + @Override public String toString() { return Arrays.toString(iterator); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 67ede7f6540..a434ecba5cc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -128,7 +128,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index 11e52aad73e..0e0dc9a9aa8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -20,6 +20,6 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati /** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */ default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); } - default String toString(ToStringContext context) { return toString(); } + default String toString(ToStringContext<NAMETYPE> context) { return toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 09bfb8b996b..da7581c39f9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -121,8 +121,8 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY private TensorType resultType(TensorType argumentType) { List<String> peekDimensions; - // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.get(0).index().isPresent()) { peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed); if (peekDimensions.size() > 1) { @@ -140,22 +140,28 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY else { // general slicing peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList()); } - if (peekDimensions.isEmpty()) - throw new IllegalArgumentException(this + " cannot slice " + argumentType + ": No dimensions to slice"); - return TypeResolver.peek(argumentType, peekDimensions); + try { + return TypeResolver.peek(argumentType, peekDimensions); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e); + } } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { StringBuilder b = new StringBuilder(argument.toString(context)); - if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (context.typeContext().isEmpty() + && subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms if (subspaceAddress.get(0).index().isPresent()) b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]"); else b.append("{").append(subspaceAddress.get(0).label().get()).append("}"); } - else { - b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); + else { // general form + b.append("{").append(subspaceAddress.stream() + .map(i -> i.toString(context, this)) + .collect(Collectors.joining(", "))).append("}"); } return b.toString(); } @@ -222,12 +228,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY @Override public String toString() { - return toString(ToStringContext.empty()); + return toString(null, null); } - public String toString(ToStringContext context) { + String toString(ToStringContext<NAMETYPE> context, Slice<NAMETYPE> owner) { StringBuilder b = new StringBuilder(); - dimension.ifPresent(d -> b.append(d).append(":")); + Optional<String> dimensionName = dimension; + if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail + TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null; + if (type == null || type.dimensions().size() != 1) + throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner + + " cannot be uniquely resolved. Use the full form " + + "slice{myDimensionName: ..."); + else + dimensionName = Optional.of(type.dimensions().get(0).name()); + } + dimensionName.ifPresent(d -> b.append(d).append(":")); if (label != null) b.append(label); else diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index 13420a12e8f..9ea9040831b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -46,7 +46,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "softmax(" + argument.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 81d3692bd94..1e1d1d3b5b9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -60,7 +60,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> { * * @param context a context which must be passed to all nested functions when requesting the string value */ - public abstract String toString(ToStringContext context); + public abstract String toString(ToStringContext<NAMETYPE> context); /** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */ public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java index 1c8da9a1dca..233779fcebe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -1,31 +1,42 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Optional; + /** * A context which is passed down to all nested functions when returning a string representation. * * @author bratseth */ -public interface ToStringContext { +public interface ToStringContext<NAMETYPE extends Name> { - static ToStringContext empty() { return new EmptyStringContext(); } + static <NAMETYPE extends Name> ToStringContext<NAMETYPE> empty() { return new EmptyStringContext<NAMETYPE>(); } /** Returns the name an identifier is bound to, or null if not bound in this context */ String getBinding(String name); /** + * Returns the context used to resolve types in this, if present. + * In some functions serialization depends on type information. + */ + default Optional<TypeContext<NAMETYPE>> typeContext() { return Optional.empty(); } + + /** * Returns the parent context of this (the context we're in scope of when this is created), * or null if this is the root. */ - ToStringContext parent(); + ToStringContext<NAMETYPE> parent(); - class EmptyStringContext implements ToStringContext { + class EmptyStringContext<NAMETYPE extends Name> implements ToStringContext<NAMETYPE> { @Override public String getBinding(String name) { return null; } @Override - public ToStringContext parent() { return null; } + public ToStringContext<NAMETYPE> parent() { return null; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 112a0d43796..0223ad4d588 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -44,7 +44,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "xw_plus_b(" + x.toString(context) + ", " + w.toString(context) + ", " + b.toString(context) + ", " + |