diff options
author | Arne Juul <arnej@yahooinc.com> | 2023-11-10 10:02:02 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-11-10 10:14:23 +0000 |
commit | 1ec05cfed625e6395b8d3346c4f15b5bc7507dcf (patch) | |
tree | cbe16302842983afd9192b78a745532e419cd04a | |
parent | 27c62ead49c88382f128034c48e22d77a83f8104 (diff) |
unpack_bits_from_int8 -> unpack_bits
-rw-r--r-- | config-model/src/test/derived/tensor/tensor.sd | 2 | ||||
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 4 | ||||
-rw-r--r-- | searchlib/abi-spec.json | 8 | ||||
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java) | 11 | ||||
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 10 | ||||
-rw-r--r-- | searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java | 4 |
6 files changed, 21 insertions, 18 deletions
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index 81230e5c54c..3a5fda3ac5d 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -128,7 +128,7 @@ schema tensor { query(qvec) tensor<float>(x[40]) } function myunpack() { - expression: unpack_bits_from_int8(attribute(f7)) + expression: unpack_bits(attribute(f7)) } first-phase { expression: sum(query(para)*myunpack*query(qvec)) diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index a3273abff57..d42ec629bf7 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -13,7 +13,7 @@ import com.yahoo.language.process.Embedder; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; -import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8; +import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode; import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; @@ -284,7 +284,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } public static Tensor expandBitTensor(Tensor packed) { - var unpacker = new UnpackBitsFromInt8(new ReferenceNode("input"), TensorType.Value.FLOAT, "big"); + var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.FLOAT, "big"); var context = new MapContext(); context.put("input", new TensorValue(packed)); return unpacker.evaluate(context).asTensor(); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 4f0a99a117d..0b1cb7a103c 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -954,7 +954,7 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMacro()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorUnpackBitsFromInt8()", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorUnpackBits()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", "public final com.yahoo.tensor.TensorType tensorType(java.util.List)", @@ -1088,7 +1088,7 @@ "public static final int HAMMING", "public static final int MAP", "public static final int MAP_SUBSPACES", - "public static final int UNPACK_BITS_FROM_INT8", + "public static final int UNPACK_BITS", "public static final int REDUCE", "public static final int JOIN", "public static final int MERGE", @@ -1711,7 +1711,7 @@ ], "fields" : [ ] }, - "com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8" : { + "com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode" : { "superClass" : "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", "interfaces" : [ ], "attributes" : [ @@ -1728,4 +1728,4 @@ ], "fields" : [ ] } -}
\ No newline at end of file +} diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java index 84203da4a7e..467a7860053 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java @@ -24,9 +24,9 @@ import java.util.Objects; * @author arnej */ @Beta -public class UnpackBitsFromInt8 extends CompositeNode { +public class UnpackBitsNode extends CompositeNode { - private static String operationName = "unpack_bits_from_int8"; + private static String operationName = "unpack_bits"; private enum EndianNess { BIG_ENDIAN("big"), LITTLE_ENDIAN("little"); @@ -47,7 +47,7 @@ public class UnpackBitsFromInt8 extends CompositeNode { final TensorType.Value targetCellType; final EndianNess endian; - public UnpackBitsFromInt8(ExpressionNode input, TensorType.Value targetCellType, String endianNess) { + public UnpackBitsNode(ExpressionNode input, TensorType.Value targetCellType, String endianNess) { this.input = input; this.targetCellType = targetCellType; this.endian = EndianNess.fromId(endianNess); @@ -141,6 +141,9 @@ public class UnpackBitsFromInt8 extends CompositeNode { } private Meta analyze(TensorType inputType) { + if (inputType.valueType() != TensorType.Value.INT8) { + throw new IllegalArgumentException("bad " + operationName + "; input must have cell-type int8, but it was: " + inputType.valueType()); + } TensorType inputDenseType = inputType.indexedSubtype(); if (inputDenseType.rank() == 0) { throw new IllegalArgumentException("bad " + operationName + "; input must have indexed dimension, but type was: " + inputType); @@ -174,7 +177,7 @@ public class UnpackBitsFromInt8 extends CompositeNode { public CompositeNode setChildren(List<ExpressionNode> newChildren) { if (newChildren.size() != 1) throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size()); - return new UnpackBitsFromInt8(newChildren.get(0), targetCellType, endian.toString()); + return new UnpackBitsNode(newChildren.get(0), targetCellType, endian.toString()); } @Override diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 1da8a5ece89..42f8f846199 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -129,7 +129,7 @@ TOKEN : <MAP: "map"> | <MAP_SUBSPACES: "map_subspaces"> | - <UNPACK_BITS_FROM_INT8: "unpack_bits_from_int8"> | + <UNPACK_BITS: "unpack_bits"> | <REDUCE: "reduce"> | <JOIN: "join"> | <MERGE: "merge"> | @@ -676,23 +676,23 @@ ExpressionNode tensorMacro() : } { ( - tensorExpression = tensorUnpackBitsFromInt8() + tensorExpression = tensorUnpackBits() ) { return tensorExpression; } } -ExpressionNode tensorUnpackBitsFromInt8() : +ExpressionNode tensorUnpackBits() : { ExpressionNode tensor; String targetCellType = "float"; String endianNess = "big"; } { - <UNPACK_BITS_FROM_INT8> <LBRACE> tensor = expression() ( + <UNPACK_BITS> <LBRACE> tensor = expression() ( <COMMA> targetCellType = identifier() ( <COMMA> endianNess = identifier() )? )? <RBRACE> { - return new UnpackBitsFromInt8(tensor, TensorType.Value.fromId(targetCellType), endianNess); + return new UnpackBitsNode(tensor, TensorType.Value.fromId(targetCellType), endianNess); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 90b143e8f7f..9626059a42e 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -185,7 +185,7 @@ public class EvaluationTestCase { "],bar:[" + "0,0,0,0, 0,0,0,1," + "1,1,1,1, 1,0,0,0]}", - "unpack_bits_from_int8(tensor0, float, big)", + "unpack_bits(tensor0, float, big)", "tensor<int8>(a{},x[2]):{foo:[0,-1],bar:[1,-8]}"); tester.assertEvaluates("tensor<int8>(a{},x[16]):{foo:[" + @@ -194,7 +194,7 @@ public class EvaluationTestCase { "],bar:[" + "1,0,0,0, 0,0,0,0," + "0,0,0,1, 1,1,1,1]}", - "unpack_bits_from_int8(tensor0, int8, little)", + "unpack_bits(tensor0, int8, little)", "tensor<int8>(a{},x[2]):{foo:[0,-1],bar:[1,-8]}"); } |