aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2021-10-06 11:25:36 +0200
committerGitHub <noreply@github.com>2021-10-06 11:25:36 +0200
commit95c5bf68acbbc48749941b12d1754a7f6998bece (patch)
tree58664a4c3a8fde314a38ee95dc7b745485eb0686
parent4de0026c1065403d028d7157abb571830603e6c9 (diff)
parent566371a60a5cf3507f2fe2f58b5cfe0763089169 (diff)
Merge pull request #19430 from vespa-engine/lesters/add-tensor-expand
Add non-primitive tensor expand function
-rw-r--r--searchlib/abi-spec.json2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj15
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java8
-rw-r--r--vespajlib/abi-spec.json35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java48
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java16
7 files changed, 126 insertions, 6 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 2468fd0c5c7..4ebca94734f 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -898,6 +898,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorExpand()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()",
@@ -1053,6 +1054,7 @@
"public static final int ARGMAX",
"public static final int ARGMIN",
"public static final int CELL_CAST",
+ "public static final int EXPAND",
"public static final int AVG",
"public static final int COUNT",
"public static final int MAX",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 7bfbfd6c005..88eb0feeb73 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -144,6 +144,7 @@ TOKEN :
<ARGMAX: "argmax"> |
<ARGMIN: "argmin"> |
<CELL_CAST: "cell_cast"> |
+ <EXPAND: "expand"> |
<AVG: "avg" > |
<COUNT: "count"> |
@@ -384,7 +385,8 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorXwPlusB() |
tensorExpression = tensorArgmax() |
tensorExpression = tensorArgmin() |
- tensorExpression = tensorCellCast()
+ tensorExpression = tensorCellCast() |
+ tensorExpression = tensorExpand()
)
{ return tensorExpression; }
}
@@ -581,6 +583,16 @@ TensorFunctionNode tensorXwPlusB() :
dimension)); }
}
+TensorFunctionNode tensorExpand() :
+{
+ ExpressionNode argument;
+ String dimension;
+}
+{
+ <EXPAND> <LBRACE> argument = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new Expand(TensorFunctionNode.wrap(argument), dimension)); }
+}
+
TensorFunctionNode tensorArgmax() :
{
ExpressionNode tensor;
@@ -696,6 +708,7 @@ String tensorFunctionName() :
( <ARGMAX> { return token.image; } ) |
( <ARGMIN> { return token.image; } ) |
( <CELL_CAST> { return token.image; } ) |
+ ( <EXPAND> { return token.image; } ) |
( aggregator = tensorReduceAggregator() { return aggregator.toString(); } )
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 246dbcb2b1e..10c835b05f2 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -498,7 +498,6 @@ public class EvaluationTestCase {
"tensor(d0[3],d1[2],d2[1],d3[1])(tensor0{a0:0, a1:((d0 * 2 + d1) / 3), a2:((d0 * 2 + d1) % 3) })",
"tensor(a0[1],a1[2],a2[3]):[1,2,3,4,5,6]",
"tensor(d0[4]):[3,2,-1,1]");
-
}
@Test
@@ -725,6 +724,13 @@ public class EvaluationTestCase {
tester.assertEvaluates("tensor(d0[1], d1[3]):[1, 2, 3]",
"tensor0 * tensor(d0[1])(1)",
"tensor(d1[3]):[1, 2, 3]");
+ // Add using the "expand" non-primitive function
+ tester.assertEvaluates("tensor(d0[1],d1[3]):[[1,2,3]]",
+ "expand(tensor0, d0)",
+ "tensor(d1[3]):[1, 2, 3]");
+ tester.assertEvaluates("tensor<float>(d0[1],d1[3]):[[1,2,3]]",
+ "expand(tensor0, d0)",
+ "tensor<float>(d1[3]):[1, 2, 3]");
}
@Test
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index e68a37b15b6..c426195bc37 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1180,6 +1180,7 @@
"public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)",
"public com.yahoo.tensor.Tensor softmax(java.lang.String)",
"public com.yahoo.tensor.Tensor xwPlusB(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor, java.lang.String)",
+ "public com.yahoo.tensor.Tensor expand(java.lang.String)",
"public com.yahoo.tensor.Tensor argmax(java.lang.String)",
"public com.yahoo.tensor.Tensor argmin(java.lang.String)",
"public static com.yahoo.tensor.Tensor diag(com.yahoo.tensor.TensorType)",
@@ -1699,6 +1700,21 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.Expand": {
+ "superClass": "com.yahoo.tensor.functions.CompositeTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.lang.String)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.Generate": {
"superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
"interfaces": [],
@@ -2053,6 +2069,22 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.ScalarFunctions$Constant": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "java.util.function.Function"
+ ],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(double)",
+ "public java.lang.Double apply(java.util.List)",
+ "public java.lang.String toString()",
+ "public bridge synthetic java.lang.Object apply(java.lang.Object)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.ScalarFunctions$Cos": {
"superClass": "java.lang.Object",
"interfaces": [
@@ -2603,7 +2635,8 @@
"public static java.util.function.DoubleUnaryOperator selu(double, double)",
"public static java.util.function.Function random()",
"public static java.util.function.Function equal(java.util.List)",
- "public static java.util.function.Function sum(java.util.List)"
+ "public static java.util.function.Function sum(java.util.List)",
+ "public static java.util.function.Function constant(double)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3d4536d9249..047844113ff 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -1,7 +1,6 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
-import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
import com.yahoo.tensor.functions.CellCast;
@@ -20,7 +19,7 @@ import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.Softmax;
import com.yahoo.tensor.functions.XwPlusB;
-import com.yahoo.text.Ascii7BitMatcher;
+import com.yahoo.tensor.functions.Expand;
import java.util.ArrayList;
import java.util.Arrays;
@@ -35,7 +34,6 @@ import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
import java.util.stream.Collectors;
-import static com.yahoo.text.Ascii7BitMatcher.charsAndNumbers;
import static com.yahoo.tensor.functions.ScalarFunctions.Hamming;
/**
@@ -210,6 +208,10 @@ public interface Tensor {
return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate();
}
+ default Tensor expand(String dimension) {
+ return new Expand<>(new ConstantTensor<>(this), dimension).evaluate();
+ }
+
default Tensor argmax(String dimension) {
return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
new file mode 100644
index 00000000000..8fc246a7d9d
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java
@@ -0,0 +1,48 @@
+// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
+
+import java.util.Collections;
+import java.util.List;
+
+/**
+ * The <i>expand</i> tensor function returns a tensor with a new dimension of
+ * size 1 is added, equivalent to "tensor * tensor(dim_name[1])(1)".
+ *
+ * @author lesters
+ */
+public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final String dimensionName;
+
+ public Expand(TensorFunction<NAMETYPE> argument, String dimension) {
+ this.argument = argument;
+ this.dimensionName = dimension;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 1)
+ throw new IllegalArgumentException("Expand must have 1 argument, got " + arguments.size());
+ return new Expand<>(arguments.get(0), dimensionName);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimensionName, 1).build();
+ Generate<NAMETYPE> expansion = new Generate<>(type, ScalarFunctions.constant(1.0));
+ return new Join<>(expansion, argument, ScalarFunctions.multiply());
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return "expand(" + argument.toString(context) + ", " + dimensionName + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
index d6fcd17b8fb..4c2b64244e5 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java
@@ -66,6 +66,7 @@ public class ScalarFunctions {
public static Function<List<Long>, Double> random() { return new Random(); }
public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); }
public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); }
+ public static Function<List<Long>, Double> constant(double value) { return new Constant(value); }
// Binary operators -----------------------------------------------------------------------------
@@ -493,4 +494,19 @@ public class ScalarFunctions {
}
}
+ public static class Constant implements Function<List<Long>, Double> {
+ private final double value;
+
+ public Constant(double value) {
+ this.value = value;
+ }
+ @Override
+ public Double apply(List<Long> values) {
+ return value;
+ }
+ @Override
+ public String toString() { return Double.toString(value); }
+ }
+
+
}