aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJo Kristian Bergum <bergum@yahooinc.com>2024-03-30 09:52:16 +0100
committerJo Kristian Bergum <bergum@yahooinc.com>2024-03-30 09:52:16 +0100
commit77d7068127640eed3655422d5107890787a7526b (patch)
tree27813a16f62a883466e473752a142d3aa947a9e1
parente09d4754e8b15016b7731603a00367bf95a630b1 (diff)
Add some more tests on the binarization
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java2
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java39
2 files changed, 39 insertions, 2 deletions
diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
index c15a12f5064..169648967d7 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java
@@ -178,7 +178,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder {
return builder.build();
}
- Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
+ static public Tensor binarize(IndexedTensor embedding, TensorType tensorType) {
Tensor.Builder builder = Tensor.Builder.of(tensorType);
BitSet bitSet = new BitSet(8);
int index = 0;
diff --git a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
index 8147b0af54a..1ce1d955b00 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/HuggingFaceEmbedderTest.java
@@ -6,6 +6,7 @@ import ai.vespa.modelintegration.evaluator.OnnxRuntime;
import com.yahoo.config.ModelReference;
import com.yahoo.embedding.huggingface.HuggingFaceEmbedderConfig;
import com.yahoo.language.process.Embedder;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.TensorAddress;
@@ -16,12 +17,42 @@ import static org.junit.Assume.assumeTrue;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;
+import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
+import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.rule.UnpackBitsNode;
+
public class HuggingFaceEmbedderTest {
static HuggingFaceEmbedder embedder = getEmbedder();
static HuggingFaceEmbedder normalizedEmbedder = getNormalizedEmbedder();
static Embedder.Context context = new Embedder.Context("schema.indexing");
+ @Test
+ public void testBinarization() {
+ TensorType typeOne = TensorType.fromSpec("tensor<int8>(x[1])");
+ TensorType typeTwo = TensorType.fromSpec("tensor<int8>(x[2])");
+ assertPackRight("tensor(x[8]):[0,0,0,0,0,0,0,0]", "tensor<int8>(x[1]):[0]", typeOne);
+ assertPackRight("tensor(x[8]):[1,1,1,1,1,1,1,1]", "tensor<int8>(x[1]):[-1]", typeOne);
+ assertPackRight("tensor(x[16]):[0,0,0,0,0,0,0,0,1,1,1,1,1,1,1,1]", "tensor<int8>(x[2]):[0, -1]", typeTwo);
+
+ assertPackRight("tensor(x[8]):[0,1,0,1,0,1,0,1]", "tensor<int8>(x[1]):[85]", typeOne);
+ assertPackRight("tensor(x[8]):[1,0,1,0,1,0,1,0]", "tensor<int8>(x[1]):[-86]", typeOne);
+ assertPackRight("tensor(x[16]):[0,1,0,1,0,1,0,1,1,0,1,0,1,0,1,0]", "tensor<int8>(x[2]):[85, -86]", typeTwo);
+
+ assertPackRight("tensor(x[8]):[1,1,1,1,0,0,0,0]", "tensor<int8>(x[1]):[-16]", typeOne);
+ assertPackRight("tensor(x[8]):[0,0,0,0,1,1,1,1]", "tensor<int8>(x[1]):[15]", typeOne);
+ assertPackRight("tensor(x[16]):[1,1,1,1,0,0,0,0,0,0,0,0,1,1,1,1]", "tensor<int8>(x[2]):[-16, 15]", typeTwo);
+ }
+
+ private void assertPackRight(String input, String expected, TensorType type) {
+ Tensor inputTensor = Tensor.from(input);
+ Tensor result = HuggingFaceEmbedder.binarize((IndexedTensor) inputTensor, type);
+ assertEquals(expected.toString(), result.toString());
+ //Verify against what is done in ranking with unpack_bits
+ Tensor unpacked = expandBitTensor(result);
+ assertEquals(inputTensor.toString(), unpacked.toString());
+ }
@Test
public void testEmbedder() {
@@ -47,7 +78,6 @@ public class HuggingFaceEmbedderTest {
//49*8 > 384
embedder.embed(input, context, TensorType.fromSpec(("tensor<int8>(x[49])")));
});
-
Tensor float16Result = embedder.embed(input, context, TensorType.fromSpec(("tensor<bfloat16>(x[1])")));
assertEquals(-0.666, float16Result.sum().asDouble(),1e-3);
}
@@ -86,4 +116,11 @@ public class HuggingFaceEmbedderTest {
builder.normalize(true);
return new HuggingFaceEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
}
+
+ public static Tensor expandBitTensor(Tensor packed) {
+ var unpacker = new UnpackBitsNode(new ReferenceNode("input"), TensorType.Value.DOUBLE, "big");
+ var context = new MapContext();
+ context.put("input", new TensorValue(packed));
+ return unpacker.evaluate(context).asTensor();
+ }
}