diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 16:51:50 +0200 |
commit | f4203c3cc571722f08ee65047437c1290ed63f69 (patch) | |
tree | 7d06d17091a2e388e6771187a11cf4f4023a0c1e | |
parent | 316c941e90f39d2e9bc46f12b96ca0f87471d1bd (diff) |
Allow bound functions in tensor generate
10 files changed, 67 insertions, 28 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index d5970a4b69e..dcf42069373 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1615,8 +1615,9 @@ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", - "public static java.util.Map wrap(java.util.Map)", - "public static java.util.List wrap(java.util.List)" + "public static java.util.Map wrapScalars(java.util.Map)", + "public static java.util.List wrapScalars(java.util.List)", + "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" ], "fields": [] }, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index d68f8c85ad1..cf17c6465f3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -34,7 +34,7 @@ public abstract class Context implements EvaluationContext<Reference> { @Override public TensorType getType(String reference) { - throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + throw new UnsupportedOperationException("Not able to parse general references from string form"); } /** Returns a variable as a tensor */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 4ffd40f00f7..18f1fa8a78f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.sql.Ref; import java.util.ArrayList; import java.util.Collections; import java.util.Deque; @@ -81,21 +82,22 @@ public class TensorFunctionNode extends CompositeNode { return new ExpressionTensorFunction(node); } - public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) { - Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>(); + public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<>(); for (var entry : nodes.entrySet()) - functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue())); + functions.put(entry.getKey(), wrapScalar(entry.getValue())); return functions; } - public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) { - List<ScalarFunction> functions = new ArrayList<>(); - for (var entry : nodes) - functions.add(new ExpressionScalarFunction(entry)); - return functions; + public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { + return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); + } + + public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) { + return new ExpressionScalarFunction(node); } - private static class ExpressionScalarFunction implements ScalarFunction { + private static class ExpressionScalarFunction implements ScalarFunction<Reference> { private final ExpressionNode expression; @@ -104,8 +106,8 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public Double apply(EvaluationContext<?> context) { - return expression.evaluate((Context)context).asDouble(); + public Double apply(EvaluationContext<Reference> context) { + return expression.evaluate(new ContextWrapper(context)).asDouble(); } @Override @@ -209,4 +211,26 @@ public class TensorFunctionNode extends CompositeNode { } + /** Turns an EvaluationContext into a Context */ + // TODO: We should be able to change RankingExpression.evaluate to take an EvaluationContext and then get rid of this + private static class ContextWrapper extends Context { + + private final EvaluationContext<Reference> delegate; + + public ContextWrapper(EvaluationContext<Reference> delegate) { + this.delegate = delegate; + } + + @Override + public Value get(String name) { + return new TensorValue(delegate.getTensor(name)); + } + + @Override + public TensorType getType(Reference name) { + return delegate.getType(name); + } + + } + } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 3e9649cd9c6..beab722a1eb 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -487,7 +487,7 @@ TensorFunctionNode tensorGenerateBody(TensorType type) : } { <LBRACE> generator = expression() <RBRACE> - { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } + { return new TensorFunctionNode(Generate.bound(type, TensorFunctionNode.wrapScalar(generator))); } } TensorFunctionNode tensorRange() : @@ -847,7 +847,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) : ( tensorCell(type, cells))* ( <COMMA> tensorCell(type, cells))* <RCURLY> - { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } DynamicTensor indexedTensorValueBody(TensorType type) : @@ -860,7 +860,7 @@ DynamicTensor indexedTensorValueBody(TensorType type) : ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* // <RSQUARE> - { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } void tensorCell(TensorType type, java.util.Map cells) : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index a8afc230bde..05ad8c97c7f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -358,6 +358,9 @@ public class EvaluationTestCase { tester.assertEvaluates("500", "join(tensor0, tensor1, f(x,y) (x*y)){tag2}", "tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}"); + tester.assertEvaluates("tensor(j[3]):[3, 3, 3]", + "tensor(j[3])(tensor0[2])", + "tensor(values[5]):[1, 2, 3, 4, 5]"); // tensor result dimensions are given from argument dimensions, not the resulting values tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }"); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 15cfce09793..8cba1ccdef8 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1620,7 +1620,8 @@ ], "methods": [ "public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)", - "public void <init>(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", + "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)", + "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index b8b644f8b49..a75e49c6402 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -73,6 +73,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) @@ -114,6 +115,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override + @SuppressWarnings("unchecked") public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 28fc2c61426..52620814ecd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -27,28 +27,33 @@ public class Generate extends PrimitiveTensorFunction { private final Function<List<Long>, Double> freeGenerator; private final ScalarFunction boundGenerator; + /** The same as Generate.free */ + public Generate(TensorType type, Function<List<Long>, Double> generator) { + this(type, Objects.requireNonNull(generator), null); + } + /** - * Creates a generated tensor + * Creates a generated tensor from a free function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function<List<Long>, Double> generator) { - this(type, Objects.requireNonNull(generator), null); + public static Generate free(TensorType type, Function<List<Long>, Double> generator) { + return new Generate(type, Objects.requireNonNull(generator), null); } /** - * Creates a generated tensor + * Creates a generated tensor from a bound function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, ScalarFunction generator) { - this(type, null, Objects.requireNonNull(generator)); + public static Generate bound(TensorType type, ScalarFunction generator) { + return new Generate(type, null, Objects.requireNonNull(generator)); } private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction boundGenerator) { @@ -127,6 +132,7 @@ public class Generate extends PrimitiveTensorFunction { this.context = context; } + @SuppressWarnings("unchecked") double apply(IndexedTensor.Indexes indexes) { if (freeGenerator != null) { return freeGenerator.apply(indexes.toList()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index c6a244b64df..70e08af16b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import java.util.function.Function; @@ -10,10 +11,10 @@ import java.util.function.Function; * * @author bratseth */ -public interface ScalarFunction extends Function<EvaluationContext<?>, Double> { +public interface ScalarFunction<NAMETYPE extends TypeContext.Name> extends Function<EvaluationContext<NAMETYPE>, Double> { @Override - Double apply(EvaluationContext<?> context); + Double apply(EvaluationContext<NAMETYPE> context); default String toString(ToStringContext context) { return toString(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 925da9d3c89..e1ae7f13c48 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; import org.junit.Test; import java.util.Collections; @@ -34,14 +35,14 @@ public class DynamicTensorTestCase { assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } - private static class Constant implements ScalarFunction { + private static class Constant implements ScalarFunction<TypeContext.Name> { private final double value; public Constant(double value) { this.value = value; } @Override - public Double apply(EvaluationContext<?> evaluationContext) { return value; } + public Double apply(EvaluationContext<TypeContext.Name> evaluationContext) { return value; } @Override public String toString() { return String.valueOf(value); } |