summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne H Juul <arnej27959@users.noreply.github.com>2023-06-26 14:30:53 +0200
committerGitHub <noreply@github.com>2023-06-26 14:30:53 +0200
commit626bcc6c265229d8c97f4e0a1c996013650b335e (patch)
tree8042b0897155d1049d2bbc2ea20dc68ff3bda03b /searchlib
parent0c341f8ed39b3edcd1938d964cbdf9ce7c179411 (diff)
parent9faebe628164657eaad3de625b9b799a385aea6e (diff)
Merge pull request #27544 from vespa-engine/arnej/add-euclidean-distance
add euclidean_distance
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj17
2 files changed, 18 insertions, 1 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index f3fe86e261f..30f2cb5c6ea 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -946,6 +946,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRandom()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorEuclideanDistance()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
@@ -1098,6 +1099,7 @@
"public static final int RANDOM",
"public static final int L1_NORMALIZE",
"public static final int L2_NORMALIZE",
+ "public static final int EUCLIDEAN_DISTANCE",
"public static final int MATMUL",
"public static final int SOFTMAX",
"public static final int XW_PLUS_B",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 42b5f2c191a..744e629893e 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -138,6 +138,7 @@ TOKEN :
<RANDOM: "random"> |
<L1_NORMALIZE: "l1_normalize"> |
<L2_NORMALIZE: "l2_normalize"> |
+ <EUCLIDEAN_DISTANCE: "euclidean_distance"> |
<MATMUL: "matmul"> |
<SOFTMAX: "softmax"> |
<XW_PLUS_B: "xw_plus_b"> |
@@ -379,6 +380,7 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorRandom() |
tensorExpression = tensorL1Normalize() |
tensorExpression = tensorL2Normalize() |
+ tensorExpression = tensorEuclideanDistance() |
tensorExpression = tensorMatmul() |
tensorExpression = tensorSoftmax() |
tensorExpression = tensorXwPlusB() |
@@ -544,6 +546,18 @@ TensorFunctionNode tensorL2Normalize() :
{ return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
+TensorFunctionNode tensorEuclideanDistance() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <EUCLIDEAN_DISTANCE> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new EuclideanDistance(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
+}
+
TensorFunctionNode tensorMatmul() :
{
ExpressionNode tensor1, tensor2;
@@ -701,6 +715,7 @@ String tensorFunctionName() :
( <RANDOM> { return token.image; } ) |
( <L1_NORMALIZE> { return token.image; } ) |
( <L2_NORMALIZE> { return token.image; } ) |
+ ( <EUCLIDEAN_DISTANCE> { return token.image; } ) |
( <MATMUL> { return token.image; } ) |
( <SOFTMAX> { return token.image; } ) |
( <XW_PLUS_B> { return token.image; } ) |
@@ -1041,4 +1056,4 @@ String label() :
String string() : {}
{
<STRING> { return token.image.substring(1, token.image.length() - 1); }
-} \ No newline at end of file
+}