summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-11-10 10:02:02 +0000
committerArne Juul <arnej@yahooinc.com>2023-11-10 10:14:23 +0000
commit1ec05cfed625e6395b8d3346c4f15b5bc7507dcf (patch)
treecbe16302842983afd9192b78a745532e419cd04a
parent27c62ead49c88382f128034c48e22d77a83f8104 (diff)
unpack_bits_from_int8 -> unpack_bits
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd2
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java4
-rw-r--r--searchlib/abi-spec.json8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java (renamed from searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java)11
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj10
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java4
6 files changed, 21 insertions, 18 deletions
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index 81230e5c54c..3a5fda3ac5d 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -128,7 +128,7 @@ schema tensor {
query(qvec) tensor<float>(x[40])
}
function myunpack() {
- expression: unpack_bits_from_int8(attribute(f7))
+ expression: unpack_bits(attribute(f7))
}
first-phase {
expression: sum(query(para)*myunpack*query(qvec))
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index a3273abff57..d42ec629bf7 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -13,7 +13,7 @@ import com.yahoo.language.process.Embedder;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
-import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8;
+import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;
import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
@@ -284,7 +284,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
public static Tensor expandBitTensor(Tensor packed) {
- var unpacker = new UnpackBitsFromInt8(new ReferenceNode("input"), TensorType.Value.FLOAT, "big");
+ var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.FLOAT, "big");
var context = new MapContext();
context.put("input", new TensorValue(packed));
return unpacker.evaluate(context).asTensor();
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 4f0a99a117d..0b1cb7a103c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -954,7 +954,7 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMacro()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorUnpackBitsFromInt8()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorUnpackBits()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
"public final com.yahoo.tensor.TensorType tensorType(java.util.List)",
@@ -1088,7 +1088,7 @@
"public static final int HAMMING",
"public static final int MAP",
"public static final int MAP_SUBSPACES",
- "public static final int UNPACK_BITS_FROM_INT8",
+ "public static final int UNPACK_BITS",
"public static final int REDUCE",
"public static final int JOIN",
"public static final int MERGE",
@@ -1711,7 +1711,7 @@
],
"fields" : [ ]
},
- "com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8" : {
+ "com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode" : {
"superClass" : "com.yahoo.searchlib.rankingexpression.rule.CompositeNode",
"interfaces" : [ ],
"attributes" : [
@@ -1728,4 +1728,4 @@
],
"fields" : [ ]
}
-} \ No newline at end of file
+}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java
index 84203da4a7e..467a7860053 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsNode.java
@@ -24,9 +24,9 @@ import java.util.Objects;
* @author arnej
*/
@Beta
-public class UnpackBitsFromInt8 extends CompositeNode {
+public class UnpackBitsNode extends CompositeNode {
- private static String operationName = "unpack_bits_from_int8";
+ private static String operationName = "unpack_bits";
private enum EndianNess {
BIG_ENDIAN("big"), LITTLE_ENDIAN("little");
@@ -47,7 +47,7 @@ public class UnpackBitsFromInt8 extends CompositeNode {
final TensorType.Value targetCellType;
final EndianNess endian;
- public UnpackBitsFromInt8(ExpressionNode input, TensorType.Value targetCellType, String endianNess) {
+ public UnpackBitsNode(ExpressionNode input, TensorType.Value targetCellType, String endianNess) {
this.input = input;
this.targetCellType = targetCellType;
this.endian = EndianNess.fromId(endianNess);
@@ -141,6 +141,9 @@ public class UnpackBitsFromInt8 extends CompositeNode {
}
private Meta analyze(TensorType inputType) {
+ if (inputType.valueType() != TensorType.Value.INT8) {
+ throw new IllegalArgumentException("bad " + operationName + "; input must have cell-type int8, but it was: " + inputType.valueType());
+ }
TensorType inputDenseType = inputType.indexedSubtype();
if (inputDenseType.rank() == 0) {
throw new IllegalArgumentException("bad " + operationName + "; input must have indexed dimension, but type was: " + inputType);
@@ -174,7 +177,7 @@ public class UnpackBitsFromInt8 extends CompositeNode {
public CompositeNode setChildren(List<ExpressionNode> newChildren) {
if (newChildren.size() != 1)
throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size());
- return new UnpackBitsFromInt8(newChildren.get(0), targetCellType, endian.toString());
+ return new UnpackBitsNode(newChildren.get(0), targetCellType, endian.toString());
}
@Override
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 1da8a5ece89..42f8f846199 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -129,7 +129,7 @@ TOKEN :
<MAP: "map"> |
<MAP_SUBSPACES: "map_subspaces"> |
- <UNPACK_BITS_FROM_INT8: "unpack_bits_from_int8"> |
+ <UNPACK_BITS: "unpack_bits"> |
<REDUCE: "reduce"> |
<JOIN: "join"> |
<MERGE: "merge"> |
@@ -676,23 +676,23 @@ ExpressionNode tensorMacro() :
}
{
(
- tensorExpression = tensorUnpackBitsFromInt8()
+ tensorExpression = tensorUnpackBits()
)
{ return tensorExpression; }
}
-ExpressionNode tensorUnpackBitsFromInt8() :
+ExpressionNode tensorUnpackBits() :
{
ExpressionNode tensor;
String targetCellType = "float";
String endianNess = "big";
}
{
- <UNPACK_BITS_FROM_INT8> <LBRACE> tensor = expression() (
+ <UNPACK_BITS> <LBRACE> tensor = expression() (
<COMMA> targetCellType = identifier() (
<COMMA> endianNess = identifier() )? )? <RBRACE>
{
- return new UnpackBitsFromInt8(tensor, TensorType.Value.fromId(targetCellType), endianNess);
+ return new UnpackBitsNode(tensor, TensorType.Value.fromId(targetCellType), endianNess);
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 90b143e8f7f..9626059a42e 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -185,7 +185,7 @@ public class EvaluationTestCase {
"],bar:[" +
"0,0,0,0, 0,0,0,1," +
"1,1,1,1, 1,0,0,0]}",
- "unpack_bits_from_int8(tensor0, float, big)",
+ "unpack_bits(tensor0, float, big)",
"tensor<int8>(a{},x[2]):{foo:[0,-1],bar:[1,-8]}");
tester.assertEvaluates("tensor<int8>(a{},x[16]):{foo:[" +
@@ -194,7 +194,7 @@ public class EvaluationTestCase {
"],bar:[" +
"1,0,0,0, 0,0,0,0," +
"0,0,0,1, 1,1,1,1]}",
- "unpack_bits_from_int8(tensor0, int8, little)",
+ "unpack_bits(tensor0, int8, little)",
"tensor<int8>(a{},x[2]):{foo:[0,-1],bar:[1,-8]}");
}