aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2024-03-31 23:15:59 +0200
committerGitHub <noreply@github.com>2024-03-31 23:15:59 +0200
commit74b0626d7e9f9b2b007a454587ef6df78d85b61e (patch)
tree0eb37eb2edf6bdf70efc29865572e93ebc99a284
parent1d0e06dc7db74cfc5a13d1ca8094ab583a6907ce (diff)
parent77d7068127640eed3655422d5107890787a7526b (diff)
Merge pull request #30755 from vespa-engine/jobergum/add-support-for-binarization-and-matryoshka
Jobergum/add support for binarization and matryoshka
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java60
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java126
-rw-r--r--model-integration/src/test/models/onnx/transformer/embedding_model.onnxbin0 -> 17409774 bytes
3 files changed, 181 insertions, 5 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index 35645deffa4..169648967d7 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -17,6 +17,7 @@ import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import java.nio.file.Paths;
+import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.logging.Logger;
@@ -124,18 +125,44 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
}
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
- Tensor tokenEmbeddings = outputs.get(outputName);
- var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
- var normalized = normalize ? normalize(result, tensorType) : result;
+ IndexedTensor tokenEmbeddings = (IndexedTensor) outputs.get(outputName);
+ long[] resultShape = tokenEmbeddings.shape();
+ //shape batch, sequence, embedding dimensionality
+ if (resultShape.length != 3) {
+ throw new IllegalArgumentException("" +
+ "Expected 3 output dimensions for output name '" +
+ outputName + "': [batch, sequence, embedding], got " + resultShape.length);
+ }
+ Tensor result;
+ if (tensorType.valueType() == TensorType.Value.INT8) {
+ long outputDimensions = resultShape[2];
+ long targetDim = tensorType.dimensions().get(0).size().get();
+
+ if(targetDim * 8 > outputDimensions) {
+ throw new IllegalArgumentException("Cannot pack " + outputDimensions + " into " + targetDim + " int8s");
+ }
+ //Dimensionality flexibility 🪆 - packing only the first 8*targetDim values from the model output
+ long firstDimensions = 8 * targetDim;
+ String name = tensorType.indexedSubtype().dimensions().get(0).name();
+ //perform pooling and normalizing using floating point embeddings before binarizing
+ //using the firstDimensions as the target dimensionality
+ TensorType poolingType = new TensorType.Builder(TensorType.Value.FLOAT).indexed(name, firstDimensions).build();
+ result = poolingStrategy.toSentenceEmbedding(poolingType, tokenEmbeddings, attentionMask);
+ result = normalize? normalize(result, poolingType) : result;
+ result = binarize((IndexedTensor) result, tensorType);
+
+ } else { // regular floating points embeddings
+ result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask);
+ result = normalize ? normalize(result, tensorType) : result;
+ }
runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context);
- return normalized;
+ return result;
}
Tensor normalize(Tensor embedding, TensorType tensorType) {
double sumOfSquares = 0.0;
Tensor.Builder builder = Tensor.Builder.of(tensorType);
-
for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) {
double item = embedding.get(TensorAddress.of(i));
sumOfSquares += item * item;
@@ -151,6 +178,29 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
+ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
+ Tensor.Builder builder = Tensor.Builder.of(tensorType);
+ BitSet bitSet = new BitSet(8);
+ int index = 0;
+ for (int d = 0; d < embedding.sizeAsInt(); d++) {
+ var value = embedding.get(d);
+ int bitIndex = 7 - (d % 8);
+ if (value > 0.0) {
+ bitSet.set(bitIndex);
+ } else {
+ bitSet.clear(bitIndex);
+ }
+ if ((d + 1) % 8 == 0) {
+ byte[] bytes = bitSet.toByteArray();
+ byte packed = (bytes.length == 0) ? 0 : bytes[0];
+ builder.cell(TensorAddress.of(index), packed);
+ index++;
+ bitSet = new BitSet(8);
+ }
+ }
+ return builder.build();
+ }
+
private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
int size = input.size();
TensorType type = new TensorType.Builder(TensorType.Value.FLOAT).indexed(dimension, size).build();
diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
new file mode 100644
index 00000000000..1ce1d955b00
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
@@ -0,0 +1,126 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.embedding;
+
+import ai.vespa.embedding.huggingface.HuggingFaceEmbedder;
+import ai.vespa.modelintegration.evaluator.OnnxRuntime;
+import com.yahoo.config.ModelReference;
+import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
+import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TensorAddress;
+import org.junit.Test;
+
+import static org.junit.Assert.assertThrows;
+import static org.junit.Assume.assumeTrue;
+import static org.junit.jupiter.api.Assertions.assertEquals;
+import static org.junit.jupiter.api.Assertions.assertTrue;
+
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;
+
+public class HuggingFaceEmbedderTest {
+
+ static HuggingFaceEmbedder embedder = getEmbedder();
+ static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder();
+ static Embedder.Context context = new Embedder.Context("schema.indexing");
+
+ @Test
+ public void testBinarization() {
+ TensorType typeOne = TensorType.fromSpec("tensor<int8>(x[1])");
+ TensorType typeTwo = TensorType.fromSpec("tensor<int8>(x[2])");
+ assertPackRight("tensor(x[8]):[0,0,0,0,0,0,0,0]", "tensor<int8>(x[1]):[0]", typeOne);
+ assertPackRight("tensor(x[8]):[1,1,1,1,1,1,1,1]", "tensor<int8>(x[1]):[-1]", typeOne);
+ assertPackRight("tensor(x[16]):[0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]", "tensor<int8>(x[2]):[0, -1]", typeTwo);
+
+ assertPackRight("tensor(x[8]):[0,1,0,1,0,1,0,1]", "tensor<int8>(x[1]):[85]", typeOne);
+ assertPackRight("tensor(x[8]):[1,0,1,0,1,0,1,0]", "tensor<int8>(x[1]):[-86]", typeOne);
+ assertPackRight("tensor(x[16]):[0,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0]", "tensor<int8>(x[2]):[85, -86]", typeTwo);
+
+ assertPackRight("tensor(x[8]):[1,1,1,1,0,0,0,0]", "tensor<int8>(x[1]):[-16]", typeOne);
+ assertPackRight("tensor(x[8]):[0,0,0,0,1,1,1,1]", "tensor<int8>(x[1]):[15]", typeOne);
+ assertPackRight("tensor(x[16]):[1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1]", "tensor<int8>(x[2]):[-16, 15]", typeTwo);
+ }
+
+ private void assertPackRight(String input, String expected, TensorType type) {
+ Tensor inputTensor = Tensor.from(input);
+ Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type);
+ assertEquals(expected.toString(), result.toString());
+ //Verify against what is done in ranking with unpack_bits
+ Tensor unpacked = expandBitTensor(result);
+ assertEquals(inputTensor.toString(), unpacked.toString());
+ }
+
+ @Test
+ public void testEmbedder() {
+ String input = "This is a test";
+
+ Tensor expected = Tensor.from("tensor<float>(x[8]):[-0.666, 0.335, 0.227, 0.0919, -0.069, 0.323, 0.422, 0.270]");
+ Tensor result = embedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
+ for(int i = 0; i < 8; i++) {
+ assertEquals(expected.get(TensorAddress.of(i)), result.get(TensorAddress.of(i)), 1e-2);
+ }
+ // Thresholding on the above gives [0, 1, 1, 1, 0, 1, 1, 1] which is packed into 119 (int8)
+ Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[1])")));
+ assertEquals("tensor<int8>(x[1]):[119]", binarizedResult.toString());
+
+ binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
+ assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());
+
+ binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[48])")));
+ assertTrue(binarizedResult.toAbbreviatedString().startsWith("tensor<int8>(x[48]):[119, 44"));
+
+ assertThrows(IllegalArgumentException.class, () -> {
+ // throws because the target tensor type is not compatible with the model output
+ //49*8 > 384
+ embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[49])")));
+ });
+ Tensor float16Result = embedder.embed(input, context, TensorType.fromSpec(("tensor<bfloat16>(x[1])")));
+ assertEquals(-0.666, float16Result.sum().asDouble(),1e-3);
+ }
+
+ @Test
+ public void testEmbedderWithNormalization() {
+ String input = "This is a test";
+
+ Tensor result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[8])")));
+ assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
+
+ result = normalizedEmbedder.embed(input, context, TensorType.fromSpec(("tensor<float>(x[16])")));
+ assertEquals(1.0, result.multiply(result).sum().asDouble(), 1e-3);
+ Tensor binarizedResult = embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[2])")));
+ assertEquals("tensor<int8>(x[2]):[119, 44]", binarizedResult.toAbbreviatedString());
+ }
+
+ private static HuggingFaceEmbedder getEmbedder() {
+ String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
+ String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
+ HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
+ builder.tokenizerPath(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ builder.transformerGpuDevice(-1);
+ return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+ private static HuggingFaceEmbedder getNormalizedEmbedder() {
+ String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
+ String modelPath = "src/test/models/onnx/transformer/embedding_model.onnx";
+ assumeTrue(OnnxRuntime.isRuntimeAvailable(modelPath));
+ HuggingFaceEmbedderConfig.Builder builder = new HuggingFaceEmbedderConfig.Builder();
+ builder.tokenizerPath(ModelReference.valueOf(vocabPath));
+ builder.transformerModel(ModelReference.valueOf(modelPath));
+ builder.transformerGpuDevice(-1);
+ builder.normalize(true);
+ return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
+ }
+
+ public static Tensor expandBitTensor(Tensor packed) {
+ var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big");
+ var context = new MapContext();
+ context.put("input", new TensorValue(packed));
+ return unpacker.evaluate(context).asTensor();
+ }
+}
diff --git a/model-integration/src/test/models/onnx/transformer/embedding_model.onnx b/model-integration/src/test/models/onnx/transformer/embedding_model.onnx
new file mode 100644
index 00000000000..266ed567344
--- /dev/null
+++ b/model-integration/src/test/models/onnx/transformer/embedding_model.onnx
Binary files differ