diff options
author | Arne Juul <arnej@vespa.ai> | 2023-11-05 12:20:09 +0000 |
---|---|---|
committer | Arne Juul <arnej@yahooinc.com> | 2023-11-10 09:55:58 +0000 |
commit | 65047e9ad4d6138570e141159941ad9b81fdd41b (patch) | |
tree | 8c7a5a87be48dae2ce862061e87209129b432ebd | |
parent | 83b1ccd36dd5df2e43307aab19adc07b41c94c9f (diff) |
add "unpack_bits_from_int8" function
10 files changed, 361 insertions, 1 deletions
diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg index 752e236bb19..d050d553416 100644 --- a/config-model/src/test/derived/tensor/attributes.cfg +++ b/config-model/src/test/derived/tensor/attributes.cfg @@ -143,3 +143,32 @@ attribute[].index.hnsw.enabled false attribute[].index.hnsw.maxlinkspernode 16 attribute[].index.hnsw.neighborstoexploreatinsert 200 attribute[].index.hnsw.multithreadedindexing true +attribute[].name "f7" +attribute[].datatype TENSOR +attribute[].collectiontype SINGLE +attribute[].dictionary.type BTREE +attribute[].dictionary.match UNCASED +attribute[].match UNCASED +attribute[].removeifzero false +attribute[].createifnonexistent false +attribute[].fastsearch false +attribute[].paged false +attribute[].ismutable false +attribute[].sortascending true +attribute[].sortfunction UCA +attribute[].sortstrength PRIMARY +attribute[].sortlocale "" +attribute[].enableonlybitvector false +attribute[].fastaccess false +attribute[].arity 8 +attribute[].lowerbound -9223372036854775808 +attribute[].upperbound 9223372036854775807 +attribute[].densepostinglistthreshold 0.4 +attribute[].tensortype "tensor<int8>(p{},x[5])" +attribute[].imported false +attribute[].maxuncommittedmemory 77777 +attribute[].distancemetric EUCLIDEAN +attribute[].index.hnsw.enabled false +attribute[].index.hnsw.maxlinkspernode 16 +attribute[].index.hnsw.neighborstoexploreatinsert 200 +attribute[].index.hnsw.multithreadedindexing true diff --git a/config-model/src/test/derived/tensor/documentmanager.cfg b/config-model/src/test/derived/tensor/documentmanager.cfg index bae2db34040..5676f60ef46 100644 --- a/config-model/src/test/derived/tensor/documentmanager.cfg +++ b/config-model/src/test/derived/tensor/documentmanager.cfg @@ -49,6 +49,7 @@ doctype[].fieldsets{[document]}.fields[] "f3" doctype[].fieldsets{[document]}.fields[] "f4" doctype[].fieldsets{[document]}.fields[] "f5" doctype[].fieldsets{[document]}.fields[] "f6" +doctype[].fieldsets{[document]}.fields[] "f7" doctype[].tensortype[].idx 10017 doctype[].tensortype[].detailedtype "tensor(x[3])" doctype[].tensortype[].idx 10018 @@ -59,6 +60,8 @@ doctype[].tensortype[].idx 10020 doctype[].tensortype[].detailedtype "tensor(x[10],y[10])" doctype[].tensortype[].idx 10021 doctype[].tensortype[].detailedtype "tensor<float>(x[10])" +doctype[].tensortype[].idx 10022 +doctype[].tensortype[].detailedtype "tensor<int8>(p{},x[5])" doctype[].structtype[].idx 10016 doctype[].structtype[].name "tensor.header" doctype[].structtype[].field[].name "f1" @@ -79,3 +82,6 @@ doctype[].structtype[].field[].type 10021 doctype[].structtype[].field[].name "f6" doctype[].structtype[].field[].internalid 596352344 doctype[].structtype[].field[].type 10005 +doctype[].structtype[].field[].name "f7" +doctype[].structtype[].field[].internalid 981728328 +doctype[].structtype[].field[].type 10022 diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg index d10ecd37c8f..d069b3764ba 100644 --- a/config-model/src/test/derived/tensor/documenttypes.cfg +++ b/config-model/src/test/derived/tensor/documenttypes.cfg @@ -54,6 +54,7 @@ doctype[].fieldsets{[document]}.fields[] "f3" doctype[].fieldsets{[document]}.fields[] "f4" doctype[].fieldsets{[document]}.fields[] "f5" doctype[].fieldsets{[document]}.fields[] "f6" +doctype[].fieldsets{[document]}.fields[] "f7" doctype[].tensortype[].idx 10017 doctype[].tensortype[].detailedtype "tensor(x[3])" doctype[].tensortype[].idx 10018 @@ -64,6 +65,8 @@ doctype[].tensortype[].idx 10020 doctype[].tensortype[].detailedtype "tensor(x[10],y[10])" doctype[].tensortype[].idx 10021 doctype[].tensortype[].detailedtype "tensor<float>(x[10])" +doctype[].tensortype[].idx 10022 +doctype[].tensortype[].detailedtype "tensor<int8>(p{},x[5])" doctype[].structtype[].idx 10016 doctype[].structtype[].name "tensor.header" doctype[].structtype[].field[].name "f1" @@ -84,4 +87,7 @@ doctype[].structtype[].field[].type 10021 doctype[].structtype[].field[].name "f6" doctype[].structtype[].field[].internalid 596352344 doctype[].structtype[].field[].type 10005 +doctype[].structtype[].field[].name "f7" +doctype[].structtype[].field[].internalid 981728328 +doctype[].structtype[].field[].type 10022 doctype[].structtype[].internalid 2125927172 diff --git a/config-model/src/test/derived/tensor/index-info.cfg b/config-model/src/test/derived/tensor/index-info.cfg index 4d8ce8d1372..c9ce2433e17 100644 --- a/config-model/src/test/derived/tensor/index-info.cfg +++ b/config-model/src/test/derived/tensor/index-info.cfg @@ -35,3 +35,9 @@ indexinfo[].command[].indexname "f6" indexinfo[].command[].command "numerical" indexinfo[].command[].indexname "f6" indexinfo[].command[].command "type float" +indexinfo[].command[].indexname "f7" +indexinfo[].command[].command "attribute" +indexinfo[].command[].indexname "f7" +indexinfo[].command[].command "type tensor<int8>(p{},x[5])" +indexinfo[].command[].indexname "f7" +indexinfo[].command[].command "word" diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 1ec3d67cb47..cd8375cb68d 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -1,4 +1,6 @@ rankprofile[].name "default" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -16,6 +18,8 @@ rankprofile[].fef.property[].name "vespa.hitcollector.arraysize" rankprofile[].fef.property[].value "0" rankprofile[].fef.property[].name "vespa.dump.ignoredefaultfeatures" rankprofile[].fef.property[].value "true" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -29,6 +33,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(map(attribute(f4), f(x)(x * x)) + reduce(tensor(x[2],y[3])(random), count) * rename(attribute(f4), (x, y), (y, x)), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -42,6 +48,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(reduce(join(attribute(f4), tensor(x[10],y[10],z[3])((x==y)*(y==z)), f(a,b)(a * b)), sum, x), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -59,6 +67,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(rankingExpression(joinedtensors), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -72,6 +82,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(attribute(f5), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -85,6 +97,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(tensor<float>(d0[1],x[2]):{{d0:0,x:0}:(attribute(f6)),{d0:0,x:1}:(reduce(attribute(f5), sum))}, sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -102,6 +116,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(tensor<float>(d0[1],x[2]):{{d0:0,x:0}:(attribute(f6)),{d0:0,x:1}:(reduce(rankingExpression(joinedtensors), sum))}, sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -123,6 +139,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(rankingExpression(reshaped) * rankingExpression(literal), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -138,6 +156,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(tensor(d0[1])((attribute(f3){x:(rankingExpression(functionNotLabel))})), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -153,6 +173,8 @@ rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" rankprofile[].fef.property[].value "reduce(tensor(shadow[1])((attribute(f3){x:(shadow + rankingExpression(shadow))})), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" @@ -161,3 +183,26 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f4" rankprofile[].fef.property[].value "tensor(x[10],y[10])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].name "with-unpack" +rankprofile[].fef.property[].name "rankingExpression(myunpack).rankingScript" +rankprofile[].fef.property[].value "map_subspaces(attribute(f7), f(denseSubspaceInput)(tensor<float>(x[40])(bit(denseSubspaceInput{x:(x/8)}, 7-(x % 8)))))" +rankprofile[].fef.property[].name "rankingExpression(myunpack).type" +rankprofile[].fef.property[].value "tensor<float>(p{},x[40])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(query(para) * rankingExpression(myunpack) * query(qvec), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f7" +rankprofile[].fef.property[].value "tensor<int8>(p{},x[5])" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[10])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].fef.property[].name "vespa.type.query.para" +rankprofile[].fef.property[].value "tensor<float>(p{})" +rankprofile[].fef.property[].name "vespa.type.query.qvec" +rankprofile[].fef.property[].value "tensor<float>(x[40])" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index f2fc57a2018..81230e5c54c 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -20,6 +20,9 @@ schema tensor { field f6 type float { indexing: attribute } + field f7 type tensor<int8>(p{},x[5]) { + indexing: attribute + } } rank-profile profile1 { @@ -119,4 +122,17 @@ schema tensor { } + rank-profile with-unpack { + inputs { + query(para) tensor<float>(p{}) + query(qvec) tensor<float>(x[40]) + } + function myunpack() { + expression: unpack_bits_from_int8(attribute(f7)) + } + first-phase { + expression: sum(query(para)*myunpack*query(qvec)) + } + } + } diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index ca475a77b6c..4f0a99a117d 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -953,6 +953,8 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorCellCast()", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMacro()", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorUnpackBitsFromInt8()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", "public final com.yahoo.tensor.TensorType tensorType(java.util.List)", @@ -1086,6 +1088,7 @@ "public static final int HAMMING", "public static final int MAP", "public static final int MAP_SUBSPACES", + "public static final int UNPACK_BITS_FROM_INT8", "public static final int REDUCE", "public static final int JOIN", "public static final int MERGE", @@ -1707,5 +1710,22 @@ "public int hashCode()" ], "fields" : [ ] + }, + "com.yahoo.searchlib.rankingexpression.rule.UnpackBitsFromInt8" : { + "superClass" : "com.yahoo.searchlib.rankingexpression.rule.CompositeNode", + "interfaces" : [ ], + "attributes" : [ + "public" + ], + "methods" : [ + "public void <init>(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode, com.yahoo.tensor.TensorType$Value, java.lang.String)", + "public java.util.List children()", + "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", + "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public com.yahoo.searchlib.rankingexpression.rule.CompositeNode setChildren(java.util.List)", + "public int hashCode()" + ], + "fields" : [ ] } }
\ No newline at end of file diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java new file mode 100644 index 00000000000..84203da4a7e --- /dev/null +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/UnpackBitsFromInt8.java @@ -0,0 +1,183 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.searchlib.rankingexpression.rule; + +import com.yahoo.api.annotations.Beta; +import com.yahoo.searchlib.rankingexpression.Reference; +import com.yahoo.searchlib.rankingexpression.evaluation.Context; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Deque; +import java.util.List; +import java.util.Optional; +import java.util.Objects; + +/** + * Macro that expands to the appropriate map_subspaces magic incantation + * + * @author arnej + */ +@Beta +public class UnpackBitsFromInt8 extends CompositeNode { + + private static String operationName = "unpack_bits_from_int8"; + private enum EndianNess { + BIG_ENDIAN("big"), LITTLE_ENDIAN("little"); + + private final String id; + EndianNess(String id) { this.id = id; } + public String toString() { return id; } + public static EndianNess fromId(String id) { + for (EndianNess value : values()) { + if (value.id.equals(id)) { + return value; + } + } + throw new IllegalArgumentException("EndianNess must be either 'big' or 'little', but was '" + id + "'"); + } + }; + + final ExpressionNode input; + final TensorType.Value targetCellType; + final EndianNess endian; + + public UnpackBitsFromInt8(ExpressionNode input, TensorType.Value targetCellType, String endianNess) { + this.input = input; + this.targetCellType = targetCellType; + this.endian = EndianNess.fromId(endianNess); + } + + @Override + public List<ExpressionNode> children() { + return Collections.singletonList(input); + } + + private static record Meta(TensorType outputType, TensorType outputDenseType, String unpackDimension) {} + + @Override + public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { + var optTC = context.typeContext(); + if (optTC.isPresent()) { + TensorType inputType = input.type(optTC.get()); + var meta = analyze(inputType); + string.append("map_subspaces").append("("); + input.toString(string, context, path, this); + string.append(", f(denseSubspaceInput)("); + string.append(meta.outputDenseType()).append("("); // generate + string.append("bit(denseSubspaceInput{"); + for (var dim : meta.outputDenseType().dimensions()) { + String dName = dim.name(); + boolean last = dName.equals(meta.unpackDimension); + string.append(dName); + string.append(":("); + string.append(dName); + if (last) { + string.append("/8"); + } + string.append(")"); + if (! last) { + string.append(", "); + } + } + if (endian.equals(EndianNess.BIG_ENDIAN)) { + string.append("}, 7-("); + } else { + string.append("}, ("); + } + string.append(meta.unpackDimension); + string.append(" % 8)"); + string.append("))))"); // bit, generate, f, map_subspaces + } else { + string.append(operationName); + string.append("("); + input.toString(string, context, path, this); + string.append(","); + string.append(targetCellType); + string.append(","); + string.append(endian); + string.append(")"); + } + return string; + } + + @Override + public Value evaluate(Context context) { + Tensor inputTensor = input.evaluate(context).asTensor(); + TensorType inputType = inputTensor.type(); + var meta = analyze(inputType); + var builder = Tensor.Builder.of(meta.outputType()); + for (var iter = inputTensor.cellIterator(); iter.hasNext(); ) { + var cell = iter.next(); + var oldAddr = cell.getKey(); + for (int bitIdx = 0; bitIdx < 8; bitIdx++) { + var addrBuilder = new TensorAddress.Builder(meta.outputType()); + for (int i = 0; i < inputType.dimensions().size(); i++) { + var dim = inputType.dimensions().get(i); + if (dim.name().equals(meta.unpackDimension())) { + long newIdx = oldAddr.numericLabel(i) * 8 + bitIdx; + addrBuilder.add(dim.name(), String.valueOf(newIdx)); + } else { + addrBuilder.add(dim.name(), oldAddr.label(i)); + } + } + var newAddr = addrBuilder.build(); + int oldValue = (int)(cell.getValue().doubleValue()); + if (endian.equals(EndianNess.BIG_ENDIAN)) { + float newCellValue = 1 & (oldValue >>> (7 - bitIdx)); + builder.cell(newAddr, newCellValue); + } else { + float newCellValue = 1 & (oldValue >>> bitIdx); + builder.cell(newAddr, newCellValue); + } + } + } + return new TensorValue(builder.build()); + } + + private Meta analyze(TensorType inputType) { + TensorType inputDenseType = inputType.indexedSubtype(); + if (inputDenseType.rank() == 0) { + throw new IllegalArgumentException("bad " + operationName + "; input must have indexed dimension, but type was: " + inputType); + } + var lastDim = inputDenseType.dimensions().get(inputDenseType.rank() - 1); + if (lastDim.size().isEmpty()) { + throw new IllegalArgumentException("bad " + operationName + "; last indexed dimension must be bound, but type was: " + inputType); + } + List<TensorType.Dimension> outputDims = new ArrayList<>(); + var ttBuilder = new TensorType.Builder(targetCellType); + for (var dim : inputType.dimensions()) { + if (dim.name().equals(lastDim.name())) { + long sz = dim.size().get(); + ttBuilder.indexed(dim.name(), sz * 8); + } else { + ttBuilder.set(dim); + } + } + TensorType outputType = ttBuilder.build(); + return new Meta(outputType, outputType.indexedSubtype(), lastDim.name()); + } + + @Override + public TensorType type(TypeContext<Reference> context) { + TensorType inputType = input.type(context); + var meta = analyze(inputType); + return meta.outputType(); + } + + @Override + public CompositeNode setChildren(List<ExpressionNode> newChildren) { + if (newChildren.size() != 1) + throw new IllegalArgumentException("Expected 1 child but got " + newChildren.size()); + return new UnpackBitsFromInt8(newChildren.get(0), targetCellType, endian.toString()); + } + + @Override + public int hashCode() { return Objects.hash(operationName, input, targetCellType); } + +} diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 6cd01151dc1..1da8a5ece89 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -129,6 +129,7 @@ TOKEN : <MAP: "map"> | <MAP_SUBSPACES: "map_subspaces"> | + <UNPACK_BITS_FROM_INT8: "unpack_bits_from_int8"> | <REDUCE: "reduce"> | <JOIN: "join"> | <MERGE: "merge"> | @@ -344,7 +345,7 @@ ExpressionNode function() : ExpressionNode function; } { - ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() ) + ( LOOKAHEAD(2) function = scalarOrTensorFunction() | function = tensorFunction() | function = tensorMacro() ) { return function; } } @@ -669,6 +670,32 @@ TensorFunctionNode tensorCellCast() : { return new TensorFunctionNode(new CellCast(TensorFunctionNode.wrap(tensor), TensorType.Value.fromId(valueType)));} } +ExpressionNode tensorMacro() : +{ + ExpressionNode tensorExpression; +} +{ + ( + tensorExpression = tensorUnpackBitsFromInt8() + ) + { return tensorExpression; } +} + +ExpressionNode tensorUnpackBitsFromInt8() : +{ + ExpressionNode tensor; + String targetCellType = "float"; + String endianNess = "big"; +} +{ + <UNPACK_BITS_FROM_INT8> <LBRACE> tensor = expression() ( + <COMMA> targetCellType = identifier() ( + <COMMA> endianNess = identifier() )? )? <RBRACE> + { + return new UnpackBitsFromInt8(tensor, TensorType.Value.fromId(targetCellType), endianNess); + } +} + LambdaFunctionNode lambdaFunction() : { List<String> variables; diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 637e9be5fc3..6ac734e8771 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -177,6 +177,28 @@ public class EvaluationTestCase { } @Test + public void testUnpack() { + EvaluationTester tester = new EvaluationTester(); + tester.assertEvaluates("tensor<float>(a{},x[16]):{foo:[" + + "0,0,0,0, 0,0,0,0," + + "1,1,1,1, 1,1,1,1" + + "],bar:[" + + "0,0,0,0, 0,0,0,1," + + "1,1,1,1, 1,0,0,0]}", + "unpack_bits_from_int8(tensor0, float, big)", + "tensor<int8>(a{},x[2]):{foo:[0,255],bar:[1,-8]}"); + + tester.assertEvaluates("tensor<int8>(a{},x[16]):{foo:[" + + "0,0,0,0, 0,0,0,0," + + "1,1,1,1, 1,1,1,1" + + "],bar:[" + + "1,0,0,0, 0,0,0,0," + + "0,0,0,1, 1,1,1,1]}", + "unpack_bits_from_int8(tensor0, int8, little)", + "tensor<int8>(a{},x[2]):{foo:[0,255],bar:[1,-8]}"); + } + + @Test public void testMapSubspaces() { EvaluationTester tester = new EvaluationTester(); tester.assertEvaluates("tensor<float>(a{},x[2]):{foo:[2,3],bar:[7,10]}", |