diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-15 18:54:03 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-01-15 18:54:03 +0100 |
commit | 250aec1b7f1156ded4ef8eed2b4f029dafe4bc8a (patch) | |
tree | ba80f21c0bba5d9b6d8d4169c57479ff1360f947 /model-integration/src | |
parent | 348ba0774b8047aeb15d8f96c189991dac4180b1 (diff) |
Avoid generic reduce and keep PAD token embedding
Diffstat (limited to 'model-integration/src')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java | 27 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java | 44 |
2 files changed, 47 insertions, 24 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java index 0bee03a65af..8c39cc8c813 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java @@ -191,10 +191,10 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { attentionMaskName, attentionMaskTensor.expand("d0")); Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + IndexedTensor result = (IndexedTensor) tokenEmbeddings; int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue(); - if (dims != result.shape()[1]) { + if (dims != result.shape()[2]) { throw new IllegalArgumentException("Token vector dimensionality does not" + " match indexed dimensionality of " + dims); } @@ -217,9 +217,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); - IndexedTensor result = (IndexedTensor) tokenEmbeddings.reduce(Reduce.Aggregator.min, "d0"); + IndexedTensor result = (IndexedTensor) tokenEmbeddings; Tensor contextualEmbeddings; - int maxTokens = input.inputIds.size() -1; //Do not retain last PAD + int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens. if (tensorType.valueType() == TensorType.Value.INT8) { contextualEmbeddings = toBitTensor(result, tensorType, maxTokens); } else { @@ -230,11 +230,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { } public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) { + if(result.shape().length != 3) + throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); int size = type.indexedSubtype().dimensions().size(); if (size != 1) - throw new IllegalArgumentException("Indexed tensor must have one dimension"); + throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDimensionality = (int)result.shape()[1]; + int resultDimensionality = (int)result.shape()[2]; if (resultDimensionality != wantedDimensionality) { throw new IllegalArgumentException("Not possible to map token vector embedding with " + resultDimensionality + " + dimensions into tensor with " + wantedDimensionality); @@ -242,7 +244,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { Tensor.Builder builder = Tensor.Builder.of(type); for (int token = 0; token < nTokens; token++) { for (int d = 0; d < resultDimensionality; d++) { - var value = result.get(TensorAddress.of(token, d)); + var value = result.get(0,token,d); // batch, sequence token, dimension builder.cell(TensorAddress.of(token,d),value); } } @@ -253,11 +255,14 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { if (type.valueType() != TensorType.Value.INT8) throw new IllegalArgumentException("Only a int8 tensor type can be" + " the destination of bit packing"); + if(result.shape().length != 3) + throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]"); + int size = type.indexedSubtype().dimensions().size(); if (size != 1) - throw new IllegalArgumentException("Indexed tensor must have one dimension"); + throw new IllegalArgumentException("Target indexed sub-type must have one dimension"); int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue(); - int resultDimensionality = (int)result.shape()[1]; + int resultDimensionality = (int)result.shape()[2]; if (resultDimensionality != 8 * wantedDimensionality) { throw new IllegalArgumentException("Not possible to pack " + resultDimensionality + " + dimensions into " + wantedDimensionality + " dimensions"); @@ -266,8 +271,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder { for (int token = 0; token < nTokens; token++) { BitSet bitSet = new BitSet(8); int key = 0; - for (int d = 0; d < result.shape()[1]; d++) { - var value = result.get(TensorAddress.of(token, d)); + for (int d = 0; d < result.shape()[2]; d++) { + var value = result.get(0, token, d); // batch, sequence token, dimension int bitIndex = 7 - (d % 8); if (value > 0.0) { bitSet.set(bitIndex); diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java index f3682e45efc..0cae94c372a 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.MixedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import org.junit.Ignore; import org.junit.Test; import java.util.List; @@ -35,25 +36,25 @@ public class ColBertEmbedderTest { public void testPacking() { assertPackedRight( "" + - "tensor<float>(d1[6],d2[8]):" + - "[" + + "tensor<float>(d0[1],d1[6],d2[8]):" + + "[[" + "[0, 0, 0, 0, 0, 0, 0, 1]," + "[0, 0, 0, 0, 0, 1, 0, 1]," + "[0, 0, 0, 0, 0, 0, 1, 1]," + "[0, 1, 1, 1, 1, 1, 1, 1]," + "[1, 0, 0, 0, 0, 0, 0, 0]," + "[1, 1, 1, 1, 1, 1, 1, 1]" + - "]", + "]]", TensorType.fromSpec("tensor<int8>(dt{},x[1])"), "tensor<int8>(dt{},x[1]):{0:1.0, 1:5.0, 2:3.0, 3:127.0, 4:-128.0, 5:-1.0}", 6 ); assertPackedRight( "" + - "tensor<float>(d1[2],d2[16]):" + - "[" + + "tensor<float>(d0[1],d1[2],d2[16]):" + + "[[" + "[0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0]," + "[0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1]" + - "]", + "]]", TensorType.fromSpec("tensor<int8>(dt{},x[2])"), "tensor<int8>(dt{},x[2]):{0:[1.0, -128.0], 1:[5.0, 1.0]}",2 ); @@ -133,18 +134,35 @@ public class ColBertEmbedderTest { } String text = sb.toString(); Tensor fullFloat = assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); - assertEquals(511*128,fullFloat.size()); + assertEquals(512*128,fullFloat.size()); Tensor query = assertEmbed("tensor<float>(dt{},x[128])", text, queryContext); assertEquals(32*128,query.size()); Tensor binaryRep = assertEmbed("tensor<int8>(dt{},x[16])", text, indexingContext); - assertEquals(511*16,binaryRep.size()); + assertEquals(512*16,binaryRep.size()); Tensor shortDoc = assertEmbed("tensor<int8>(dt{},x[16])", "annoyance", indexingContext); - // 3 tokens, 16 bytes each = 48 bytes + // 4 tokens, 16 bytes each = 64 bytes //CLS [unused1] sequence - assertEquals(3*16,shortDoc.size());; + assertEquals(4*16,shortDoc.size());; + } + + @Ignore + public void testPerf() { + StringBuilder sb = new StringBuilder(); + for(int i = 0; i < 256; i++) { + sb.append("annoyance"); + sb.append(" "); + } + String text = sb.toString(); + Long now = System.currentTimeMillis(); + int n = 1000; + for (int i = 0; i < n; i++) { + assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext); + } + Long elapsed = (System.currentTimeMillis() - now); + System.out.println("Elapsed time: " + elapsed + " ms"); } static Tensor assertEmbed(String tensorSpec, String text, Embedder.Context context) { @@ -163,11 +181,11 @@ public class ColBertEmbedderTest { Tensor packed = ColBertEmbedder.toBitTensor(in, destination, size); assertEquals(expected, packed.toString()); Tensor unpacked = ColBertEmbedder.expandBitTensor(packed); - assertEquals(in.shape()[1], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue()); + assertEquals(in.shape()[2], unpacked.type().indexedSubtype().dimensions().get(0).size().get().longValue()); for (int dOuter = 0; dOuter < size; dOuter++) { - for (int dInner = 0; dInner < in.shape()[1]; dInner++) { + for (int dInner = 0; dInner < in.shape()[2]; dInner++) { var addr = TensorAddress.of(dOuter, dInner); - double oldVal = in.get(addr); + double oldVal = in.get(TensorAddress.of(0,dOuter, dInner)); if (oldVal > 0) { assertEquals(unpacked.get(addr), 1.0, 0.0); } else { |