diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-20 12:52:49 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-20 12:52:49 +0100 |
commit | 173e95cac59728ca14a2a44902255c72ad982ca3 (patch) | |
tree | a07fee4b5e4e3eecc9e34ed92e3908408349595a /model-integration | |
parent | c18b5805006b83efbeb9fc881e1658a57be28e56 (diff) |
- Put the inner loops in separate methods. This improves ability to inline.
- Use Buffer.get(int index) instead of Buffer.get(). That avoids a write.
- Use int as loop variable.
- This brings the splade perfoamnce test down from 8s to 7s
- TensorConverter.toVespaTensor more than doubled speed.
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java | 104 | ||||
-rw-r--r-- | model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java | 2 |
2 files changed, 52 insertions, 54 deletions
diff --git a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java index 07f2aea4ab6..d1a06d8c7ff 100644 --- a/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java +++ b/model-integration/src/main/java/ai/vespa/modelintegration/evaluator/TensorConverter.java @@ -28,6 +28,7 @@ import java.nio.ShortBuffer; import java.util.HashMap; import java.util.Map; import java.util.Set; +import java.util.function.Function; import java.util.stream.Collectors; @@ -101,70 +102,67 @@ class TensorConverter { throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensorInfo.type); } + interface Short2Float { + float convert(short value); + } + + private static void extractTensor(FloatBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(DoubleBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ByteBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(ShortBuffer buffer, Short2Float converter, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, converter.convert(buffer.get(i))); + } + private static void extractTensor(IntBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + private static void extractTensor(LongBuffer buffer, IndexedTensor.BoundBuilder builder, int totalSize) { + for (int i = 0; i < totalSize; i++) + builder.cellByDirectIndex(i, buffer.get(i)); + } + static Tensor toVespaTensor(OnnxValue onnxValue) { if ( ! (onnxValue instanceof OnnxTensor onnxTensor)) { throw new IllegalArgumentException("ONNX value is not a tensor: maps and sequences are not yet supported"); } TensorInfo tensorInfo = onnxTensor.getInfo(); - TensorType type = toVespaType(onnxTensor.getInfo()); - DimensionSizes sizes = sizesFromType(type); - + DimensionSizes sizes = DimensionSizes.of(type); IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) Tensor.Builder.of(type, sizes); - long totalSize = sizes.totalSize(); - if (tensorInfo.type == OnnxJavaType.FLOAT) { - FloatBuffer buffer = onnxTensor.getFloatBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.DOUBLE) { - DoubleBuffer buffer = onnxTensor.getDoubleBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT8) { - ByteBuffer buffer = onnxTensor.getByteBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT32) { - IntBuffer buffer = onnxTensor.getIntBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.INT64) { - LongBuffer buffer = onnxTensor.getLongBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, buffer.get()); - } - else if (tensorInfo.type == OnnxJavaType.FLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.fp16ToFloat(buffer.get())); - } - else if (tensorInfo.type == OnnxJavaType.BFLOAT16) { - ShortBuffer buffer = onnxTensor.getShortBuffer(); - for (long i = 0; i < totalSize; i++) - builder.cellByDirectIndex(i, Fp16Conversions.bf16ToFloat((buffer.get()))); - } - else { - throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); + long totalSizeAsLong = sizes.totalSize(); + if (totalSizeAsLong > Integer.MAX_VALUE) { + throw new IllegalArgumentException("TotalSize=" + totalSizeAsLong + " currently limited at INTEGER.MAX_VALUE"); + } + + int totalSize = (int) totalSizeAsLong; + switch (tensorInfo.type) { + case FLOAT -> extractTensor(onnxTensor.getFloatBuffer(), builder, totalSize); + case DOUBLE -> extractTensor(onnxTensor.getDoubleBuffer(), builder, totalSize); + case INT8 -> extractTensor(onnxTensor.getByteBuffer(), builder, totalSize); + case INT16 -> extractTensor(onnxTensor.getShortBuffer(), builder, totalSize); + case INT32 -> extractTensor(onnxTensor.getIntBuffer(), builder, totalSize); + case INT64 -> extractTensor(onnxTensor.getLongBuffer(), builder, totalSize); + case FLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::fp16ToFloat, builder, totalSize); + case BFLOAT16 -> extractTensor(onnxTensor.getShortBuffer(), Fp16Conversions::bf16ToFloat, builder, totalSize); + default -> throw new IllegalArgumentException("OnnxEvaluator does not currently support value type " + onnxTensor.getInfo().type); } return builder.build(); } - static private DimensionSizes sizesFromType(TensorType type) { - DimensionSizes.Builder builder = new DimensionSizes.Builder(type.dimensions().size()); - for (int i = 0; i < type.dimensions().size(); i++) - builder.set(i, type.dimensions().get(i).size().get()); - return builder.build(); - } - static Map<String, TensorType> toVespaTypes(Map<String, NodeInfo> infoMap) { return infoMap.entrySet().stream().collect(Collectors.toMap(e -> asValidName(e.getKey()), e -> toVespaType(e.getValue().getInfo()))); diff --git a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java index 82998b56fb5..b48051814ab 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/SpladeEmbedderTest.java @@ -49,7 +49,7 @@ public class SpladeEmbedderTest { String text = "what was the manhattan project in this context it was a secret project to develop a nuclear weapon in world war" + " ii the project was led by the united states with the support of the united kingdom and canada"; Long now = System.currentTimeMillis(); - int n = 1000; // Takes around 8s on Intel core i9 2.4Ghz (macbook pro, 2019) + int n = 1000; // Takes around 7s on Intel core i9 2.4Ghz (macbook pro, 2019) for (int i = 0; i < n; i++) { assertEmbed("tensor<float>(t{})", text, indexingContext); } |