summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-06-26 12:29:45 +0000
committerArne Juul <arnej@yahooinc.com>2023-06-26 13:06:51 +0000
commitcc517d86dc886058cdc5f95a318945a6a328da28 (patch)
treecef1ca84628800fc8226ccb30626688a3efdf49b /searchlib
parent626bcc6c265229d8c97f4e0a1c996013650b335e (diff)
add cosine_similarity
Diffstat (limited to 'searchlib')
-rw-r--r--searchlib/abi-spec.json2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj15
2 files changed, 17 insertions, 0 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 30f2cb5c6ea..7d6f2f8790c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -947,6 +947,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorEuclideanDistance()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCosineSimilarity()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
@@ -1100,6 +1101,7 @@
"public static final int L1_NORMALIZE",
"public static final int L2_NORMALIZE",
"public static final int EUCLIDEAN_DISTANCE",
+ "public static final int COSINE_SIMILARITY",
"public static final int MATMUL",
"public static final int SOFTMAX",
"public static final int XW_PLUS_B",
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 744e629893e..41647a5ef5b 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -139,6 +139,7 @@ TOKEN :
<L1_NORMALIZE: "l1_normalize"> |
<L2_NORMALIZE: "l2_normalize"> |
<EUCLIDEAN_DISTANCE: "euclidean_distance"> |
+ <COSINE_SIMILARITY: "cosine_similarity"> |
<MATMUL: "matmul"> |
<SOFTMAX: "softmax"> |
<XW_PLUS_B: "xw_plus_b"> |
@@ -381,6 +382,7 @@ TensorFunctionNode tensorFunction() :
tensorExpression = tensorL1Normalize() |
tensorExpression = tensorL2Normalize() |
tensorExpression = tensorEuclideanDistance() |
+ tensorExpression = tensorCosineSimilarity() |
tensorExpression = tensorMatmul() |
tensorExpression = tensorSoftmax() |
tensorExpression = tensorXwPlusB() |
@@ -558,6 +560,18 @@ TensorFunctionNode tensorEuclideanDistance() :
dimension)); }
}
+TensorFunctionNode tensorCosineSimilarity() :
+{
+ ExpressionNode tensor1, tensor2;
+ String dimension;
+}
+{
+ <COSINE_SIMILARITY> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
+ { return new TensorFunctionNode(new CosineSimilarity(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
+}
+
TensorFunctionNode tensorMatmul() :
{
ExpressionNode tensor1, tensor2;
@@ -716,6 +730,7 @@ String tensorFunctionName() :
( <L1_NORMALIZE> { return token.image; } ) |
( <L2_NORMALIZE> { return token.image; } ) |
( <EUCLIDEAN_DISTANCE> { return token.image; } ) |
+ ( <COSINE_SIMILARITY> { return token.image; } ) |
( <MATMUL> { return token.image; } ) |
( <SOFTMAX> { return token.image; } ) |
( <XW_PLUS_B> { return token.image; } ) |