summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2022-02-17 12:31:15 +0100
committerGitHub <noreply@github.com>2022-02-17 12:31:15 +0100
commitc17968cbd98d30b977a4f32b15f5711d18dbef4c (patch)
treea80f20cae59d80d926373a28141b21688d9dd9c0 /vespajlib
parent63c3b54619e9e3fbb18e0954555075f026dbc22d (diff)
parente25d913b884339afc4f8e3073e4e4b795e55d930 (diff)
Merge pull request #21228 from vespa-engine/bratseth/resolve-slice-dimension
Resolve slice dimension
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java2
29 files changed, 90 insertions, 61 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 71e6c0f28f4..a30ee055538 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2658,8 +2658,7 @@
"public java.util.Optional dimension()",
"public java.util.Optional label()",
"public java.util.Optional index()",
- "public java.lang.String toString()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString()"
],
"fields": []
},
@@ -2744,6 +2743,7 @@
"methods": [
"public static com.yahoo.tensor.functions.ToStringContext empty()",
"public abstract java.lang.String getBinding(java.lang.String)",
+ "public java.util.Optional typeContext()",
"public abstract com.yahoo.tensor.functions.ToStringContext parent()"
],
"fields": []
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
index 457cfcbfa5f..3b12b6bdba1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
@@ -57,7 +57,7 @@ public class TypeResolver {
static public TensorType peek(TensorType inputType, List<String> peekDimensions) {
if (peekDimensions.isEmpty()) {
- throw new IllegalArgumentException("peeking no dimensions makes no sense");
+ throw new IllegalArgumentException("Peeking no dimensions makes no sense");
}
Map<String, Dimension> map = new HashMap<>();
for (Dimension dim : inputType.dimensions()) {
@@ -67,7 +67,7 @@ public class TypeResolver {
if (map.containsKey(name)) {
map.remove(name);
} else {
- throw new IllegalArgumentException("peeking non-existing dimension "+name+" in type "+inputType);
+ throw new IllegalArgumentException("Peeking non-existing dimension '" + name + "'");
}
}
if (map.isEmpty()) {
@@ -79,10 +79,10 @@ public class TypeResolver {
static public TensorType rename(TensorType inputType, List<String> from, List<String> to) {
if (from.isEmpty()) {
- throw new IllegalArgumentException("renaming no dimensions");
+ throw new IllegalArgumentException("Renaming no dimensions");
}
if (from.size() != to.size()) {
- throw new IllegalArgumentException("bad rename, from size "+from.size()+" != to.size "+to.size());
+ throw new IllegalArgumentException("Bad rename, from size "+from.size()+" != to.size "+to.size());
}
Map<String,Dimension> oldDims = new HashMap<>();
for (Dimension dim : inputType.dimensions()) {
@@ -96,7 +96,7 @@ public class TypeResolver {
var dim = oldDims.remove(oldName);
newDims.put(newName, dim.withName(newName));
} else {
- logger.log(Level.WARNING, "renaming non-existing dimension "+oldName+" in type "+inputType);
+ logger.log(Level.WARNING, "Renaming non-existing dimension "+oldName+" in type "+inputType);
// throw new IllegalArgumentException("bad rename, dimension "+oldName+" not found");
}
}
@@ -106,13 +106,13 @@ public class TypeResolver {
if (inputType.dimensions().size() == newDims.size()) {
return new TensorType(inputType.valueType(), newDims.values());
} else {
- throw new IllegalArgumentException("bad rename, lost some dimenions");
+ throw new IllegalArgumentException("Bad rename, lost some dimensions");
}
}
static public TensorType cell_cast(TensorType inputType, Value toCellType) {
if (toCellType != Value.DOUBLE && inputType.dimensions().isEmpty()) {
- throw new IllegalArgumentException("cannot cast "+inputType+" to valueType"+toCellType);
+ throw new IllegalArgumentException("Cannot cast "+inputType+" to valueType"+toCellType);
}
return new TensorType(toCellType, inputType.dimensions());
}
@@ -188,7 +188,7 @@ public class TypeResolver {
if (allOk) {
return join(lhs, rhs);
} else {
- throw new IllegalArgumentException("types in merge() dimensions mismatch: "+lhs+" != "+rhs);
+ throw new IllegalArgumentException("Types in merge() dimensions mismatch: "+lhs+" != "+rhs);
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index a376536015a..dbc8396d701 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -58,7 +58,7 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return name;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 16ca7104f8d..55dd8a7bc8a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -48,7 +48,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index fcdc1233550..f1f0b9d67b0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -48,7 +48,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index 8c6c27e171a..09f84e6747e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -107,7 +107,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 32a4c8cd2ff..6d4b15be991 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -285,7 +285,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 1544369ba2f..a0fd9272f54 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -47,6 +47,6 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
@Override
- public String toString(ToStringContext context) { return constant.toString(); }
+ public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index 2c0fa483021..92d89ec68f7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -41,7 +41,7 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 97126ad88a7..46992115c23 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -46,11 +46,11 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
TensorType type() { return type; }
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return type().toString() + ":" + contentToString(context);
}
- abstract String contentToString(ToStringContext context);
+ abstract String contentToString(ToStringContext<NAMETYPE> context);
/** Creates a dynamic tensor function. The cell addresses must match the type. */
public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
@@ -80,7 +80,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
}
@Override
- String contentToString(ToStringContext context) {
+ String contentToString(ToStringContext<NAMETYPE> context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.values().iterator().next().toString(context) + "}";
@@ -121,7 +121,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
}
@Override
- String contentToString(ToStringContext context) {
+ String contentToString(ToStringContext<NAMETYPE> context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.get(0).toString(context) + "}";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
index 8fc246a7d9d..c049e5d41da 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
@@ -41,7 +41,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "expand(" + argument.toString(context) + ", " + dimensionName + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 89e981df49e..54e83fa472f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -117,9 +117,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; }
+ public String toString(ToStringContext<NAMETYPE> context) { return type + "(" + generatorToString(context) + ")"; }
- private String generatorToString(ToStringContext context) {
+ private String generatorToString(ToStringContext<NAMETYPE> context) {
if (freeGenerator != null)
return freeGenerator.toString();
else
@@ -183,11 +183,11 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
/** A context which adds the bindings of the generate dimension names to the given context. */
- private class GenerateToStringContext implements ToStringContext {
+ private class GenerateToStringContext implements ToStringContext<NAMETYPE> {
- private final ToStringContext context;
+ private final ToStringContext<NAMETYPE> context;
- public GenerateToStringContext(ToStringContext context) {
+ public GenerateToStringContext(ToStringContext<NAMETYPE> context) {
this.context = context;
}
@@ -200,7 +200,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public ToStringContext parent() { return context; }
+ public ToStringContext<NAMETYPE> parent() { return context; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 0d4aeb5c37d..52bef482fb4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -75,7 +75,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 903d0b2dcd9..f47202d1b9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -39,7 +39,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index c862aa4eaf6..8f4e2f466d4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -41,7 +41,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 40620cb95fe..46772d8cbff 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -71,7 +71,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "map(" + argument.toString(context) + ", " + mapper + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 810e01011fe..8ac6d711c48 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -45,7 +45,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index bd42e95a59e..adc84225a63 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -70,7 +70,7 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 8cf1964585a..18c5db8e3a7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -38,7 +38,7 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index 7d5b11d6672..45b827db900 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -42,7 +42,7 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 79209fd8f09..8841cff15e9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -86,7 +86,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 9b56fefb5f0..7505355beed 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -314,7 +314,7 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "reduce_join(" + argumentA.toString(context) + ", " +
argumentB.toString(context) + ", " +
combinator + ", " +
@@ -324,8 +324,8 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
private static class MultiDimensionIterator {
- private long[] bounds;
- private long[] iterator;
+ private final long[] bounds;
+ private final long[] iterator;
private int remaining;
MultiDimensionIterator(TensorType type) {
@@ -364,9 +364,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
remaining -= 1;
}
+ @Override
public String toString() {
return Arrays.toString(iterator);
}
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 67ede7f6540..a434ecba5cc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -128,7 +128,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index 11e52aad73e..0e0dc9a9aa8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -20,6 +20,6 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati
/** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */
default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); }
- default String toString(ToStringContext context) { return toString(); }
+ default String toString(ToStringContext<NAMETYPE> context) { return toString(); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 09bfb8b996b..da7581c39f9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -121,8 +121,8 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
private TensorType resultType(TensorType argumentType) {
List<String> peekDimensions;
- // Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ // Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.get(0).index().isPresent()) {
peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
if (peekDimensions.size() > 1) {
@@ -140,22 +140,28 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
else { // general slicing
peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList());
}
- if (peekDimensions.isEmpty())
- throw new IllegalArgumentException(this + " cannot slice " + argumentType + ": No dimensions to slice");
- return TypeResolver.peek(argumentType, peekDimensions);
+ try {
+ return TypeResolver.peek(argumentType, peekDimensions);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e);
+ }
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
StringBuilder b = new StringBuilder(argument.toString(context));
- if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ if (context.typeContext().isEmpty()
+ && subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms
if (subspaceAddress.get(0).index().isPresent())
b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
else
b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
}
- else {
- b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
+ else { // general form
+ b.append("{").append(subspaceAddress.stream()
+ .map(i -> i.toString(context, this))
+ .collect(Collectors.joining(", "))).append("}");
}
return b.toString();
}
@@ -222,12 +228,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
@Override
public String toString() {
- return toString(ToStringContext.empty());
+ return toString(null, null);
}
- public String toString(ToStringContext context) {
+ String toString(ToStringContext<NAMETYPE> context, Slice<NAMETYPE> owner) {
StringBuilder b = new StringBuilder();
- dimension.ifPresent(d -> b.append(d).append(":"));
+ Optional<String> dimensionName = dimension;
+ if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail
+ TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null;
+ if (type == null || type.dimensions().size() != 1)
+ throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner +
+ " cannot be uniquely resolved. Use the full form " +
+ "slice{myDimensionName: ...");
+ else
+ dimensionName = Optional.of(type.dimensions().get(0).name());
+ }
+ dimensionName.ifPresent(d -> b.append(d).append(":"));
if (label != null)
b.append(label);
else
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index 13420a12e8f..9ea9040831b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -46,7 +46,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "softmax(" + argument.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 81d3692bd94..1e1d1d3b5b9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -60,7 +60,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
*
* @param context a context which must be passed to all nested functions when requesting the string value
*/
- public abstract String toString(ToStringContext context);
+ public abstract String toString(ToStringContext<NAMETYPE> context);
/** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */
public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
index 1c8da9a1dca..233779fcebe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
@@ -1,31 +1,42 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Optional;
+
/**
* A context which is passed down to all nested functions when returning a string representation.
*
* @author bratseth
*/
-public interface ToStringContext {
+public interface ToStringContext<NAMETYPE extends Name> {
- static ToStringContext empty() { return new EmptyStringContext(); }
+ static <NAMETYPE extends Name> ToStringContext<NAMETYPE> empty() { return new EmptyStringContext<NAMETYPE>(); }
/** Returns the name an identifier is bound to, or null if not bound in this context */
String getBinding(String name);
/**
+ * Returns the context used to resolve types in this, if present.
+ * In some functions serialization depends on type information.
+ */
+ default Optional<TypeContext<NAMETYPE>> typeContext() { return Optional.empty(); }
+
+ /**
* Returns the parent context of this (the context we're in scope of when this is created),
* or null if this is the root.
*/
- ToStringContext parent();
+ ToStringContext<NAMETYPE> parent();
- class EmptyStringContext implements ToStringContext {
+ class EmptyStringContext<NAMETYPE extends Name> implements ToStringContext<NAMETYPE> {
@Override
public String getBinding(String name) { return null; }
@Override
- public ToStringContext parent() { return null; }
+ public ToStringContext<NAMETYPE> parent() { return null; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 112a0d43796..0223ad4d588 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -44,7 +44,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "xw_plus_b(" + x.toString(context) + ", " +
w.toString(context) + ", " +
b.toString(context) + ", " +