diff options
Diffstat (limited to 'searchlib')
8 files changed, 272 insertions, 25 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index d058104bf1b..39642f5cb50 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1413,6 +1413,7 @@ "public void <init>(java.util.Collection, java.util.Map)", "public void <init>(java.util.Map)", "public void <init>(java.util.Map, java.util.Map)", + "public void <init>(java.util.Map, java.util.Map, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext)", "public com.yahoo.searchlib.rankingexpression.ExpressionFunction getFunction(java.lang.String)", "protected com.google.common.collect.ImmutableMap functions()", "public java.lang.String getBinding(java.lang.String)", @@ -1568,7 +1569,7 @@ "public void <init>(java.util.Map)", "public void <init>(java.util.Collection, java.util.Map)", "public void <init>(java.util.Collection, java.util.Map, java.util.Map)", - "public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map)", + "public void <init>(com.google.common.collect.ImmutableMap, java.util.Map, java.util.Map, com.yahoo.searchlib.rankingexpression.rule.FunctionReferenceContext)", "public void addFunctionSerialization(java.lang.String, java.lang.String)", "public void addArgumentTypeSerialization(java.lang.String, java.lang.String, com.yahoo.tensor.TensorType)", "public void addFunctionTypeSerialization(java.lang.String, com.yahoo.tensor.TensorType)", @@ -1597,6 +1598,24 @@ ], "fields": [] }, + "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionScalarFunction": { + "superClass": "java.lang.Object", + "interfaces": [ + "com.yahoo.tensor.functions.ScalarFunction" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public com.yahoo.searchlib.rankingexpression.rule.ExpressionNode getExpression()", + "public java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public bridge synthetic java.lang.Object apply(java.lang.Object)" + ], + "fields": [] + }, "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction": { "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", "interfaces": [], diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 674571ff73e..f2f8799b342 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -134,6 +134,8 @@ public class ExpressionFunction { for (int i = 0; i < arguments.size() && i < argumentValues.size(); ++i) { argumentBindings.put(arguments.get(i), argumentValues.get(i).toString(new StringBuilder(), context, path, null).toString()); } + String symbol = toSymbol(argumentBindings); + System.out.println("Expanding function " + symbol); return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString()); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java index 83aabada8f0..9d094ce06f4 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/FunctionReferenceContext.java @@ -22,6 +22,8 @@ public class FunctionReferenceContext { /** Mapping from argument names to the expressions they resolve to */ private final Map<String, String> bindings = new HashMap<>(); + private final FunctionReferenceContext parent; + /** Create a context for a single serialization task */ public FunctionReferenceContext() { this(Collections.emptyList()); @@ -43,9 +45,14 @@ public class FunctionReferenceContext { /** Create a context for a single serialization task */ public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings) { + this(functions, bindings, null); + } + + public FunctionReferenceContext(Map<String, ExpressionFunction> functions, Map<String, String> bindings, FunctionReferenceContext parent) { this.functions = ImmutableMap.copyOf(functions); if (bindings != null) this.bindings.putAll(bindings); + this.parent = parent; } private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { @@ -56,16 +63,34 @@ public class FunctionReferenceContext { } /** Returns a function or null if it isn't defined in this context */ - public ExpressionFunction getFunction(String name) { return functions.get(name); } + public ExpressionFunction getFunction(String name) { + ExpressionFunction function = functions.get(name); + if (function != null) { + return function; + } + if (parent != null) { + return parent.getFunction(name); + } + return null; + } protected ImmutableMap<String, ExpressionFunction> functions() { return functions; } /** Returns the resolution of an identifier, or null if it isn't defined in this context */ - public String getBinding(String name) { return bindings.get(name); } + public String getBinding(String name) { + String binding = bindings.get(name); + if (binding != null) { + return binding; + } + if (parent != null) { + return parent.getBinding(name); + } + return null; + } /** Returns a new context with the bindings replaced by the given bindings */ public FunctionReferenceContext withBindings(Map<String, String> bindings) { - return new FunctionReferenceContext(this.functions, bindings); + return new FunctionReferenceContext(this.functions, bindings, this); } } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java index 8fec3603f3e..a994f5247b7 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/ReferenceNode.java @@ -74,20 +74,49 @@ public final class ReferenceNode extends CompositeNode { return string.append(context.getBinding(getName())); } + String name = getName(); // A reference to a function? ExpressionFunction function = context.getFunction(getName()); if (function != null && function.arguments().size() == getArguments().size() && getOutput() == null) { // a function reference: replace by the referenced function wrapped in rankingExpression - if (path == null) - path = new ArrayDeque<>(); - String myPath = getName() + getArguments().expressions(); - if (path.contains(myPath)) - throw new IllegalStateException("Cycle in ranking expression function: " + path); - path.addLast(myPath); - ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); - path.removeLast(); - context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); - return string.append("rankingExpression(").append(instance.getName()).append(')'); +// if (path == null) +// path = new ArrayDeque<>(); +// String myPath = getName() + getArguments().expressions(); +// if (path.contains(myPath)) +// throw new IllegalStateException("Cycle in ranking expression function: " + path); +// path.addLast(myPath); +// ExpressionFunction.Instance instance = function.expand(context, getArguments().expressions(), path); +// path.removeLast(); +// context.addFunctionSerialization(RankingExpression.propertyName(instance.getName()), instance.getExpressionString()); +// return string.append("rankingExpression(").append(instance.getName()).append(')'); + +// return new Instance(toSymbol(argumentBindings), body.getRoot().toString(new StringBuilder(), context.withBindings(argumentBindings), path, null).toString()); + + // hack for testing: + // So, this worked. Meaning that when expanding we could probably cut down on the context tree? +// String expression = function.getBody().toString(); +// context.addFunctionSerialization(RankingExpression.propertyName(getName()), expression); // <- actually set by deriveFunctionProperties - this will only overwrite + + String prefix = string.toString(); // incredibly ugly hack - for testing this + + // so problem here with input values + if (prefix.startsWith("attribute")) { + if (name.equals("segment_ids") || name.equals("input_mask") || name.equals("input_ids")) { + return string.append(getName()); + // TODO: divine this! + } + } + + // so, in one case +// rankprofile[2].fef.property[35].name "rankingExpression(imported_ml_function_bertsquad8_input_ids).rankingScript" +// rankprofile[2].fef.property[35].value "input_ids" + // vs +// rankprofile[2].fef.property[2].name "rankingExpression(input_ids).rankingScript" +// rankprofile[2].fef.property[2].value "attribute(input_ids)" + // uppermost is wrong, then we need the below + + return string.append("rankingExpression(").append(getName()).append(')'); + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java index d7807caa2b6..c79f5556e03 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java @@ -50,7 +50,7 @@ public class SerializationContext extends FunctionReferenceContext { */ public SerializationContext(Collection<ExpressionFunction> functions, Map<String, String> bindings, Map<String, String> serializedFunctions) { - this(toMap(functions), bindings, serializedFunctions); + this(toMap(functions), bindings, serializedFunctions, null); } private static ImmutableMap<String, ExpressionFunction> toMap(Collection<ExpressionFunction> list) { @@ -69,8 +69,8 @@ public class SerializationContext extends FunctionReferenceContext { * is <b>transferred</b> to this and will be modified in it */ public SerializationContext(ImmutableMap<String,ExpressionFunction> functions, Map<String, String> bindings, - Map<String, String> serializedFunctions) { - super(functions, bindings); + Map<String, String> serializedFunctions, FunctionReferenceContext root) { + super(functions, bindings, root); this.serializedFunctions = serializedFunctions; } @@ -92,7 +92,7 @@ public class SerializationContext extends FunctionReferenceContext { @Override public SerializationContext withBindings(Map<String, String> bindings) { - return new SerializationContext(functions(), bindings, this.serializedFunctions); + return new SerializationContext(functions(), bindings, this.serializedFunctions, this); } public Map<String, String> serializedFunctions() { return serializedFunctions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 6e1cdf52158..1ab9702367a 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -143,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode { return new ExpressionScalarFunction(node); } - private static class ExpressionScalarFunction implements ScalarFunction<Reference> { + public static class ExpressionScalarFunction implements ScalarFunction<Reference> { private final ExpressionNode expression; @@ -151,6 +151,10 @@ public class TensorFunctionNode extends CompositeNode { this.expression = expression; } + public ExpressionNode getExpression() { + return expression; + } + @Override public Double apply(EvaluationContext<Reference> context) { return expression.evaluate(new ContextWrapper(context)).asDouble(); @@ -321,13 +325,45 @@ public class TensorFunctionNode extends CompositeNode { public ToStringContext parent() { return wrappedToStringContext; } + private int contextNodes() { + int nodes = 0; + if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) { + nodes += ((ExpressionToStringContext)wrappedToStringContext).contextNodes(); + } else if (wrappedToStringContext != null) { + nodes += 1; + } + if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) { + nodes += ((ExpressionToStringContext)wrappedSerializationContext).contextNodes(); + } else if (wrappedSerializationContext != null) { + nodes += 1; + } + return nodes + 1; + } + + private int contextDepth() { + int depth = 0; + if (wrappedToStringContext != null && wrappedToStringContext instanceof ExpressionToStringContext) { + depth += ((ExpressionToStringContext)wrappedToStringContext).contextDepth(); + } + if (wrappedSerializationContext != null && wrappedSerializationContext instanceof ExpressionToStringContext) { + int d = ((ExpressionToStringContext)wrappedSerializationContext).contextDepth(); + depth = Math.max(depth, d); + } + return depth + 1; + } + /** Returns the resolution of an identifier, or null if it isn't defined in this context */ @Override public String getBinding(String name) { - if (wrappedToStringContext != null && wrappedToStringContext.getBinding(name) != null) - return wrappedToStringContext.getBinding(name); - else - return wrappedSerializationContext.getBinding(name); +// System.out.println("getBinding for " + name + " with node count " + contextNodes() + " and max depth " + contextDepth()); + String binding; + if (wrappedToStringContext != null) { + binding = wrappedToStringContext.getBinding(name); + if (binding != null) { + return binding; + } + } + return wrappedSerializationContext.getBinding(name); } /** Returns a new context with the bindings replaced by the given bindings */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java index a541eac2421..95652bb0e15 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/ConstantDereferencer.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.transform; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; @@ -28,8 +29,16 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext return node; } + /** Returns true if the given reference is an attribute, constant or query feature */ + // TEMP: from config-model module + public static boolean isSimpleFeature(Reference reference) { + if ( ! reference.isSimple()) return false; + String name = reference.name(); + return name.equals("attribute") || name.equals("constant") || name.equals("query"); + } + private ExpressionNode transformFeature(ReferenceNode node, TransformContext context) { - if (!node.getArguments().isEmpty()) + if ( ! node.getArguments().isEmpty() && ! isSimpleFeature(node.reference())) return transformArguments(node, context); else return transformConstantReference(node, context); @@ -44,7 +53,14 @@ public class ConstantDereferencer extends ExpressionTransformer<TransformContext } private ExpressionNode transformConstantReference(ReferenceNode node, TransformContext context) { - Value value = context.constants().get(node.getName()); + String name = node.getName(); + if (node.reference().name().equals("constant")) { + ExpressionNode arg = node.getArguments().expressions().get(0); + if (arg instanceof ReferenceNode) { + name = ((ReferenceNode)arg).getName(); + } + } + Value value = context.constants().get(name); // works if "constant(...)" is added if (value == null || value.type().rank() > 0) { return node; // not a number constant reference } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 807eb3aa7ce..7b246f22cf2 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -390,6 +390,126 @@ public class EvaluationTestCase { } @Test + public void testTile() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor(d0[2],d1[4]):[1,2,1,2,3,4,3,4]", + "tensor(d0[2],d1[4])(tensor0{input0:(d0 % 2), input1:(d1 % 2) } )", + "tensor(input0[2],input1[2]):[1, 2, 3, 4]", + "tensor(repeats0[2]):[1,2]"); + + tester.assertEvaluates("tensor(d0[6],d1[2]):[1,2,3,4,1,2,3,4,1,2,3,4]", + "tensor(d0[6],d1[2])(tensor0{input0:(d0 % 2), input1:(d1 % 2) } )", + "tensor(input0[2],input1[2]):[1, 2, 3, 4]", + "tensor(repeats0[2]):[3,1]"); + } + + @Test + public void testReshape() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor(d0[4]):[1,2,3,4]", + "tensor(d0[4])(tensor0{a0:(d0 / 2), a1:(d0 % 2)})", + "tensor(a0[2],a1[2]):[1,2,3,4]", + "tensor(d0[1]):[4]"); + + tester.assertEvaluates("tensor(d0[2],d1[2]):[1,2,3,4]", + "tensor(d0[2],d1[2])(tensor0{a0:(d0), a1:(d1)})", + "tensor(a0[2],a1[2]):[1,2,3,4]", + "tensor(d0[2]):[2,2]"); + + tester.assertEvaluates("tensor(d0[2],d1[1],d2[2]):[1,2,3,4]", + "tensor(d0[2],d1[1],d2[2])(tensor0{a0:(d0), a1:(d2)})", + "tensor(a0[2],a1[2]):[1,2,3,4]", + "tensor(d0[3]):[2,1,2]"); + + tester.assertEvaluates("tensor(d0[3],d1[2]):[1,2,3,4,5,6]", + "tensor(d0[3],d1[2])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", + "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", + "tensor(d0[2]):[3,2]"); + + tester.assertEvaluates("tensor(d0[3],d1[2],d2[1],d3[1]):[1,2,3,4,5,6]", + "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", + "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", + "tensor(d0[4]):[3,2,-1,1]"); + + } + + @Test + public void testMatmul() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor():{91}", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)", + "tensor(d0[6]):[1,2,3,4,5,6]", + "tensor(d0[6]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d1[2]):[22, 28]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)", + "tensor(d0[3]):[1,2,3]", + "tensor(d0[3],d1[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d1[2]):[22, 28]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d0)", + "tensor(d0[3],d1[2]):[1,2,3,4,5,6]", + "tensor(d0[3]):[1,2,3]"); + + tester.assertEvaluates("tensor(d0[2],d2[2]):[22,28,49,64]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d1)", + "tensor(d0[2],d1[3]):[1,2,3,4,5,6]", + "tensor(d1[3],d2[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[2],d3[2]):[22,28,49,64]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d2)", + "tensor(d0[1],d1[2],d2[3]):[1,2,3,4,5,6]", + "tensor(d2[3],d3[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[2],d3[2]):[22,28,49,64]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d2)", + "tensor(d1[2],d2[3]):[1,2,3,4,5,6]", + "tensor(d0[1],d2[3],d3[2]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[4],d2[2],d4[2]):[22,28,49,64,58,64,139,154,94,100,229,244,130,136,319,334]", + "reduce(join(tensor0{d1:0}, tensor1, f(x,y)(x*y)), sum, d3)", // notice peek + "tensor(d0[1],d1[1],d2[2],d3[3]):[1,2,3,4,5,6]", + "tensor(d0[1],d1[4],d3[3],d4[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + + tester.assertEvaluates("tensor(d0[1],d1[4],d2[2],d4[2]):[22,28,49,64,220,244,301,334,634,676,769,820,1264,1324,1453,1522]", + "reduce(join(tensor0, tensor1, f(x,y)(x*y)), sum, d3)", + "tensor(d0[1],d1[4],d2[2],d3[3]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]", + "tensor(d0[1],d1[4],d3[3],d4[2]):[1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24]"); + + } + + @Test + public void testSplit() { + EvaluationTester tester = new EvaluationTester(); + + tester.assertEvaluates("tensor(d0[3]):[1,2,3]", + "tensor(d0[3])(tensor0{input0:(d0)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[3]):[4,5,6]", + "tensor(d0[3])(tensor0{input0:(d0+3)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[4]):[3,4,5,6]", + "tensor(d0[4])(tensor0{input0:(d0+2)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[2]):[3,4]", + "tensor(d0[2])(tensor0{input0:(d0+2)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + tester.assertEvaluates("tensor(d0[2]):[5,6]", + "tensor(d0[2])(tensor0{input0:(d0+4)} )", + "tensor(input0[6]):[1,2,3,4,5,6]"); + + tester.assertEvaluates("tensor(d0[1],d1[3]):[1,2,3]", + "tensor(d0[1],d1[3])(tensor0{input0:(d0), input1:(d1)} )", + "tensor(input0[2],input1[3]):[[1,2,3],[4,5,6]]"); + tester.assertEvaluates("tensor(d0[1],d1[3]):[4,5,6]", + "tensor(d0[1],d1[3])(tensor0{input0:(d0+1), input1:(d1)} )", + "tensor(input0[2],input1[3]):[[1,2,3],[4,5,6]]"); + } + + @Test public void testTake() { EvaluationTester tester = new EvaluationTester(); |