summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-11-05 12:20:09 +0000
committerArne Juul <arnej@yahooinc.com>2023-11-10 09:55:58 +0000
commit65047e9ad4d6138570e141159941ad9b81fdd41b (patch)
tree8c7a5a87be48dae2ce862061e87209129b432ebd /searchlib/src/main
parent83b1ccd36dd5df2e43307aab19adc07b41c94c9f (diff)
add "unpack_bits_from_int8" function
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java183
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj29
2 files changed, 211 insertions, 1 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java
new file mode 100644
index 00000000000..84203da4a7e
--- /dev/null
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java
@@ -0,0 +1,183 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.searchlib.rankingexpression.rule;
+
+import com.yahoo.api.annotations.Beta;
+import com.yahoo.searchlib.rankingexpression.Reference;
+import com.yahoo.searchlib.rankingexpression.evaluation.Context;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.Deque;
+import java.util.List;
+import java.util.Optional;
+import java.util.Objects;
+
+/**
+ * Macro that expands to the appropriate map_subspaces magic incantation
+ *
+ * @author arnej
+ */
+@Beta
+public class UnpackBitsFromInt8 extends CompositeNode {
+
+ private static String operationName = "unpack_bits_from_int8";
+ private enum EndianNess {
+ BIG_ENDIAN("big"), LITTLE_ENDIAN("little");
+
+ private final String id;
+ EndianNess(String id) { this.id = id; }
+ public String toString() { return id; }
+ public static EndianNess fromId(String id) {
+ for (EndianNess value : values()) {
+ if (value.id.equals(id)) {
+ return value;
+ }
+ }
+ throw new IllegalArgumentException("EndianNess must be either 'big' or 'little', but was '" + id + "'");
+ }
+ };
+
+ final ExpressionNode input;
+ final TensorType.Value targetCellType;
+ final EndianNess endian;
+
+ public UnpackBitsFromInt8(ExpressionNode input, TensorType.Value targetCellType, String endianNess) {
+ this.input = input;
+ this.targetCellType = targetCellType;
+ this.endian = EndianNess.fromId(endianNess);
+ }
+
+ @Override
+ public List<ExpressionNode> children() {
+ return Collections.singletonList(input);
+ }
+
+ private static record Meta(TensorType outputType, TensorType outputDenseType, String unpackDimension) {}
+
+ @Override
+ public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
+ var optTC = context.typeContext();
+ if (optTC.isPresent()) {
+ TensorType inputType = input.type(optTC.get());
+ var meta = analyze(inputType);
+ string.append("map_subspaces").append("(");
+ input.toString(string, context, path, this);
+ string.append(", f(denseSubspaceInput)(");
+ string.append(meta.outputDenseType()).append("("); // generate
+ string.append("bit(denseSubspaceInput{");
+ for (var dim : meta.outputDenseType().dimensions()) {
+ String dName = dim.name();
+ boolean last = dName.equals(meta.unpackDimension);
+ string.append(dName);
+ string.append(":(");
+ string.append(dName);
+ if (last) {
+ string.append("/8");
+ }
+ string.append(")");
+ if (! last) {
+ string.append(", ");
+ }
+ }
+ if (endian.equals(EndianNess.BIG_ENDIAN)) {
+ string.append("}, 7-(");
+ } else {
+ string.append("}, (");
+ }
+ string.append(meta.unpackDimension);
+ string.append(" % 8)");
+ string.append("))))"); // bit, generate, f, map_subspaces
+ } else {
+ string.append(operationName);
+ string.append("(");
+ input.toString(string, context, path, this);
+ string.append(",");
+ string.append(targetCellType);
+ string.append(",");
+ string.append(endian);
+ string.append(")");
+ }
+ return string;
+ }
+
+ @Override
+ public Value evaluate(Context context) {
+ Tensor inputTensor = input.evaluate(context).asTensor();
+ TensorType inputType = inputTensor.type();
+ var meta = analyze(inputType);
+ var builder = Tensor.Builder.of(meta.outputType());
+ for (var iter = inputTensor.cellIterator(); iter.hasNext(); ) {
+ var cell = iter.next();
+ var oldAddr = cell.getKey();
+ for (int bitIdx = 0; bitIdx < 8; bitIdx++) {
+ var addrBuilder = new TensorAddress.Builder(meta.outputType());
+ for (int i = 0; i < inputType.dimensions().size(); i++) {
+ var dim = inputType.dimensions().get(i);
+ if (dim.name().equals(meta.unpackDimension())) {
+ long newIdx = oldAddr.numericLabel(i) * 8 + bitIdx;
+ addrBuilder.add(dim.name(), String.valueOf(newIdx));
+ } else {
+ addrBuilder.add(dim.name(), oldAddr.label(i));
+ }
+ }
+ var newAddr = addrBuilder.build();
+ int oldValue = (int)(cell.getValue().doubleValue());
+ if (endian.equals(EndianNess.BIG_ENDIAN)) {
+ float newCellValue = 1 & (oldValue >>> (7 - bitIdx));
+ builder.cell(newAddr, newCellValue);
+ } else {
+ float newCellValue = 1 & (oldValue >>> bitIdx);
+ builder.cell(newAddr, newCellValue);
+ }
+ }
+ }
+ return new TensorValue(builder.build());
+ }
+
+ private Meta analyze(TensorType inputType) {
+ TensorType inputDenseType = inputType.indexedSubtype();
+ if (inputDenseType.rank() == 0) {
+ throw new IllegalArgumentException("bad " + operationName + "; input must have indexed dimension, but type was: " + inputType);
+ }
+ var lastDim = inputDenseType.dimensions().get(inputDenseType.rank() - 1);
+ if (lastDim.size().isEmpty()) {
+ throw new IllegalArgumentException("bad " + operationName + "; last indexed dimension must be bound, but type was: " + inputType);
+ }
+ List<TensorType.Dimension> outputDims = new ArrayList<>();
+ var ttBuilder = new TensorType.Builder(targetCellType);
+ for (var dim : inputType.dimensions()) {
+ if (dim.name().equals(lastDim.name())) {
+ long sz = dim.size().get();
+ ttBuilder.indexed(dim.name(), sz * 8);
+ } else {
+ ttBuilder.set(dim);
+ }
+ }
+ TensorType outputType = ttBuilder.build();
+ return new Meta(outputType, outputType.indexedSubtype(), lastDim.name());
+ }
+
+ @Override
+ public TensorType type(TypeContext<Reference> context) {
+ TensorType inputType = input.type(context);
+ var meta = analyze(inputType);
+ return meta.outputType();
+ }
+
+ @Override
+ public CompositeNode setChildren(List<ExpressionNode> newChildren) {
+ if (newChildren.size() != 1)
+ throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size());
+ return new UnpackBitsFromInt8(newChildren.get(0), targetCellType, endian.toString());
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash(operationName, input, targetCellType); }
+
+}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 6cd01151dc1..1da8a5ece89 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -129,6 +129,7 @@ TOKEN :
<MAP: "map"> |
<MAP_SUBSPACES: "map_subspaces"> |
+ <UNPACK_BITS_FROM_INT8: "unpack_bits_from_int8"> |
<REDUCE: "reduce"> |
<JOIN: "join"> |
<MERGE: "merge"> |
@@ -344,7 +345,7 @@ ExpressionNode function() :
ExpressionNode function;
}
{
- ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() )
+ ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() | function = tensorMacro() )
{ return function; }
}
@@ -669,6 +670,32 @@ TensorFunctionNode tensorCellCast() :
{ return new TensorFunctionNode(new CellCast(TensorFunctionNode.wrap(tensor), TensorType.Value.fromId(valueType)));}
}
+ExpressionNode tensorMacro() :
+{
+ ExpressionNode tensorExpression;
+}
+{
+ (
+ tensorExpression = tensorUnpackBitsFromInt8()
+ )
+ { return tensorExpression; }
+}
+
+ExpressionNode tensorUnpackBitsFromInt8() :
+{
+ ExpressionNode tensor;
+ String targetCellType = "float";
+ String endianNess = "big";
+}
+{
+ <UNPACK_BITS_FROM_INT8> <LBRACE> tensor = expression() (
+ <COMMA> targetCellType = identifier() (
+ <COMMA> endianNess = identifier() )? )? <RBRACE>
+ {
+ return new UnpackBitsFromInt8(tensor, TensorType.Value.fromId(targetCellType), endianNess);
+ }
+}
+
LambdaFunctionNode lambdaFunction() :
{
List<String> variables;