aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2022-02-16 16:43:20 +0100
committerJon Bratseth <bratseth@gmail.com>2022-02-16 16:43:20 +0100
commite25d913b884339afc4f8e3073e4e4b795e55d930 (patch)
tree408e9fded165a07fae202fd691f6f2864680ac63
parent6f99bd502132cd378124a40060ac1d74d54f5e92 (diff)
Resolve slice dimension
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java1
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java3
-rw-r--r--config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java17
-rw-r--r--config-model/src/test/derived/slice/rank-profiles.cfg25
-rw-r--r--config-model/src/test/derived/slice/test.sd8
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java5
-rw-r--r--searchlib/abi-spec.json4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java8
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java33
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java30
-rw-r--r--vespajlib/abi-spec.json4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java38
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java2
40 files changed, 195 insertions, 94 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
index fef7ff56763..08475813317 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java
@@ -104,7 +104,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement
@Override
public TensorType getType(Reference reference) {
// computeIfAbsent without concurrent modification due to resolve adding more resolved entries:
-
boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference);
TensorType resolvedType = resolvedTypes.get(reference);
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
index 5b842b002bd..49c52b21907 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java
@@ -998,7 +998,8 @@ public class RankProfile implements Cloneable {
return featureTypes;
}
- public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles, Map<Reference, TensorType> featureTypes) {
+ public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles,
+ Map<Reference, TensorType> featureTypes) {
MapEvaluationTypeContext context = new MapEvaluationTypeContext(getExpressionFunctions(), featureTypes);
// Add small and large constants, respectively
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
index 75703c33f07..9370075bcf3 100644
--- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
+++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java
@@ -63,8 +63,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
QueryProfileRegistry queryProfiles, ImportedMlModels importedModels,
AttributeFields attributeFields, ModelContext.Properties deployProperties) {
this.name = rankProfile.name();
- compressedProperties = compress(new Deriver(rankProfile.compile(queryProfiles, importedModels),
- attributeFields, deployProperties).derive(largeExpressions));
+ compressedProperties = compress(new Deriver(rankProfile.compile(queryProfiles, importedModels), attributeFields, deployProperties, queryProfiles)
+ .derive(largeExpressions));
}
private Compressor.Compression compress(List<Pair<String, String>> properties) {
@@ -156,7 +156,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
/**
* Creates a raw rank profile from the given rank profile
*/
- Deriver(RankProfile compiled, AttributeFields attributeFields, ModelContext.Properties deployProperties) {
+ Deriver(RankProfile compiled,
+ AttributeFields attributeFields,
+ ModelContext.Properties deployProperties,
+ QueryProfileRegistry queryProfiles) {
rankprofileName = compiled.name();
attributeTypes = compiled.getAttributeTypes();
queryFeatureTypes = compiled.getQueryFeatureTypes();
@@ -179,7 +182,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions();
List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList());
Map<String, String> functionProperties = new LinkedHashMap<>();
- SerializationContext functionSerializationContext = new SerializationContext(functionExpressions);
+ SerializationContext functionSerializationContext = new SerializationContext(functionExpressions,
+ Map.of(),
+ compiled.typeContext(queryProfiles));
if (firstPhaseRanking != null) {
functionProperties.putAll(firstPhaseRanking.getRankProperties(functionSerializationContext));
@@ -201,8 +206,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer {
}
private void derivePropertiesAndFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions,
- Map<String, String> functionProperties,
- SerializationContext functionContext) {
+ Map<String, String> functionProperties,
+ SerializationContext functionContext) {
if (functions.isEmpty()) return;
replaceFunctionFeatures(summaryFeatures, functionContext);
diff --git a/config-model/src/test/derived/slice/rank-profiles.cfg b/config-model/src/test/derived/slice/rank-profiles.cfg
new file mode 100644
index 00000000000..75725b81ecf
--- /dev/null
+++ b/config-model/src/test/derived/slice/rank-profiles.cfg
@@ -0,0 +1,25 @@
+rankprofile[].name "default"
+rankprofile[].fef.property[].name "vespa.type.query.myTensor"
+rankprofile[].fef.property[].value "tensor<float>(key{})"
+rankprofile[].name "unranked"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "value(0)"
+rankprofile[].fef.property[].name "vespa.hitcollector.heapsize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.hitcollector.arraysize"
+rankprofile[].fef.property[].value "0"
+rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures"
+rankprofile[].fef.property[].value "true"
+rankprofile[].fef.property[].name "vespa.type.query.myTensor"
+rankprofile[].fef.property[].value "tensor<float>(key{})"
+rankprofile[].name "parent"
+rankprofile[].fef.property[].name "rankingExpression(mySlice@77dee0712164ce73).rankingScript"
+rankprofile[].fef.property[].value "query(myTensor){key:MY_KEY2}"
+rankprofile[].fef.property[].name "rankingExpression(myFunction).rankingScript"
+rankprofile[].fef.property[].value "4 * query(myTensor){key:MY_KEY1} * rankingExpression(mySlice@77dee0712164ce73)"
+rankprofile[].fef.property[].name "rankingExpression(myValue).rankingScript"
+rankprofile[].fef.property[].value "4"
+rankprofile[].fef.property[].name "rankingExpression(mySlice).rankingScript"
+rankprofile[].fef.property[].value "myTensor{key:MY_KEY2}"
+rankprofile[].fef.property[].name "vespa.type.query.myTensor"
+rankprofile[].fef.property[].value "tensor<float>(key{})"
diff --git a/config-model/src/test/derived/slice/test.sd b/config-model/src/test/derived/slice/test.sd
index fbb581d1b1d..c2060300785 100644
--- a/config-model/src/test/derived/slice/test.sd
+++ b/config-model/src/test/derived/slice/test.sd
@@ -5,8 +5,8 @@ search test {
rank-profile parent {
- function inline cpmScore() {
- expression: myValue * mySlice(query(myTensor))
+ function inline myFunction() {
+ expression: myValue * query(myTensor){MY_KEY1} * mySlice(query(myTensor))
}
function inline myValue() {
@@ -14,7 +14,9 @@ search test {
}
function inline mySlice(myTensor) {
- expression: myTensor{"NULL"}
+ # TODO: We are missing type resolving across function calls in serialization,
+ # so using the short form (without 'key') here will fail
+ expression: myTensor{key:MY_KEY2}
}
}
diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java
index df87de2a12b..b3d2e6451f3 100644
--- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java
+++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java
@@ -51,7 +51,7 @@ public final class DocsumDefinitionSet {
}
if (ds == null) {
throw new ConfigurationException("Fetched hit with summary class " + summaryClass +
- ", but this summary class is not in current summary config (" + toString() + ")" +
+ ", but this summary class is not in current summary config (" + this + ")" +
" (that is, you asked for something unknown, and no default was found)");
}
return ds;
@@ -66,7 +66,7 @@ public final class DocsumDefinitionSet {
* @return Error message or null on success.
* @throws ConfigurationException if the summary class of this hit is missing
*/
- public final String lazyDecode(String summaryClass, byte[] data, FastHit hit) {
+ public String lazyDecode(String summaryClass, byte[] data, FastHit hit) {
ByteBuffer buffer = ByteBuffer.wrap(data);
buffer.order(ByteOrder.LITTLE_ENDIAN);
long docsumClassId = buffer.getInt();
@@ -83,6 +83,7 @@ public final class DocsumDefinitionSet {
return null;
}
+ @Override
public String toString() {
StringBuilder sb = new StringBuilder();
for (Map.Entry<String, DocsumDefinition> e : definitionsByName.entrySet() ) {
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 2d7daf2300e..5d7e281df87 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1611,8 +1611,10 @@
"public void <init>(java.util.Collection)",
"public void <init>(java.util.Map)",
"public void <init>(java.util.Collection, java.util.Map)",
- "public void <init>(java.util.Collection, java.util.Map, java.util.Map)",
+ "public void <init>(java.util.Collection, java.util.Map, com.yahoo.tensor.evaluation.TypeContext)",
+ "public java.util.Optional typeContext()",
"public void <init>(java.util.Map, java.util.Map, java.util.Map)",
+ "public void <init>(java.util.Map, java.util.Map, java.util.Optional, java.util.Map)",
"public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map)",
"public void addFunctionSerialization(java.lang.String, java.lang.String)",
"public void addArgumentTypeSerialization(java.lang.String, java.lang.String, com.yahoo.tensor.TensorType)",
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
index fe22a5b1267..e770e6ac038 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java
@@ -15,7 +15,9 @@ import java.util.List;
* @author bratseth
*/
public final class Arguments implements Serializable {
+
public static final Arguments EMPTY = new Arguments();
+
private final ImmutableList<ExpressionNode> expressions;
public Arguments() {
@@ -47,9 +49,9 @@ public final class Arguments implements Serializable {
/** Evaluate all arguments in this */
public Value[] evaluate(Context context) {
- Value[] values=new Value[expressions.size()];
- for (int i=0; i<expressions.size(); i++)
- values[i]=expressions.get(i).evaluate(context);
+ Value[] values = new Value[expressions.size()];
+ for (int i = 0; i < expressions.size(); i++)
+ values[i] = expressions.get(i).evaluate(context);
return values;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
index 93b8b8aca5e..8ac1829b16b 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java
@@ -42,9 +42,11 @@ public final class ReferenceNode extends CompositeNode {
return reference.name();
}
+ @Override
public int hashCode() {
return reference.hashCode();
}
+
/** Returns the arguments, never null */
public Arguments getArguments() { return reference.arguments(); }
@@ -118,7 +120,7 @@ public final class ReferenceNode extends CompositeNode {
throw new IllegalArgumentException(reference + " is invalid", e);
}
if (type == null)
- throw new IllegalArgumentException("Unknown feature '" + toString() + "'");
+ throw new IllegalArgumentException("Unknown feature '" + this + "'");
return type;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index 535ad013caf..1f3203f2e35 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -4,13 +4,16 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.Map;
+import java.util.Optional;
/**
* Context needed to serialize an expression to a string. This has the lifetime of a single serialization
@@ -22,6 +25,8 @@ public class SerializationContext extends FunctionReferenceContext {
/** Serialized form of functions indexed by name */
private final Map<String, String> serializedFunctions;
+ private final Optional<TypeContext<Reference>> typeContext;
+
/** Create a context for a single serialization task */
public SerializationContext() {
this(Collections.emptyList());
@@ -29,7 +34,7 @@ public class SerializationContext extends FunctionReferenceContext {
/** Create a context for a single serialization task */
public SerializationContext(Collection<ExpressionFunction> functions) {
- this(functions, Collections.emptyMap(), new LinkedHashMap<>());
+ this(functions, Collections.emptyMap(), Optional.empty(), new LinkedHashMap<>());
}
/** Create a context for a single serialization task */
@@ -39,7 +44,13 @@ public class SerializationContext extends FunctionReferenceContext {
/** Create a context for a single serialization task */
public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) {
- this(functions, bindings, new LinkedHashMap<>());
+ this(functions, bindings, Optional.empty(), new LinkedHashMap<>());
+ }
+
+ /** Create a context for a single serialization task */
+ public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings,
+ TypeContext<Reference> typeContext) {
+ this(functions, bindings, Optional.of(typeContext), new LinkedHashMap<>());
}
/**
@@ -47,14 +58,19 @@ public class SerializationContext extends FunctionReferenceContext {
*
* @param functions the functions of this
* @param bindings the arguments of this
+ * @param typeContext the type context of this: Serialization may depend on type resolution
* @param serializedFunctions a cache of serializedFunctions - the ownership of this map
* is <b>transferred</b> to this and will be modified in it
*/
- public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings,
+ private SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings,
+ Optional<TypeContext<Reference>> typeContext,
Map<String, String> serializedFunctions) {
- this(toMap(functions), bindings, serializedFunctions);
+ this(toMap(functions), bindings, typeContext, serializedFunctions);
}
+ /** Returns the type context of this, if it is able to resolve types. */
+ public Optional<TypeContext<Reference>> typeContext() { return typeContext; }
+
private static Map<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) {
Map<String,ExpressionFunction> mapBuilder = new HashMap<>();
for (ExpressionFunction function : list)
@@ -70,9 +86,16 @@ public class SerializationContext extends FunctionReferenceContext {
* @param serializedFunctions a cache of serializedFunctions - the ownership of this map
* is <b>transferred</b> to this and will be modified in it
*/
+ public SerializationContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings,
+ Map<String, String> serializedFunctions) {
+ this(functions, bindings, Optional.empty(), serializedFunctions);
+ }
+
public SerializationContext(Map<String,ExpressionFunction> functions, Map<String, String> bindings,
+ Optional<TypeContext<Reference>> typeContext,
Map<String, String> serializedFunctions) {
super(functions, bindings);
+ this.typeContext = typeContext;
this.serializedFunctions = serializedFunctions;
}
@@ -88,7 +111,7 @@ public class SerializationContext extends FunctionReferenceContext {
serializedFunctions.put(name, expressionString);
}
- /** Adds the serialization of the an argument type to a function */
+ /** Adds the serialization of the argument type to a function */
public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) {
serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString());
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index d873963bb6e..52d54c9163e 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -168,8 +168,8 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public String toString(ToStringContext c) {
- ToStringContext outermost = c;
+ public String toString(ToStringContext<Reference> c) {
+ ToStringContext<Reference> outermost = c;
while (outermost.parent() != null)
outermost = outermost.parent();
@@ -251,15 +251,17 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public String toString(ToStringContext c) {
- ToStringContext outermost = c;
+ public String toString(ToStringContext<Reference> c) {
+ ToStringContext<Reference> outermost = c;
while (outermost.parent() != null)
outermost = outermost.parent();
if (outermost instanceof ExpressionToStringContext) {
ExpressionToStringContext context = (ExpressionToStringContext)outermost;
return expression.toString(new StringBuilder(),
- new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent),
+ new ExpressionToStringContext(context.wrappedSerializationContext, c,
+ context.path,
+ context.parent),
context.path,
context.parent)
.toString();
@@ -281,9 +283,9 @@ public class TensorFunctionNode extends CompositeNode {
* to add more context, we need to keep track of both these contexts here separately and map between
* contexts as seen in the toString methods of the function classes above.
*/
- private static class ExpressionToStringContext extends SerializationContext implements ToStringContext {
+ private static class ExpressionToStringContext extends SerializationContext implements ToStringContext<Reference> {
- private final ToStringContext wrappedToStringContext;
+ private final ToStringContext<Reference> wrappedToStringContext;
private final SerializationContext wrappedSerializationContext;
private final Deque<String> path;
private final CompositeNode parent;
@@ -297,7 +299,7 @@ public class TensorFunctionNode extends CompositeNode {
}
ExpressionToStringContext(SerializationContext wrappedSerializationContext,
- ToStringContext wrappedToStringContext,
+ ToStringContext<Reference> wrappedToStringContext,
Deque<String> path,
CompositeNode parent) {
this.wrappedSerializationContext = wrappedSerializationContext;
@@ -328,6 +330,12 @@ public class TensorFunctionNode extends CompositeNode {
/** Returns a function or null if it isn't defined in this context */
public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); }
+ /** Returns the type context of this, or empty if none. */
+ @Override
+ public Optional<TypeContext<Reference>> typeContext() {
+ return wrappedSerializationContext.typeContext();
+ }
+
/** @deprecated Use {@link #getFunctions()} instead */
@SuppressWarnings("removal")
@Deprecated(forRemoval = true, since = "7")
@@ -335,9 +343,10 @@ public class TensorFunctionNode extends CompositeNode {
return ImmutableMap.copyOf(wrappedSerializationContext.getFunctions());
}
- @Override protected Map<String, ExpressionFunction> getFunctions() { return wrappedSerializationContext.getFunctions(); }
+ @Override
+ protected Map<String, ExpressionFunction> getFunctions() { return wrappedSerializationContext.getFunctions(); }
- public ToStringContext parent() { return wrappedToStringContext; }
+ public ToStringContext<Reference> parent() { return wrappedToStringContext; }
/** Returns the resolution of an identifier, or null if it isn't defined in this context */
@Override
@@ -361,6 +370,7 @@ public class TensorFunctionNode extends CompositeNode {
SerializationContext serializationContext = new SerializationContext(getFunctions(), null, serializedFunctions());
return new ExpressionToStringContext(serializationContext, null, path, parent);
}
+
}
/** Turns an EvaluationContext into a Context */
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 71e6c0f28f4..a30ee055538 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -2658,8 +2658,7 @@
"public java.util.Optional dimension()",
"public java.util.Optional label()",
"public java.util.Optional index()",
- "public java.lang.String toString()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ "public java.lang.String toString()"
],
"fields": []
},
@@ -2744,6 +2743,7 @@
"methods": [
"public static com.yahoo.tensor.functions.ToStringContext empty()",
"public abstract java.lang.String getBinding(java.lang.String)",
+ "public java.util.Optional typeContext()",
"public abstract com.yahoo.tensor.functions.ToStringContext parent()"
],
"fields": []
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
index 457cfcbfa5f..3b12b6bdba1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
@@ -57,7 +57,7 @@ public class TypeResolver {
static public TensorType peek(TensorType inputType, List<String> peekDimensions) {
if (peekDimensions.isEmpty()) {
- throw new IllegalArgumentException("peeking no dimensions makes no sense");
+ throw new IllegalArgumentException("Peeking no dimensions makes no sense");
}
Map<String, Dimension> map = new HashMap<>();
for (Dimension dim : inputType.dimensions()) {
@@ -67,7 +67,7 @@ public class TypeResolver {
if (map.containsKey(name)) {
map.remove(name);
} else {
- throw new IllegalArgumentException("peeking non-existing dimension "+name+" in type "+inputType);
+ throw new IllegalArgumentException("Peeking non-existing dimension '" + name + "'");
}
}
if (map.isEmpty()) {
@@ -79,10 +79,10 @@ public class TypeResolver {
static public TensorType rename(TensorType inputType, List<String> from, List<String> to) {
if (from.isEmpty()) {
- throw new IllegalArgumentException("renaming no dimensions");
+ throw new IllegalArgumentException("Renaming no dimensions");
}
if (from.size() != to.size()) {
- throw new IllegalArgumentException("bad rename, from size "+from.size()+" != to.size "+to.size());
+ throw new IllegalArgumentException("Bad rename, from size "+from.size()+" != to.size "+to.size());
}
Map<String,Dimension> oldDims = new HashMap<>();
for (Dimension dim : inputType.dimensions()) {
@@ -96,7 +96,7 @@ public class TypeResolver {
var dim = oldDims.remove(oldName);
newDims.put(newName, dim.withName(newName));
} else {
- logger.log(Level.WARNING, "renaming non-existing dimension "+oldName+" in type "+inputType);
+ logger.log(Level.WARNING, "Renaming non-existing dimension "+oldName+" in type "+inputType);
// throw new IllegalArgumentException("bad rename, dimension "+oldName+" not found");
}
}
@@ -106,13 +106,13 @@ public class TypeResolver {
if (inputType.dimensions().size() == newDims.size()) {
return new TensorType(inputType.valueType(), newDims.values());
} else {
- throw new IllegalArgumentException("bad rename, lost some dimenions");
+ throw new IllegalArgumentException("Bad rename, lost some dimensions");
}
}
static public TensorType cell_cast(TensorType inputType, Value toCellType) {
if (toCellType != Value.DOUBLE && inputType.dimensions().isEmpty()) {
- throw new IllegalArgumentException("cannot cast "+inputType+" to valueType"+toCellType);
+ throw new IllegalArgumentException("Cannot cast "+inputType+" to valueType"+toCellType);
}
return new TensorType(toCellType, inputType.dimensions());
}
@@ -188,7 +188,7 @@ public class TypeResolver {
if (allOk) {
return join(lhs, rhs);
} else {
- throw new IllegalArgumentException("types in merge() dimensions mismatch: "+lhs+" != "+rhs);
+ throw new IllegalArgumentException("Types in merge() dimensions mismatch: "+lhs+" != "+rhs);
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index a376536015a..dbc8396d701 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -58,7 +58,7 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return name;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 16ca7104f8d..55dd8a7bc8a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -48,7 +48,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index fcdc1233550..f1f0b9d67b0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -48,7 +48,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index 8c6c27e171a..09f84e6747e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -107,7 +107,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "cell_cast(" + argument.toString(context) + ", " + valueType + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 32a4c8cd2ff..6d4b15be991 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -285,7 +285,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 1544369ba2f..a0fd9272f54 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -47,6 +47,6 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti
public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
@Override
- public String toString(ToStringContext context) { return constant.toString(); }
+ public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index 2c0fa483021..92d89ec68f7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -41,7 +41,7 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 97126ad88a7..46992115c23 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -46,11 +46,11 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
TensorType type() { return type; }
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return type().toString() + ":" + contentToString(context);
}
- abstract String contentToString(ToStringContext context);
+ abstract String contentToString(ToStringContext<NAMETYPE> context);
/** Creates a dynamic tensor function. The cell addresses must match the type. */
public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
@@ -80,7 +80,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
}
@Override
- String contentToString(ToStringContext context) {
+ String contentToString(ToStringContext<NAMETYPE> context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.values().iterator().next().toString(context) + "}";
@@ -121,7 +121,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens
}
@Override
- String contentToString(ToStringContext context) {
+ String contentToString(ToStringContext<NAMETYPE> context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.get(0).toString(context) + "}";
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
index 8fc246a7d9d..c049e5d41da 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
@@ -41,7 +41,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "expand(" + argument.toString(context) + ", " + dimensionName + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 89e981df49e..54e83fa472f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -117,9 +117,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; }
+ public String toString(ToStringContext<NAMETYPE> context) { return type + "(" + generatorToString(context) + ")"; }
- private String generatorToString(ToStringContext context) {
+ private String generatorToString(ToStringContext<NAMETYPE> context) {
if (freeGenerator != null)
return freeGenerator.toString();
else
@@ -183,11 +183,11 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
/** A context which adds the bindings of the generate dimension names to the given context. */
- private class GenerateToStringContext implements ToStringContext {
+ private class GenerateToStringContext implements ToStringContext<NAMETYPE> {
- private final ToStringContext context;
+ private final ToStringContext<NAMETYPE> context;
- public GenerateToStringContext(ToStringContext context) {
+ public GenerateToStringContext(ToStringContext<NAMETYPE> context) {
this.context = context;
}
@@ -200,7 +200,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
}
@Override
- public ToStringContext parent() { return context; }
+ public ToStringContext<NAMETYPE> parent() { return context; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 0d4aeb5c37d..52bef482fb4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -75,7 +75,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 903d0b2dcd9..f47202d1b9f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -39,7 +39,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index c862aa4eaf6..8f4e2f466d4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -41,7 +41,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 40620cb95fe..46772d8cbff 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -71,7 +71,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "map(" + argument.toString(context) + ", " + mapper + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 810e01011fe..8ac6d711c48 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -45,7 +45,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index bd42e95a59e..adc84225a63 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -70,7 +70,7 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 8cf1964585a..18c5db8e3a7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -38,7 +38,7 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index 7d5b11d6672..45b827db900 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -42,7 +42,7 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction;
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 79209fd8f09..8841cff15e9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -86,7 +86,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 9b56fefb5f0..7505355beed 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -314,7 +314,7 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "reduce_join(" + argumentA.toString(context) + ", " +
argumentB.toString(context) + ", " +
combinator + ", " +
@@ -324,8 +324,8 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
private static class MultiDimensionIterator {
- private long[] bounds;
- private long[] iterator;
+ private final long[] bounds;
+ private final long[] iterator;
private int remaining;
MultiDimensionIterator(TensorType type) {
@@ -364,9 +364,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N
remaining -= 1;
}
+ @Override
public String toString() {
return Arrays.toString(iterator);
}
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 67ede7f6540..a434ecba5cc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -128,7 +128,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "rename(" + argument.toString(context) + ", " +
toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index 11e52aad73e..0e0dc9a9aa8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -20,6 +20,6 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati
/** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */
default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); }
- default String toString(ToStringContext context) { return toString(); }
+ default String toString(ToStringContext<NAMETYPE> context) { return toString(); }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
index 09bfb8b996b..da7581c39f9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -121,8 +121,8 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
private TensorType resultType(TensorType argumentType) {
List<String> peekDimensions;
- // Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ // Special case where a single indexed or mapped dimension is sliced
if (subspaceAddress.get(0).index().isPresent()) {
peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed);
if (peekDimensions.size() > 1) {
@@ -140,22 +140,28 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
else { // general slicing
peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList());
}
- if (peekDimensions.isEmpty())
- throw new IllegalArgumentException(this + " cannot slice " + argumentType + ": No dimensions to slice");
- return TypeResolver.peek(argumentType, peekDimensions);
+ try {
+ return TypeResolver.peek(argumentType, peekDimensions);
+ }
+ catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e);
+ }
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
StringBuilder b = new StringBuilder(argument.toString(context));
- if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ if (context.typeContext().isEmpty()
+ && subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms
if (subspaceAddress.get(0).index().isPresent())
b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
else
b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
}
- else {
- b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
+ else { // general form
+ b.append("{").append(subspaceAddress.stream()
+ .map(i -> i.toString(context, this))
+ .collect(Collectors.joining(", "))).append("}");
}
return b.toString();
}
@@ -222,12 +228,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
@Override
public String toString() {
- return toString(ToStringContext.empty());
+ return toString(null, null);
}
- public String toString(ToStringContext context) {
+ String toString(ToStringContext<NAMETYPE> context, Slice<NAMETYPE> owner) {
StringBuilder b = new StringBuilder();
- dimension.ifPresent(d -> b.append(d).append(":"));
+ Optional<String> dimensionName = dimension;
+ if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail
+ TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null;
+ if (type == null || type.dimensions().size() != 1)
+ throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner +
+ " cannot be uniquely resolved. Use the full form " +
+ "slice{myDimensionName: ...");
+ else
+ dimensionName = Optional.of(type.dimensions().get(0).name());
+ }
+ dimensionName.ifPresent(d -> b.append(d).append(":"));
if (label != null)
b.append(label);
else
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index 13420a12e8f..9ea9040831b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -46,7 +46,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "softmax(" + argument.toString(context) + ", " + dimension + ")";
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 81d3692bd94..1e1d1d3b5b9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -60,7 +60,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
*
* @param context a context which must be passed to all nested functions when requesting the string value
*/
- public abstract String toString(ToStringContext context);
+ public abstract String toString(ToStringContext<NAMETYPE> context);
/** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */
public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
index 1c8da9a1dca..233779fcebe 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java
@@ -1,31 +1,42 @@
// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Optional;
+
/**
* A context which is passed down to all nested functions when returning a string representation.
*
* @author bratseth
*/
-public interface ToStringContext {
+public interface ToStringContext<NAMETYPE extends Name> {
- static ToStringContext empty() { return new EmptyStringContext(); }
+ static <NAMETYPE extends Name> ToStringContext<NAMETYPE> empty() { return new EmptyStringContext<NAMETYPE>(); }
/** Returns the name an identifier is bound to, or null if not bound in this context */
String getBinding(String name);
/**
+ * Returns the context used to resolve types in this, if present.
+ * In some functions serialization depends on type information.
+ */
+ default Optional<TypeContext<NAMETYPE>> typeContext() { return Optional.empty(); }
+
+ /**
* Returns the parent context of this (the context we're in scope of when this is created),
* or null if this is the root.
*/
- ToStringContext parent();
+ ToStringContext<NAMETYPE> parent();
- class EmptyStringContext implements ToStringContext {
+ class EmptyStringContext<NAMETYPE extends Name> implements ToStringContext<NAMETYPE> {
@Override
public String getBinding(String name) { return null; }
@Override
- public ToStringContext parent() { return null; }
+ public ToStringContext<NAMETYPE> parent() { return null; }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 112a0d43796..0223ad4d588 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -44,7 +44,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME
}
@Override
- public String toString(ToStringContext context) {
+ public String toString(ToStringContext<NAMETYPE> context) {
return "xw_plus_b(" + x.toString(context) + ", " +
w.toString(context) + ", " +
b.toString(context) + ", " +