summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-20 12:52:49 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-20 12:52:49 +0100
commit173e95cac59728ca14a2a44902255c72ad982ca3 (patch)
treea07fee4b5e4e3eecc9e34ed92e3908408349595a /model-integration
parentc18b5805006b83efbeb9fc881e1658a57be28e56 (diff)
- Put the inner loops in separate methods. This improves ability to inline.
- Use Buffer.get(int index) instead of Buffer.get(). That avoids a write. - Use int as loop variable. - This brings the splade perfoamnce test down from 8s to 7s - TensorConverter.toVespaTensor more than doubled speed.
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java104
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java2
2 files changed, 52 insertions, 54 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
index 07f2aea4ab6..d1a06d8c7ff 100644
--- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
+++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java
@@ -28,6 +28,7 @@ import java.nio.ShortBuffer;
import java.util.HashMap;
import java.util.Map;
import java.util.Set;
+import java.util.function.Function;
import java.util.stream.Collectors;
@@ -101,70 +102,67 @@ class TensorConverter {
throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type);
}
+ interface Short2Float {
+ float convert(short value);
+ }
+
+ private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, buffer.get(i));
+ }
+ private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, buffer.get(i));
+ }
+ private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, buffer.get(i));
+ }
+ private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, buffer.get(i));
+ }
+ private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, converter.convert(buffer.get(i)));
+ }
+ private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, buffer.get(i));
+ }
+ private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) {
+ for (int i = 0; i < totalSize; i++)
+ builder.cellByDirectIndex(i, buffer.get(i));
+ }
+
static Tensor toVespaTensor(OnnxValue onnxValue) {
if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) {
throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported");
}
TensorInfo tensorInfo = onnxTensor.getInfo();
-
TensorType type = toVespaType(onnxTensor.getInfo());
- DimensionSizes sizes = sizesFromType(type);
-
+ DimensionSizes sizes = DimensionSizes.of(type);
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes);
- long totalSize = sizes.totalSize();
- if (tensorInfo.type == OnnxJavaType.FLOAT) {
- FloatBuffer buffer = onnxTensor.getFloatBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, buffer.get());
- }
- else if (tensorInfo.type == OnnxJavaType.DOUBLE) {
- DoubleBuffer buffer = onnxTensor.getDoubleBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, buffer.get());
- }
- else if (tensorInfo.type == OnnxJavaType.INT8) {
- ByteBuffer buffer = onnxTensor.getByteBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, buffer.get());
- }
- else if (tensorInfo.type == OnnxJavaType.INT16) {
- ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, buffer.get());
- }
- else if (tensorInfo.type == OnnxJavaType.INT32) {
- IntBuffer buffer = onnxTensor.getIntBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, buffer.get());
- }
- else if (tensorInfo.type == OnnxJavaType.INT64) {
- LongBuffer buffer = onnxTensor.getLongBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, buffer.get());
- }
- else if (tensorInfo.type == OnnxJavaType.FLOAT16) {
- ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get()));
- }
- else if (tensorInfo.type == OnnxJavaType.BFLOAT16) {
- ShortBuffer buffer = onnxTensor.getShortBuffer();
- for (long i = 0; i < totalSize; i++)
- builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get())));
- }
- else {
- throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
+ long totalSizeAsLong = sizes.totalSize();
+ if (totalSizeAsLong > Integer.MAX_VALUE) {
+ throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE");
+ }
+
+ int totalSize = (int) totalSizeAsLong;
+ switch (tensorInfo.type) {
+ case FLOAT -> extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize);
+ case DOUBLE -> extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize);
+ case INT8 -> extractTensor(onnxTensor.getByteBuffer(), builder, totalSize);
+ case INT16 -> extractTensor(onnxTensor.getShortBuffer(), builder, totalSize);
+ case INT32 -> extractTensor(onnxTensor.getIntBuffer(), builder, totalSize);
+ case INT64 -> extractTensor(onnxTensor.getLongBuffer(), builder, totalSize);
+ case FLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize);
+ case BFLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize);
+ default -> throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type);
}
return builder.build();
}
- static private DimensionSizes sizesFromType(TensorType type) {
- DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size());
- for (int i = 0; i < type.dimensions().size(); i++)
- builder.set(i, type.dimensions().get(i).size().get());
- return builder.build();
- }
-
static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) {
return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()),
e -> toVespaType(e.getValue().getInfo())));
diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
index 82998b56fb5..b48051814ab 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java
@@ -49,7 +49,7 @@ public class SpladeEmbedderTest {
String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" +
" ii the project was led by the united states with the support of the united kingdom and canada";
Long now = System.currentTimeMillis();
- int n = 1000; // Takes around 8s on Intel core i9 2.4Ghz (macbook pro, 2019)
+ int n = 1000; // Takes around 7s on Intel core i9 2.4Ghz (macbook pro, 2019)
for (int i = 0; i < n; i++) {
assertEmbed("tensor<float>(t{})", text, indexingContext);
}