diff options
author | Jon Bratseth <bratseth@gmail.com> | 2022-02-16 16:43:20 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2022-02-16 16:43:20 +0100 |
commit | e25d913b884339afc4f8e3073e4e4b795e55d930 (patch) | |
tree | 408e9fded165a07fae202fd691f6f2864680ac63 | |
parent | 6f99bd502132cd378124a40060ac1d74d54f5e92 (diff) |
Resolve slice dimension
40 files changed, 195 insertions, 94 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index fef7ff56763..08475813317 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -104,7 +104,6 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement @Override public TensorType getType(Reference reference) { // computeIfAbsent without concurrent modification due to resolve adding more resolved entries: - boolean canBeResolvedGlobally = referenceCanBeResolvedGlobally(reference); TensorType resolvedType = resolvedTypes.get(reference); diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java index 5b842b002bd..49c52b21907 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/RankProfile.java @@ -998,7 +998,8 @@ public class RankProfile implements Cloneable { return featureTypes; } - public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles, Map<Reference, TensorType> featureTypes) { + public MapEvaluationTypeContext typeContext(QueryProfileRegistry queryProfiles, + Map<Reference, TensorType> featureTypes) { MapEvaluationTypeContext context = new MapEvaluationTypeContext(getExpressionFunctions(), featureTypes); // Add small and large constants, respectively diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java index 75703c33f07..9370075bcf3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/derived/RawRankProfile.java @@ -63,8 +63,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { QueryProfileRegistry queryProfiles, ImportedMlModels importedModels, AttributeFields attributeFields, ModelContext.Properties deployProperties) { this.name = rankProfile.name(); - compressedProperties = compress(new Deriver(rankProfile.compile(queryProfiles, importedModels), - attributeFields, deployProperties).derive(largeExpressions)); + compressedProperties = compress(new Deriver(rankProfile.compile(queryProfiles, importedModels), attributeFields, deployProperties, queryProfiles) + .derive(largeExpressions)); } private Compressor.Compression compress(List<Pair<String, String>> properties) { @@ -156,7 +156,10 @@ public class RawRankProfile implements RankProfilesConfig.Producer { /** * Creates a raw rank profile from the given rank profile */ - Deriver(RankProfile compiled, AttributeFields attributeFields, ModelContext.Properties deployProperties) { + Deriver(RankProfile compiled, + AttributeFields attributeFields, + ModelContext.Properties deployProperties, + QueryProfileRegistry queryProfiles) { rankprofileName = compiled.name(); attributeTypes = compiled.getAttributeTypes(); queryFeatureTypes = compiled.getQueryFeatureTypes(); @@ -179,7 +182,9 @@ public class RawRankProfile implements RankProfilesConfig.Producer { Map<String, RankProfile.RankingExpressionFunction> functions = compiled.getFunctions(); List<ExpressionFunction> functionExpressions = functions.values().stream().map(f -> f.function()).collect(Collectors.toList()); Map<String, String> functionProperties = new LinkedHashMap<>(); - SerializationContext functionSerializationContext = new SerializationContext(functionExpressions); + SerializationContext functionSerializationContext = new SerializationContext(functionExpressions, + Map.of(), + compiled.typeContext(queryProfiles)); if (firstPhaseRanking != null) { functionProperties.putAll(firstPhaseRanking.getRankProperties(functionSerializationContext)); @@ -201,8 +206,8 @@ public class RawRankProfile implements RankProfilesConfig.Producer { } private void derivePropertiesAndFeaturesFromFunctions(Map<String, RankProfile.RankingExpressionFunction> functions, - Map<String, String> functionProperties, - SerializationContext functionContext) { + Map<String, String> functionProperties, + SerializationContext functionContext) { if (functions.isEmpty()) return; replaceFunctionFeatures(summaryFeatures, functionContext); diff --git a/config-model/src/test/derived/slice/rank-profiles.cfg b/config-model/src/test/derived/slice/rank-profiles.cfg new file mode 100644 index 00000000000..75725b81ecf --- /dev/null +++ b/config-model/src/test/derived/slice/rank-profiles.cfg @@ -0,0 +1,25 @@ +rankprofile[].name "default" +rankprofile[].fef.property[].name "vespa.type.query.myTensor" +rankprofile[].fef.property[].value "tensor<float>(key{})" +rankprofile[].name "unranked" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "value(0)" +rankprofile[].fef.property[].name "vespa.hitcollector.heapsize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" +rankprofile[].fef.property[].value "0" +rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" +rankprofile[].fef.property[].value "true" +rankprofile[].fef.property[].name "vespa.type.query.myTensor" +rankprofile[].fef.property[].value "tensor<float>(key{})" +rankprofile[].name "parent" +rankprofile[].fef.property[].name "rankingExpression(mySlice@77dee0712164ce73).rankingScript" +rankprofile[].fef.property[].value "query(myTensor){key:MY_KEY2}" +rankprofile[].fef.property[].name "rankingExpression(myFunction).rankingScript" +rankprofile[].fef.property[].value "4 * query(myTensor){key:MY_KEY1} * rankingExpression(mySlice@77dee0712164ce73)" +rankprofile[].fef.property[].name "rankingExpression(myValue).rankingScript" +rankprofile[].fef.property[].value "4" +rankprofile[].fef.property[].name "rankingExpression(mySlice).rankingScript" +rankprofile[].fef.property[].value "myTensor{key:MY_KEY2}" +rankprofile[].fef.property[].name "vespa.type.query.myTensor" +rankprofile[].fef.property[].value "tensor<float>(key{})" diff --git a/config-model/src/test/derived/slice/test.sd b/config-model/src/test/derived/slice/test.sd index fbb581d1b1d..c2060300785 100644 --- a/config-model/src/test/derived/slice/test.sd +++ b/config-model/src/test/derived/slice/test.sd @@ -5,8 +5,8 @@ search test { rank-profile parent { - function inline cpmScore() { - expression: myValue * mySlice(query(myTensor)) + function inline myFunction() { + expression: myValue * query(myTensor){MY_KEY1} * mySlice(query(myTensor)) } function inline myValue() { @@ -14,7 +14,9 @@ search test { } function inline mySlice(myTensor) { - expression: myTensor{"NULL"} + # TODO: We are missing type resolving across function calls in serialization, + # so using the short form (without 'key') here will fail + expression: myTensor{key:MY_KEY2} } } diff --git a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java index df87de2a12b..b3d2e6451f3 100644 --- a/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java +++ b/container-search/src/main/java/com/yahoo/prelude/fastsearch/DocsumDefinitionSet.java @@ -51,7 +51,7 @@ public final class DocsumDefinitionSet { } if (ds == null) { throw new ConfigurationException("Fetched hit with summary class " + summaryClass + - ", but this summary class is not in current summary config (" + toString() + ")" + + ", but this summary class is not in current summary config (" + this + ")" + " (that is, you asked for something unknown, and no default was found)"); } return ds; @@ -66,7 +66,7 @@ public final class DocsumDefinitionSet { * @return Error message or null on success. * @throws ConfigurationException if the summary class of this hit is missing */ - public final String lazyDecode(String summaryClass, byte[] data, FastHit hit) { + public String lazyDecode(String summaryClass, byte[] data, FastHit hit) { ByteBuffer buffer = ByteBuffer.wrap(data); buffer.order(ByteOrder.LITTLE_ENDIAN); long docsumClassId = buffer.getInt(); @@ -83,6 +83,7 @@ public final class DocsumDefinitionSet { return null; } + @Override public String toString() { StringBuilder sb = new StringBuilder(); for (Map.Entry<String, DocsumDefinition> e : definitionsByName.entrySet() ) { diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 2d7daf2300e..5d7e281df87 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1611,8 +1611,10 @@ "public void <init>(java.util.Collection)", "public void <init>(java.util.Map)", "public void <init>(java.util.Collection, java.util.Map)", - "public void <init>(java.util.Collection, java.util.Map, java.util.Map)", + "public void <init>(java.util.Collection, java.util.Map, com.yahoo.tensor.evaluation.TypeContext)", + "public java.util.Optional typeContext()", "public void <init>(java.util.Map, java.util.Map, java.util.Map)", + "public void <init>(java.util.Map, java.util.Map, java.util.Optional, java.util.Map)", "public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map)", "public void addFunctionSerialization(java.lang.String, java.lang.String)", "public void addArgumentTypeSerialization(java.lang.String, java.lang.String, com.yahoo.tensor.TensorType)", diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java index fe22a5b1267..e770e6ac038 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/Arguments.java @@ -15,7 +15,9 @@ import java.util.List; * @author bratseth */ public final class Arguments implements Serializable { + public static final Arguments EMPTY = new Arguments(); + private final ImmutableList<ExpressionNode> expressions; public Arguments() { @@ -47,9 +49,9 @@ public final class Arguments implements Serializable { /** Evaluate all arguments in this */ public Value[] evaluate(Context context) { - Value[] values=new Value[expressions.size()]; - for (int i=0; i<expressions.size(); i++) - values[i]=expressions.get(i).evaluate(context); + Value[] values = new Value[expressions.size()]; + for (int i = 0; i < expressions.size(); i++) + values[i] = expressions.get(i).evaluate(context); return values; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 93b8b8aca5e..8ac1829b16b 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -42,9 +42,11 @@ public final class ReferenceNode extends CompositeNode { return reference.name(); } + @Override public int hashCode() { return reference.hashCode(); } + /** Returns the arguments, never null */ public Arguments getArguments() { return reference.arguments(); } @@ -118,7 +120,7 @@ public final class ReferenceNode extends CompositeNode { throw new IllegalArgumentException(reference + " is invalid", e); } if (type == null) - throw new IllegalArgumentException("Unknown feature '" + toString() + "'"); + throw new IllegalArgumentException("Unknown feature '" + this + "'"); return type; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index 535ad013caf..1f3203f2e35 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -4,13 +4,16 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableMap; import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collection; import java.util.Collections; import java.util.HashMap; import java.util.LinkedHashMap; import java.util.Map; +import java.util.Optional; /** * Context needed to serialize an expression to a string. This has the lifetime of a single serialization @@ -22,6 +25,8 @@ public class SerializationContext extends FunctionReferenceContext { /** Serialized form of functions indexed by name */ private final Map<String, String> serializedFunctions; + private final Optional<TypeContext<Reference>> typeContext; + /** Create a context for a single serialization task */ public SerializationContext() { this(Collections.emptyList()); @@ -29,7 +34,7 @@ public class SerializationContext extends FunctionReferenceContext { /** Create a context for a single serialization task */ public SerializationContext(Collection<ExpressionFunction> functions) { - this(functions, Collections.emptyMap(), new LinkedHashMap<>()); + this(functions, Collections.emptyMap(), Optional.empty(), new LinkedHashMap<>()); } /** Create a context for a single serialization task */ @@ -39,7 +44,13 @@ public class SerializationContext extends FunctionReferenceContext { /** Create a context for a single serialization task */ public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings) { - this(functions, bindings, new LinkedHashMap<>()); + this(functions, bindings, Optional.empty(), new LinkedHashMap<>()); + } + + /** Create a context for a single serialization task */ + public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings, + TypeContext<Reference> typeContext) { + this(functions, bindings, Optional.of(typeContext), new LinkedHashMap<>()); } /** @@ -47,14 +58,19 @@ public class SerializationContext extends FunctionReferenceContext { * * @param functions the functions of this * @param bindings the arguments of this + * @param typeContext the type context of this: Serialization may depend on type resolution * @param serializedFunctions a cache of serializedFunctions - the ownership of this map * is <b>transferred</b> to this and will be modified in it */ - public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings, + private SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings, + Optional<TypeContext<Reference>> typeContext, Map<String, String> serializedFunctions) { - this(toMap(functions), bindings, serializedFunctions); + this(toMap(functions), bindings, typeContext, serializedFunctions); } + /** Returns the type context of this, if it is able to resolve types. */ + public Optional<TypeContext<Reference>> typeContext() { return typeContext; } + private static Map<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { Map<String,ExpressionFunction> mapBuilder = new HashMap<>(); for (ExpressionFunction function : list) @@ -70,9 +86,16 @@ public class SerializationContext extends FunctionReferenceContext { * @param serializedFunctions a cache of serializedFunctions - the ownership of this map * is <b>transferred</b> to this and will be modified in it */ + public SerializationContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings, + Map<String, String> serializedFunctions) { + this(functions, bindings, Optional.empty(), serializedFunctions); + } + public SerializationContext(Map<String,ExpressionFunction> functions, Map<String, String> bindings, + Optional<TypeContext<Reference>> typeContext, Map<String, String> serializedFunctions) { super(functions, bindings); + this.typeContext = typeContext; this.serializedFunctions = serializedFunctions; } @@ -88,7 +111,7 @@ public class SerializationContext extends FunctionReferenceContext { serializedFunctions.put(name, expressionString); } - /** Adds the serialization of the an argument type to a function */ + /** Adds the serialization of the argument type to a function */ public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) { serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index d873963bb6e..52d54c9163e 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -168,8 +168,8 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public String toString(ToStringContext c) { - ToStringContext outermost = c; + public String toString(ToStringContext<Reference> c) { + ToStringContext<Reference> outermost = c; while (outermost.parent() != null) outermost = outermost.parent(); @@ -251,15 +251,17 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public String toString(ToStringContext c) { - ToStringContext outermost = c; + public String toString(ToStringContext<Reference> c) { + ToStringContext<Reference> outermost = c; while (outermost.parent() != null) outermost = outermost.parent(); if (outermost instanceof ExpressionToStringContext) { ExpressionToStringContext context = (ExpressionToStringContext)outermost; return expression.toString(new StringBuilder(), - new ExpressionToStringContext(context.wrappedSerializationContext, c, context.path, context.parent), + new ExpressionToStringContext(context.wrappedSerializationContext, c, + context.path, + context.parent), context.path, context.parent) .toString(); @@ -281,9 +283,9 @@ public class TensorFunctionNode extends CompositeNode { * to add more context, we need to keep track of both these contexts here separately and map between * contexts as seen in the toString methods of the function classes above. */ - private static class ExpressionToStringContext extends SerializationContext implements ToStringContext { + private static class ExpressionToStringContext extends SerializationContext implements ToStringContext<Reference> { - private final ToStringContext wrappedToStringContext; + private final ToStringContext<Reference> wrappedToStringContext; private final SerializationContext wrappedSerializationContext; private final Deque<String> path; private final CompositeNode parent; @@ -297,7 +299,7 @@ public class TensorFunctionNode extends CompositeNode { } ExpressionToStringContext(SerializationContext wrappedSerializationContext, - ToStringContext wrappedToStringContext, + ToStringContext<Reference> wrappedToStringContext, Deque<String> path, CompositeNode parent) { this.wrappedSerializationContext = wrappedSerializationContext; @@ -328,6 +330,12 @@ public class TensorFunctionNode extends CompositeNode { /** Returns a function or null if it isn't defined in this context */ public ExpressionFunction getFunction(String name) { return wrappedSerializationContext.getFunction(name); } + /** Returns the type context of this, or empty if none. */ + @Override + public Optional<TypeContext<Reference>> typeContext() { + return wrappedSerializationContext.typeContext(); + } + /** @deprecated Use {@link #getFunctions()} instead */ @SuppressWarnings("removal") @Deprecated(forRemoval = true, since = "7") @@ -335,9 +343,10 @@ public class TensorFunctionNode extends CompositeNode { return ImmutableMap.copyOf(wrappedSerializationContext.getFunctions()); } - @Override protected Map<String, ExpressionFunction> getFunctions() { return wrappedSerializationContext.getFunctions(); } + @Override + protected Map<String, ExpressionFunction> getFunctions() { return wrappedSerializationContext.getFunctions(); } - public ToStringContext parent() { return wrappedToStringContext; } + public ToStringContext<Reference> parent() { return wrappedToStringContext; } /** Returns the resolution of an identifier, or null if it isn't defined in this context */ @Override @@ -361,6 +370,7 @@ public class TensorFunctionNode extends CompositeNode { SerializationContext serializationContext = new SerializationContext(getFunctions(), null, serializedFunctions()); return new ExpressionToStringContext(serializationContext, null, path, parent); } + } /** Turns an EvaluationContext into a Context */ diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 71e6c0f28f4..a30ee055538 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -2658,8 +2658,7 @@ "public java.util.Optional dimension()", "public java.util.Optional label()", "public java.util.Optional index()", - "public java.lang.String toString()", - "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" + "public java.lang.String toString()" ], "fields": [] }, @@ -2744,6 +2743,7 @@ "methods": [ "public static com.yahoo.tensor.functions.ToStringContext empty()", "public abstract java.lang.String getBinding(java.lang.String)", + "public java.util.Optional typeContext()", "public abstract com.yahoo.tensor.functions.ToStringContext parent()" ], "fields": [] diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java index 457cfcbfa5f..3b12b6bdba1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java @@ -57,7 +57,7 @@ public class TypeResolver { static public TensorType peek(TensorType inputType, List<String> peekDimensions) { if (peekDimensions.isEmpty()) { - throw new IllegalArgumentException("peeking no dimensions makes no sense"); + throw new IllegalArgumentException("Peeking no dimensions makes no sense"); } Map<String, Dimension> map = new HashMap<>(); for (Dimension dim : inputType.dimensions()) { @@ -67,7 +67,7 @@ public class TypeResolver { if (map.containsKey(name)) { map.remove(name); } else { - throw new IllegalArgumentException("peeking non-existing dimension "+name+" in type "+inputType); + throw new IllegalArgumentException("Peeking non-existing dimension '" + name + "'"); } } if (map.isEmpty()) { @@ -79,10 +79,10 @@ public class TypeResolver { static public TensorType rename(TensorType inputType, List<String> from, List<String> to) { if (from.isEmpty()) { - throw new IllegalArgumentException("renaming no dimensions"); + throw new IllegalArgumentException("Renaming no dimensions"); } if (from.size() != to.size()) { - throw new IllegalArgumentException("bad rename, from size "+from.size()+" != to.size "+to.size()); + throw new IllegalArgumentException("Bad rename, from size "+from.size()+" != to.size "+to.size()); } Map<String,Dimension> oldDims = new HashMap<>(); for (Dimension dim : inputType.dimensions()) { @@ -96,7 +96,7 @@ public class TypeResolver { var dim = oldDims.remove(oldName); newDims.put(newName, dim.withName(newName)); } else { - logger.log(Level.WARNING, "renaming non-existing dimension "+oldName+" in type "+inputType); + logger.log(Level.WARNING, "Renaming non-existing dimension "+oldName+" in type "+inputType); // throw new IllegalArgumentException("bad rename, dimension "+oldName+" not found"); } } @@ -106,13 +106,13 @@ public class TypeResolver { if (inputType.dimensions().size() == newDims.size()) { return new TensorType(inputType.valueType(), newDims.values()); } else { - throw new IllegalArgumentException("bad rename, lost some dimenions"); + throw new IllegalArgumentException("Bad rename, lost some dimensions"); } } static public TensorType cell_cast(TensorType inputType, Value toCellType) { if (toCellType != Value.DOUBLE && inputType.dimensions().isEmpty()) { - throw new IllegalArgumentException("cannot cast "+inputType+" to valueType"+toCellType); + throw new IllegalArgumentException("Cannot cast "+inputType+" to valueType"+toCellType); } return new TensorType(toCellType, inputType.dimensions()); } @@ -188,7 +188,7 @@ public class TypeResolver { if (allOk) { return join(lhs, rhs); } else { - throw new IllegalArgumentException("types in merge() dimensions mismatch: "+lhs+" != "+rhs); + throw new IllegalArgumentException("Types in merge() dimensions mismatch: "+lhs+" != "+rhs); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index a376536015a..dbc8396d701 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -58,7 +58,7 @@ public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return name; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 16ca7104f8d..55dd8a7bc8a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -48,7 +48,7 @@ public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "argmax(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index fcdc1233550..f1f0b9d67b0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -48,7 +48,7 @@ public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "argmin(" + argument.toString(context) + Reduce.commaSeparated(dimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index 8c6c27e171a..09f84e6747e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -107,7 +107,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "cell_cast(" + argument.toString(context) + ", " + valueType + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 32a4c8cd2ff..6d4b15be991 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -285,7 +285,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "concat(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 1544369ba2f..a0fd9272f54 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -47,6 +47,6 @@ public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFuncti public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override - public String toString(ToStringContext context) { return constant.toString(); } + public String toString(ToStringContext<NAMETYPE> context) { return constant.toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index 2c0fa483021..92d89ec68f7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -41,7 +41,7 @@ public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYP } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "diag(" + dimensionNames().collect(Collectors.joining(",")) + ")" + diagFunction; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 97126ad88a7..46992115c23 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -46,11 +46,11 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens TensorType type() { return type; } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return type().toString() + ":" + contentToString(context); } - abstract String contentToString(ToStringContext context); + abstract String contentToString(ToStringContext<NAMETYPE> context); /** Creates a dynamic tensor function. The cell addresses must match the type. */ public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { @@ -80,7 +80,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens } @Override - String contentToString(ToStringContext context) { + String contentToString(ToStringContext<NAMETYPE> context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.values().iterator().next().toString(context) + "}"; @@ -121,7 +121,7 @@ public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTens } @Override - String contentToString(ToStringContext context) { + String contentToString(ToStringContext<NAMETYPE> context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.get(0).toString(context) + "}"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java index 8fc246a7d9d..c049e5d41da 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -41,7 +41,7 @@ public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 89e981df49e..54e83fa472f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -117,9 +117,9 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; } + public String toString(ToStringContext<NAMETYPE> context) { return type + "(" + generatorToString(context) + ")"; } - private String generatorToString(ToStringContext context) { + private String generatorToString(ToStringContext<NAMETYPE> context) { if (freeGenerator != null) return freeGenerator.toString(); else @@ -183,11 +183,11 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } /** A context which adds the bindings of the generate dimension names to the given context. */ - private class GenerateToStringContext implements ToStringContext { + private class GenerateToStringContext implements ToStringContext<NAMETYPE> { - private final ToStringContext context; + private final ToStringContext<NAMETYPE> context; - public GenerateToStringContext(ToStringContext context) { + public GenerateToStringContext(ToStringContext<NAMETYPE> context) { this.context = context; } @@ -200,7 +200,7 @@ public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM } @Override - public ToStringContext parent() { return context; } + public ToStringContext<NAMETYPE> parent() { return context; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 0d4aeb5c37d..52bef482fb4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -75,7 +75,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 903d0b2dcd9..f47202d1b9f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -39,7 +39,7 @@ public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction< } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index c862aa4eaf6..8f4e2f466d4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -41,7 +41,7 @@ public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction< } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 40620cb95fe..46772d8cbff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -71,7 +71,7 @@ public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "map(" + argument.toString(context) + ", " + mapper + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 810e01011fe..8ac6d711c48 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -45,7 +45,7 @@ public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "matmul(" + argument1.toString(context) + ", " + argument2.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index bd42e95a59e..adc84225a63 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -70,7 +70,7 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "merge(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + merger + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 8cf1964585a..18c5db8e3a7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -38,7 +38,7 @@ public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "random(" + dimensionNames().collect(Collectors.joining(",")) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index 7d5b11d6672..45b827db900 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -42,7 +42,7 @@ public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETY } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "range(" + dimensionNames().collect(Collectors.joining(",")) + ")" + rangeFunction; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 79209fd8f09..8841cff15e9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -86,7 +86,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "reduce(" + argument.toString(context) + ", " + aggregator + commaSeparated(dimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 9b56fefb5f0..7505355beed 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -314,7 +314,7 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "reduce_join(" + argumentA.toString(context) + ", " + argumentB.toString(context) + ", " + combinator + ", " + @@ -324,8 +324,8 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N private static class MultiDimensionIterator { - private long[] bounds; - private long[] iterator; + private final long[] bounds; + private final long[] iterator; private int remaining; MultiDimensionIterator(TensorType type) { @@ -364,9 +364,11 @@ public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<N remaining -= 1; } + @Override public String toString() { return Arrays.toString(iterator); } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 67ede7f6540..a434ecba5cc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -128,7 +128,7 @@ public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "rename(" + argument.toString(context) + ", " + toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index 11e52aad73e..0e0dc9a9aa8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -20,6 +20,6 @@ public interface ScalarFunction<NAMETYPE extends Name> extends Function<Evaluati /** Returns this as a tensor function, or empty if it cannot be represented as a tensor function */ default Optional<TensorFunction<NAMETYPE>> asTensorFunction() { return Optional.empty(); } - default String toString(ToStringContext context) { return toString(); } + default String toString(ToStringContext<NAMETYPE> context) { return toString(); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java index 09bfb8b996b..da7581c39f9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java @@ -121,8 +121,8 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY private TensorType resultType(TensorType argumentType) { List<String> peekDimensions; - // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + // Special case where a single indexed or mapped dimension is sliced if (subspaceAddress.get(0).index().isPresent()) { peekDimensions = findDimensions(argumentType.dimensions(), TensorType.Dimension::isIndexed); if (peekDimensions.size() > 1) { @@ -140,22 +140,28 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY else { // general slicing peekDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toList()); } - if (peekDimensions.isEmpty()) - throw new IllegalArgumentException(this + " cannot slice " + argumentType + ": No dimensions to slice"); - return TypeResolver.peek(argumentType, peekDimensions); + try { + return TypeResolver.peek(argumentType, peekDimensions); + } + catch (IllegalArgumentException e) { + throw new IllegalArgumentException(this + " cannot slice type " + argumentType, e); + } } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { StringBuilder b = new StringBuilder(argument.toString(context)); - if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { + if (context.typeContext().isEmpty() + && subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) { // use short forms if (subspaceAddress.get(0).index().isPresent()) b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]"); else b.append("{").append(subspaceAddress.get(0).label().get()).append("}"); } - else { - b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}"); + else { // general form + b.append("{").append(subspaceAddress.stream() + .map(i -> i.toString(context, this)) + .collect(Collectors.joining(", "))).append("}"); } return b.toString(); } @@ -222,12 +228,22 @@ public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY @Override public String toString() { - return toString(ToStringContext.empty()); + return toString(null, null); } - public String toString(ToStringContext context) { + String toString(ToStringContext<NAMETYPE> context, Slice<NAMETYPE> owner) { StringBuilder b = new StringBuilder(); - dimension.ifPresent(d -> b.append(d).append(":")); + Optional<String> dimensionName = dimension; + if (context != null && dimensionName.isEmpty()) { // This isn't just toString(): Output canonical form or fail + TensorType type = context.typeContext().isPresent() ? owner.argument.type(context.typeContext().get()) : null; + if (type == null || type.dimensions().size() != 1) + throw new IllegalArgumentException("The tensor dimension name being sliced by " + owner + + " cannot be uniquely resolved. Use the full form " + + "slice{myDimensionName: ..."); + else + dimensionName = Optional.of(type.dimensions().get(0).name()); + } + dimensionName.ifPresent(d -> b.append(d).append(":")); if (label != null) b.append(label); else diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index 13420a12e8f..9ea9040831b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -46,7 +46,7 @@ public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAME } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "softmax(" + argument.toString(context) + ", " + dimension + ")"; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 81d3692bd94..1e1d1d3b5b9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -60,7 +60,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> { * * @param context a context which must be passed to all nested functions when requesting the string value */ - public abstract String toString(ToStringContext context); + public abstract String toString(ToStringContext<NAMETYPE> context); /** Returns this as a scalar function, or empty if it cannot be represented as a scalar function */ public Optional<ScalarFunction<NAMETYPE>> asScalarFunction() { return Optional.empty(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java index 1c8da9a1dca..233779fcebe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ToStringContext.java @@ -1,31 +1,42 @@ // Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Optional; + /** * A context which is passed down to all nested functions when returning a string representation. * * @author bratseth */ -public interface ToStringContext { +public interface ToStringContext<NAMETYPE extends Name> { - static ToStringContext empty() { return new EmptyStringContext(); } + static <NAMETYPE extends Name> ToStringContext<NAMETYPE> empty() { return new EmptyStringContext<NAMETYPE>(); } /** Returns the name an identifier is bound to, or null if not bound in this context */ String getBinding(String name); /** + * Returns the context used to resolve types in this, if present. + * In some functions serialization depends on type information. + */ + default Optional<TypeContext<NAMETYPE>> typeContext() { return Optional.empty(); } + + /** * Returns the parent context of this (the context we're in scope of when this is created), * or null if this is the root. */ - ToStringContext parent(); + ToStringContext<NAMETYPE> parent(); - class EmptyStringContext implements ToStringContext { + class EmptyStringContext<NAMETYPE extends Name> implements ToStringContext<NAMETYPE> { @Override public String getBinding(String name) { return null; } @Override - public ToStringContext parent() { return null; } + public ToStringContext<NAMETYPE> parent() { return null; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 112a0d43796..0223ad4d588 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -44,7 +44,7 @@ public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAME } @Override - public String toString(ToStringContext context) { + public String toString(ToStringContext<NAMETYPE> context) { return "xw_plus_b(" + x.toString(context) + ", " + w.toString(context) + ", " + b.toString(context) + ", " + |