diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-06-26 12:29:45 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-06-26 13:06:51 +0000 |
commit | cc517d86dc886058cdc5f95a318945a6a328da28 (patch) | |
tree | cef1ca84628800fc8226ccb30626688a3efdf49b /searchlib | |
parent | 626bcc6c265229d8c97f4e0a1c996013650b335e (diff) |
add cosine_similarity
Diffstat (limited to 'searchlib')
-rw-r--r-- | searchlib/abi-spec.json | 2 | ||||
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 15 |
2 files changed, 17 insertions, 0 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 30f2cb5c6ea..7d6f2f8790c 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -947,6 +947,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorEuclideanDistance()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCosineSimilarity()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()", @@ -1100,6 +1101,7 @@ "public static final int L1_NORMALIZE", "public static final int L2_NORMALIZE", "public static final int EUCLIDEAN_DISTANCE", + "public static final int COSINE_SIMILARITY", "public static final int MATMUL", "public static final int SOFTMAX", "public static final int XW_PLUS_B", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 744e629893e..41647a5ef5b 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -139,6 +139,7 @@ TOKEN : <L1_NORMALIZE: "l1_normalize"> | <L2_NORMALIZE: "l2_normalize"> | <EUCLIDEAN_DISTANCE: "euclidean_distance"> | + <COSINE_SIMILARITY: "cosine_similarity"> | <MATMUL: "matmul"> | <SOFTMAX: "softmax"> | <XW_PLUS_B: "xw_plus_b"> | @@ -381,6 +382,7 @@ TensorFunctionNode tensorFunction() : tensorExpression = tensorL1Normalize() | tensorExpression = tensorL2Normalize() | tensorExpression = tensorEuclideanDistance() | + tensorExpression = tensorCosineSimilarity() | tensorExpression = tensorMatmul() | tensorExpression = tensorSoftmax() | tensorExpression = tensorXwPlusB() | @@ -558,6 +560,18 @@ TensorFunctionNode tensorEuclideanDistance() : dimension)); } } +TensorFunctionNode tensorCosineSimilarity() : +{ + ExpressionNode tensor1, tensor2; + String dimension; +} +{ + <COSINE_SIMILARITY> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> + { return new TensorFunctionNode(new CosineSimilarity(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + dimension)); } +} + TensorFunctionNode tensorMatmul() : { ExpressionNode tensor1, tensor2; @@ -716,6 +730,7 @@ String tensorFunctionName() : ( <L1_NORMALIZE> { return token.image; } ) | ( <L2_NORMALIZE> { return token.image; } ) | ( <EUCLIDEAN_DISTANCE> { return token.image; } ) | + ( <COSINE_SIMILARITY> { return token.image; } ) | ( <MATMUL> { return token.image; } ) | ( <SOFTMAX> { return token.image; } ) | ( <XW_PLUS_B> { return token.image; } ) | |