diff options
35 files changed, 95 insertions, 50 deletions
diff --git a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java index 0b312d40815..fcae756eab3 100644 --- a/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java +++ b/config-model/src/main/java/com/yahoo/searchdefinition/MapEvaluationTypeContext.java @@ -48,6 +48,11 @@ public class MapEvaluationTypeContext extends FunctionReferenceContext implement } @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + + @Override public TensorType getType(Reference reference) { Optional<String> binding = boundIdentifier(reference); if (binding.isPresent()) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java index c8d90e8c4e8..6b2422d7cb2 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/RankingExpression.java @@ -244,10 +244,6 @@ public class RankingExpression implements Serializable { * @return a list of named rank properties required to implement this expression. */ public Map<String, String> getRankProperties(List<ExpressionFunction> macros) { - Map<String, ExpressionFunction> arg = new HashMap<>(); - for (ExpressionFunction function : macros) { - arg.put(function.getName(), function); - } Deque<String> path = new LinkedList<>(); SerializationContext context = new SerializationContext(macros); String serializedRoot = root.toString(context, path, null); @@ -272,7 +268,7 @@ public class RankingExpression implements Serializable { * * @throws IllegalArgumentException if this expression is not type correct in this context */ - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return root.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index 4102d4078e6..4e046df11ca 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -5,6 +5,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.rule.Arguments; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Set; @@ -25,6 +26,11 @@ public abstract class Context implements EvaluationContext<Reference> { */ public abstract Value get(String name); + @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + /** Returns a variable as a tensor */ @Override public Tensor getTensor(String name) { return get(name).asTensor(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java index 985878cfd66..2a42e2d92f7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/MapTypeContext.java @@ -23,6 +23,11 @@ public class MapTypeContext implements TypeContext<Reference> { } @Override + public TensorType getType(String reference) { + throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + } + + @Override public TensorType getType(Reference reference) { return featureTypes.get(reference); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java index 8ee4cdbf297..649c70122f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTForestNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -26,7 +27,7 @@ public class GBDTForestNode extends ExpressionNode { } @Override - public final TensorType type(TypeContext context) { return TensorType.empty; } + public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java index aac635b2545..53a286f09f6 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/gbdtoptimization/GBDTNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.evaluation.gbdtoptimization; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -51,7 +52,7 @@ public final class GBDTNode extends ExpressionNode { public final double[] values() { return values; } @Override - public final TensorType type(TypeContext context) { return TensorType.empty; } + public final TensorType type(TypeContext<Reference> context) { return TensorType.empty; } @Override public final Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java index fc6428a4c33..49c49bed9bd 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ArithmeticNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -80,7 +81,7 @@ public final class ArithmeticNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { // Compute type using tensor types as arithmetic operators are supported on tensors // and is correct also in the special case of doubles. // As all our functions are type-commutative, we don't need to take operator precedence into account diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java index 1d7d9b1ecda..cd4ddbcae55 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/BooleanNode.java @@ -5,7 +5,6 @@ package com.yahoo.searchlib.rankingexpression.rule; * A node which produces a boolean value when evaluated. * * @author bratseth - * @since 5.1.21 */ public abstract class BooleanNode extends CompositeNode { } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java index 7601c0e6180..eb328486045 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ComparisonNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -49,7 +50,7 @@ public class ComparisonNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; // by definition } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java index 1ea8d03f0eb..3ddd7223349 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ConstantNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -49,7 +50,7 @@ public final class ConstantNode extends ExpressionNode { } @Override - public TensorType type(TypeContext context) { return value.type(); } + public TensorType type(TypeContext<Reference> context) { return value.type(); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java index fd9fab99db8..47c2897e4a4 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/EmbracedNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -50,7 +51,7 @@ public final class EmbracedNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java index 477f4db4981..6bb163590de 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ExpressionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -48,7 +49,7 @@ public abstract class ExpressionNode implements Serializable { * @param context the variable type bindings to use for this evaluation * @throws IllegalArgumentException if there are variables which are not bound in the given map */ - public abstract TensorType type(TypeContext context); + public abstract TensorType type(TypeContext<Reference> context); /** * Returns the value of evaluating this expression over the given context. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java index 79515229019..1da2210a39c 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -67,7 +68,7 @@ public final class FunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { if (arguments.expressions().size() == 0) return TensorType.empty; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index e42884ecc05..c87eb0ace39 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -48,7 +49,7 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { return type; } + public TensorType type(TypeContext<Reference> context) { return type; } /** Evaluate this in a context which must have the arguments bound */ @Override diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java index 66b250736e8..ee4edac4941 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/IfNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -75,7 +76,7 @@ public final class IfNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { TensorType trueType = trueExpression.type(context); TensorType falseType = falseExpression.type(context); return trueType.dimensionwiseGeneralizationWith(falseType).orElseThrow(() -> diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java index da946228291..61086f8182a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/LambdaFunctionNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -57,7 +58,7 @@ public class LambdaFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; // by definition - no nested lambdas } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java index 759d966e10b..f1adf331630 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NameNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -33,7 +34,7 @@ public final class NameNode extends ExpressionNode { } @Override - public TensorType type(TypeContext context) { throw new RuntimeException("Named nodes can not have a type"); } + public TensorType type(TypeContext<Reference> context) { throw new RuntimeException("Named nodes can not have a type"); } @Override public Value evaluate(Context context) { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java index 9cbe5f98c72..fcc03dc4862 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NegativeNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -38,7 +39,7 @@ public class NegativeNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java index e7041600635..a539f496ff5 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/NotNode.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.TensorType; @@ -38,7 +39,7 @@ public class NotNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return value.type(context); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index ca6d8aa7104..78f53b1593d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -92,7 +92,7 @@ public final class ReferenceNode extends CompositeNode { public Reference reference() { return reference; } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { TensorType type = context.getType(reference); if (type == null) throw new IllegalArgumentException("Unknown feature '" + toString() + "'"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java index a7b82f4753f..cb31219579a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SetMembershipNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.collect.ImmutableList; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.BooleanValue; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; @@ -60,7 +61,7 @@ public class SetMembershipNode extends BooleanNode { } @Override - public TensorType type(TypeContext context) { + public TensorType type(TypeContext<Reference> context) { return TensorType.empty; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index ec6af4bb413..6c9b6bb4a98 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.rule; import com.google.common.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -64,7 +65,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorType type(TypeContext context) { return function.type(context); } + public TensorType type(TypeContext<Reference> context) { return function.type(context); } @Override public Value evaluate(Context context) { @@ -111,12 +112,13 @@ public class TensorFunctionNode extends CompositeNode { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { - return expression.type(context); + @SuppressWarnings("unchecked") // Generics awkwardness + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return expression.type((TypeContext<Reference>)context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return expression.evaluate((Context)context).asTensor(); } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java index fc73bcd3f79..a08d510eec4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/TypeResolutionTestCase.java @@ -37,7 +37,7 @@ public class TypeResolutionTestCase { assertIncompatibleType("if (1>0, query(x1), query(y1))", context); } - private void assertType(String type, String expression, TypeContext context) { + private void assertType(String type, String expression, TypeContext<Reference> context) { try { assertEquals(TensorType.fromSpec(type), new RankingExpression(expression).type(context)); } @@ -46,7 +46,7 @@ public class TypeResolutionTestCase { } } - private void assertIncompatibleType(String expression, TypeContext context) { + private void assertIncompatibleType(String expression, TypeContext<Reference> context) { try { new RankingExpression(expression).type(context); fail("Expected type incompatibility exception"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 74457a163fd..b9394da31e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -18,6 +18,11 @@ public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override + public TensorType getType(String name) { + return getType(new Name(name)); + } + + @Override public TensorType getType(Name name) { Tensor tensor = bindings.get(name.toString()); if (tensor == null) return null; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index 3674b373db0..ff2e6318b37 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -18,6 +18,14 @@ public interface TypeContext<NAMETYPE extends TypeContext.Name> { */ TensorType getType(NAMETYPE name); + /** + * Returns the type of the tensor with this name by converting from a string name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ + TensorType getType(String name); + /** A name which is just a string. Names are value objects. */ class Name { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 5f809a3d2b1..acb2363cba4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -44,15 +44,15 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { - TensorType givenType = context.getType(new TypeContext.Name(name)); + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); return givenType; } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 2109b730e1a..bfc0938abcc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -18,10 +18,14 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(TypeContext context) { return toPrimitive().type(context); } + public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return toPrimitive().type(context); + } /** Evaluates this by first converting it to a primitive function */ @Override - public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } + public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index c77ed1c0526..a073053bec8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -60,7 +60,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); a = ensureIndexedDimension(dimension, a); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 50b479da168..a43de297b9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return constant.type(); } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } @Override - public Tensor evaluate(EvaluationContext context) { return constant; } + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e70d1de3db7..edfa8253eb9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return type; } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); for (int i = 0; i < indexes.size(); i++) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 7812c985091..17e1c103ea3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 53504868ff2..4a338e5501e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return argument.type(context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 416b74e7f94..e045effbe7e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -101,7 +101,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -115,7 +115,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index de3d2be265a..af4492ca1e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 78ab09c7820..e805e9d87bb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -43,14 +43,14 @@ public abstract class TensorFunction { * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract Tensor evaluate(EvaluationContext context); + public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract TensorType type(TypeContext context); + public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } @@ -58,7 +58,7 @@ public abstract class TensorFunction { /** * Return a string representation of this context. * - * @param context a context which must be passed to all nexted functions when requesting the string value + * @param context a context which must be passed to all nested functions when requesting the string value */ public abstract String toString(ToStringContext context); |