diff options
author | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-03-30 09:52:16 +0100 |
---|---|---|
committer | Jo Kristian Bergum <bergum@yahooinc.com> | 2024-03-30 09:52:16 +0100 |
commit | 77d7068127640eed3655422d5107890787a7526b (patch) | |
tree | 27813a16f62a883466e473752a142d3aa947a9e1 | |
parent | e09d4754e8b15016b7731603a00367bf95a630b1 (diff) |
Add some more tests on the binarization
-rw-r--r-- | model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java | 2 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java | 39 |
2 files changed, 39 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index c15a12f5064..169648967d7 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -178,7 +178,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { return builder.build(); } - Tensor binarize(IndexedTensor embedding, TensorType tensorType) { + static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) { Tensor.Builder builder = Tensor.Builder.of(tensorType); BitSet bitSet = new BitSet(8); int index = 0; diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java index 8147b0af54a..1ce1d955b00 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java @@ -6,6 +6,7 @@ import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig; import com.yahoo.language.process.Embedder; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.TensorAddress; @@ -16,12 +17,42 @@ import static org.junit.Assume.assumeTrue; import static org.junit.jupiter.api.Assertions.assertEquals; import static org.junit.jupiter.api.Assertions.assertTrue; +import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; +import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode; + public class HuggingFaceEmbedderTest { static HuggingFaceEmbedder embedder = getEmbedder(); static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder(); static Embedder.Context context = new Embedder.Context("schema.indexing"); + @Test + public void testBinarization() { + TensorType typeOne = TensorType.fromSpec("tensor<int8>(x[1])"); + TensorType typeTwo = TensorType.fromSpec("tensor<int8>(x[2])"); + assertPackRight("tensor(x[8]):[0,0,0,0,0,0,0,0]", "tensor<int8>(x[1]):[0]", typeOne); + assertPackRight("tensor(x[8]):[1,1,1,1,1,1,1,1]", "tensor<int8>(x[1]):[-1]", typeOne); + assertPackRight("tensor(x[16]):[0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]", "tensor<int8>(x[2]):[0, -1]", typeTwo); + + assertPackRight("tensor(x[8]):[0,1,0,1,0,1,0,1]", "tensor<int8>(x[1]):[85]", typeOne); + assertPackRight("tensor(x[8]):[1,0,1,0,1,0,1,0]", "tensor<int8>(x[1]):[-86]", typeOne); + assertPackRight("tensor(x[16]):[0,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0]", "tensor<int8>(x[2]):[85, -86]", typeTwo); + + assertPackRight("tensor(x[8]):[1,1,1,1,0,0,0,0]", "tensor<int8>(x[1]):[-16]", typeOne); + assertPackRight("tensor(x[8]):[0,0,0,0,1,1,1,1]", "tensor<int8>(x[1]):[15]", typeOne); + assertPackRight("tensor(x[16]):[1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1]", "tensor<int8>(x[2]):[-16, 15]", typeTwo); + } + + private void assertPackRight(String input, String expected, TensorType type) { + Tensor inputTensor = Tensor.from(input); + Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type); + assertEquals(expected.toString(), result.toString()); + //Verify against what is done in ranking with unpack_bits + Tensor unpacked = expandBitTensor(result); + assertEquals(inputTensor.toString(), unpacked.toString()); + } @Test public void testEmbedder() { @@ -47,7 +78,6 @@ public class HuggingFaceEmbedderTest { //49*8 > 384 embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[49])"))); }); - Tensor float16Result = embedder.embed(input, context, TensorType.fromSpec(("tensor<bfloat16>(x[1])"))); assertEquals(-0.666, float16Result.sum().asDouble(),1e-3); } @@ -86,4 +116,11 @@ public class HuggingFaceEmbedderTest { builder.normalize(true); return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build()); } + + public static Tensor expandBitTensor(Tensor packed) { + var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big"); + var context = new MapContext(); + context.put("input", new TensorValue(packed)); + return unpacker.evaluate(context).asTensor(); + } } |