From bcd003d3253a5e51c19149dcc8fa44e8fd526adb Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 6 Oct 2021 10:52:56 +0200 Subject: Add non-primitive tensor expand function --- searchlib/abi-spec.json | 2 + .../src/main/javacc/RankingExpressionParser.jj | 15 ++++++- .../evaluation/EvaluationTestCase.java | 16 +++++++- vespajlib/abi-spec.json | 34 ++++++++++++++- .../java/com/yahoo/tensor/functions/Expand.java | 48 ++++++++++++++++++++++ .../yahoo/tensor/functions/ScalarFunctions.java | 16 ++++++++ 6 files changed, 128 insertions(+), 3 deletions(-) create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 2468fd0c5c7..4ebca94734f 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -898,6 +898,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorExpand()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()", @@ -1053,6 +1054,7 @@ "public static final int ARGMAX", "public static final int ARGMIN", "public static final int CELL_CAST", + "public static final int EXPAND", "public static final int AVG", "public static final int COUNT", "public static final int MAX", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 7bfbfd6c005..88eb0feeb73 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -144,6 +144,7 @@ TOKEN : | | | + | | | @@ -384,7 +385,8 @@ TensorFunctionNode tensorFunction() : tensorExpression = tensorXwPlusB() | tensorExpression = tensorArgmax() | tensorExpression = tensorArgmin() | - tensorExpression = tensorCellCast() + tensorExpression = tensorCellCast() | + tensorExpression = tensorExpand() ) { return tensorExpression; } } @@ -581,6 +583,16 @@ TensorFunctionNode tensorXwPlusB() : dimension)); } } +TensorFunctionNode tensorExpand() : +{ + ExpressionNode argument; + String dimension; +} +{ + argument = expression() dimension = identifier() + { return new TensorFunctionNode(new Expand(TensorFunctionNode.wrap(argument), dimension)); } +} + TensorFunctionNode tensorArgmax() : { ExpressionNode tensor; @@ -696,6 +708,7 @@ String tensorFunctionName() : ( { return token.image; } ) | ( { return token.image; } ) | ( { return token.image; } ) | + ( { return token.image; } ) | ( aggregator = tensorReduceAggregator() { return aggregator.toString(); } ) } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 246dbcb2b1e..ed8a15ad989 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -498,7 +498,6 @@ public class EvaluationTestCase { "tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })", "tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]", "tensor(d0[4]):[3,2,-1,1]"); - } @Test @@ -725,6 +724,21 @@ public class EvaluationTestCase { tester.assertEvaluates("tensor(d0[1], d1[3]):[1, 2, 3]", "tensor0 * tensor(d0[1])(1)", "tensor(d1[3]):[1, 2, 3]"); + // Add using the "expand" non-primitive function + tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor(d1[3]):[1, 2, 3]"); + tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]", + "expand(tensor0, d0)", + "tensor(d1[3]):[1, 2, 3]"); + } + + @Test + public void test() throws ParseException { + RankingExpression expr = new RankingExpression("expand(tensor(d1[3]):[1,2,3], d0)"); + System.out.println(expr); + Tensor t = expr.evaluate(new MapContext()).asTensor(); + System.out.println(t); } @Test diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index e68a37b15b6..8387a3d1a83 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1699,6 +1699,21 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.Expand": { + "superClass": "com.yahoo.tensor.functions.CompositeTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void (com.yahoo.tensor.functions.TensorFunction, java.lang.String)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.Generate": { "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", "interfaces": [], @@ -2053,6 +2068,22 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ScalarFunctions$Constant": { + "superClass": "java.lang.Object", + "interfaces": [ + "java.util.function.Function" + ], + "attributes": [ + "public" + ], + "methods": [ + "public void (double)", + "public java.lang.Double apply(java.util.List)", + "public java.lang.String toString()", + "public bridge synthetic java.lang.Object apply(java.lang.Object)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ScalarFunctions$Cos": { "superClass": "java.lang.Object", "interfaces": [ @@ -2603,7 +2634,8 @@ "public static java.util.function.DoubleUnaryOperator selu(double, double)", "public static java.util.function.Function random()", "public static java.util.function.Function equal(java.util.List)", - "public static java.util.function.Function sum(java.util.List)" + "public static java.util.function.Function sum(java.util.List)", + "public static java.util.function.Function constant(double)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java new file mode 100644 index 00000000000..8fc246a7d9d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; + +import java.util.Collections; +import java.util.List; + +/** + * The expand tensor function returns a tensor with a new dimension of + * size 1 is added, equivalent to "tensor * tensor(dim_name[1])(1)". + * + * @author lesters + */ +public class Expand extends CompositeTensorFunction { + + private final TensorFunction argument; + private final String dimensionName; + + public Expand(TensorFunction argument, String dimension) { + this.argument = argument; + this.dimensionName = dimension; + } + + @Override + public List> arguments() { return Collections.singletonList(argument); } + + @Override + public TensorFunction withArguments(List> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Expand must have 1 argument, got " + arguments.size()); + return new Expand<>(arguments.get(0), dimensionName); + } + + @Override + public PrimitiveTensorFunction toPrimitive() { + TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimensionName, 1).build(); + Generate expansion = new Generate<>(type, ScalarFunctions.constant(1.0)); + return new Join<>(expansion, argument, ScalarFunctions.multiply()); + } + + @Override + public String toString(ToStringContext context) { + return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index d6fcd17b8fb..4c2b64244e5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -66,6 +66,7 @@ public class ScalarFunctions { public static Function, Double> random() { return new Random(); } public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); } public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); } + public static Function, Double> constant(double value) { return new Constant(value); } // Binary operators ----------------------------------------------------------------------------- @@ -493,4 +494,19 @@ public class ScalarFunctions { } } + public static class Constant implements Function, Double> { + private final double value; + + public Constant(double value) { + this.value = value; + } + @Override + public Double apply(List values) { + return value; + } + @Override + public String toString() { return Double.toString(value); } + } + + } -- cgit v1.2.3 From 1fdf04b093030529563020ccea7894f24848d2a0 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 6 Oct 2021 10:55:35 +0200 Subject: Remove temporary test case --- .../rankingexpression/evaluation/EvaluationTestCase.java | 8 -------- 1 file changed, 8 deletions(-) diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index ed8a15ad989..10c835b05f2 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -733,14 +733,6 @@ public class EvaluationTestCase { "tensor(d1[3]):[1, 2, 3]"); } - @Test - public void test() throws ParseException { - RankingExpression expr = new RankingExpression("expand(tensor(d1[3]):[1,2,3], d0)"); - System.out.println(expr); - Tensor t = expr.evaluate(new MapContext()).asTensor(); - System.out.println(t); - } - @Test public void testProgrammaticBuildingAndPrecedence() { RankingExpression standardPrecedence = new RankingExpression(new ArithmeticNode(constant(2), ArithmeticOperator.PLUS, new ArithmeticNode(constant(3), ArithmeticOperator.MULTIPLY, constant(4)))); -- cgit v1.2.3 From 393b8284e615e251d1477ccbeee88398ed3388fd Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 6 Oct 2021 11:07:14 +0200 Subject: Add expand function to tensor class --- vespajlib/abi-spec.json | 1 + .../src/main/java/com/yahoo/tensor/Tensor.java | 24 ++++++---------------- 2 files changed, 7 insertions(+), 18 deletions(-) diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 8387a3d1a83..c426195bc37 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1180,6 +1180,7 @@ "public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)", "public com.yahoo.tensor.Tensor softmax(java.lang.String)", "public com.yahoo.tensor.Tensor xwPlusB(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor, java.lang.String)", + "public com.yahoo.tensor.Tensor expand(java.lang.String)", "public com.yahoo.tensor.Tensor argmax(java.lang.String)", "public com.yahoo.tensor.Tensor argmin(java.lang.String)", "public static com.yahoo.tensor.Tensor diag(com.yahoo.tensor.TensorType)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 3d4536d9249..c9be90bf20b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -1,25 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Argmax; -import com.yahoo.tensor.functions.Argmin; -import com.yahoo.tensor.functions.CellCast; -import com.yahoo.tensor.functions.Concat; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Diag; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.L1Normalize; -import com.yahoo.tensor.functions.L2Normalize; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Merge; -import com.yahoo.tensor.functions.Random; -import com.yahoo.tensor.functions.Range; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.XwPlusB; +import com.yahoo.tensor.functions.*; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -210,6 +194,10 @@ public interface Tensor { return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate(); } + default Tensor expand(String dimension) { + return new Expand<>(new ConstantTensor<>(this), dimension).evaluate(); + } + default Tensor argmax(String dimension) { return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate(); } -- cgit v1.2.3 From 566371a60a5cf3507f2fe2f58b5cfe0763089169 Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 6 Oct 2021 11:09:24 +0200 Subject: Organize imports --- .../src/main/java/com/yahoo/tensor/Tensor.java | 24 +++++++++++++++++----- 1 file changed, 19 insertions(+), 5 deletions(-) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index c9be90bf20b..047844113ff 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -1,10 +1,25 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; -import com.yahoo.tensor.evaluation.Name; -import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.*; -import com.yahoo.text.Ascii7BitMatcher; +import com.yahoo.tensor.functions.Argmax; +import com.yahoo.tensor.functions.Argmin; +import com.yahoo.tensor.functions.CellCast; +import com.yahoo.tensor.functions.Concat; +import com.yahoo.tensor.functions.ConstantTensor; +import com.yahoo.tensor.functions.Diag; +import com.yahoo.tensor.functions.Generate; +import com.yahoo.tensor.functions.Join; +import com.yahoo.tensor.functions.L1Normalize; +import com.yahoo.tensor.functions.L2Normalize; +import com.yahoo.tensor.functions.Matmul; +import com.yahoo.tensor.functions.Merge; +import com.yahoo.tensor.functions.Random; +import com.yahoo.tensor.functions.Range; +import com.yahoo.tensor.functions.Reduce; +import com.yahoo.tensor.functions.Rename; +import com.yahoo.tensor.functions.Softmax; +import com.yahoo.tensor.functions.XwPlusB; +import com.yahoo.tensor.functions.Expand; import java.util.ArrayList; import java.util.Arrays; @@ -19,7 +34,6 @@ import java.util.function.DoubleUnaryOperator; import java.util.function.Function; import java.util.stream.Collectors; -import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers; import static com.yahoo.tensor.functions.ScalarFunctions.Hamming; /** -- cgit v1.2.3