diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-08-31 13:08:18 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2023-08-31 13:08:18 +0200 |
commit | 47ebd33df308d82e0327fe083cd8deaed9e6fb53 (patch) | |
tree | 3c3f5174afde0dd77d020ab42780344a546f407d /model-integration | |
parent | f46d67c5976e77e270002267996a559b1cb6d2c1 (diff) | |
parent | ae674d6d002ca0f99b401e1215d45d188ba81e12 (diff) |
Merge pull request #27969 from vespa-engine/bjorncs/embedder-metrics
Add generic metrics for embedders
Diffstat (limited to 'model-integration')
5 files changed, 94 insertions, 8 deletions
diff --git a/model-integration/pom.xml b/model-integration/pom.xml index 854e15298c6..d195a061c52 100644 --- a/model-integration/pom.xml +++ b/model-integration/pom.xml @@ -87,6 +87,18 @@ <scope>provided</scope> </dependency> <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>container-core</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> + <groupId>com.yahoo.vespa</groupId> + <artifactId>metrics</artifactId> + <version>${project.version}</version> + <scope>provided</scope> + </dependency> + <dependency> <groupId>net.java.dev.jna</groupId> <artifactId>jna</artifactId> <scope>provided</scope> diff --git a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java index a12424c7d12..2c4f09b3821 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/BertBaseEmbedder.java @@ -40,11 +40,13 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { private final String outputName; private final PoolingStrategy poolingStrategy; + private final Embedder.Runtime runtime; private final WordPieceEmbedder tokenizer; private final OnnxEvaluator evaluator; @Inject - public BertBaseEmbedder(OnnxRuntime onnx, BertBaseEmbedderConfig config) { + public BertBaseEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, BertBaseEmbedderConfig config) { + this.runtime = runtime; maxTokens = config.transformerMaxTokens(); startSequenceToken = config.transformerStartSequenceToken(); endSequenceToken = config.transformerEndSequenceToken(); @@ -87,11 +89,16 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String text, Context context) { - return tokenizer.embed(text, context); + var start = System.nanoTime(); + var tokens = tokenize(text, context); + runtime.sampleSequenceLength(tokens.size(), context); + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return tokens; } @Override public Tensor embed(String text, Context context, TensorType type) { + var start = System.nanoTime(); if (type.dimensions().size() != 1) { throw new IllegalArgumentException("Error in embedding to type '" + type + "': should only have one dimension."); } @@ -99,11 +106,16 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { throw new IllegalArgumentException("Error in embedding to type '" + type + "': dimension should be indexed."); } List<Integer> tokens = embedWithSeparatorTokens(text, context, maxTokens); - return embedTokens(tokens, type); + runtime.sampleSequenceLength(tokens.size(), context); + var embedding = embedTokens(tokens, type); + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return embedding; } @Override public void deconstruct() { evaluator.close(); } + private List<Integer> tokenize(String text, Context ctx) { return tokenizer.embed(text, ctx); } + Tensor embedTokens(List<Integer> tokens, TensorType type) { Tensor inputSequence = createTensorRepresentation(tokens, "d1"); Tensor attentionMask = createAttentionMask(inputSequence); @@ -129,7 +141,7 @@ public class BertBaseEmbedder extends AbstractComponent implements Embedder { private List<Integer> embedWithSeparatorTokens(String text, Context context, int maxLength) { List<Integer> tokens = new ArrayList<>(); tokens.add(startSequenceToken); - tokens.addAll(embed(text, context)); + tokens.addAll(tokenize(text, context)); tokens.add(endSequenceToken); if (tokens.size() > maxLength) { tokens = tokens.subList(0, maxLength-1); diff --git a/model-integration/src/main/java/ai/vespa/embedding/EmbedderRuntime.java b/model-integration/src/main/java/ai/vespa/embedding/EmbedderRuntime.java new file mode 100644 index 00000000000..45068db67f4 --- /dev/null +++ b/model-integration/src/main/java/ai/vespa/embedding/EmbedderRuntime.java @@ -0,0 +1,51 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +package ai.vespa.embedding; + +import ai.vespa.metrics.ContainerMetrics; +import com.yahoo.component.annotation.Inject; +import com.yahoo.language.Language; +import com.yahoo.language.process.Embedder; +import com.yahoo.metrics.simple.Gauge; +import com.yahoo.metrics.simple.MetricReceiver; +import com.yahoo.metrics.simple.Point; + +import java.util.HashMap; +import java.util.Map; + +/** + * @author bjorncs + */ +public class EmbedderRuntime implements Embedder.Runtime { + + private final Gauge embedLatency; + private final Gauge sequenceLength; + private final Map<MetricDimensions, Point> metricPointCache = new HashMap<>(); + + @Inject + public EmbedderRuntime(MetricReceiver metrics) { + embedLatency = metrics.declareGauge(ContainerMetrics.EMBEDDER_LATENCY.baseName()); + sequenceLength = metrics.declareGauge(ContainerMetrics.EMBEDDER_SEQUENCE_LENGTH.baseName()); + } + + @Override + public void sampleEmbeddingLatency(double millis, Embedder.Context ctx) { + embedLatency.sample(millis, metricPoint(ctx)); + } + + @Override + public void sampleSequenceLength(long length, Embedder.Context ctx) { + sequenceLength.sample(length, metricPoint(ctx)); + } + + private Point metricPoint(Embedder.Context ctx) { + var dimensions = new MetricDimensions(ctx.getEmbedderId(), ctx.getLanguage(), ctx.getDestination()); + return metricPointCache.computeIfAbsent( + dimensions, d -> new Point(Map.of("embedder", d.embedderId(), + "language", d.language().languageCode(), + "destination", d.destination()))); + } + + private record MetricDimensions(String embedderId, Language language, String destination) {} + +} diff --git a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java index b035541bb0f..ab8d33dbf17 100644 --- a/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java +++ b/model-integration/src/main/java/ai/vespa/embedding/huggingface/HuggingFaceEmbedder.java @@ -27,6 +27,7 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private static final Logger log = Logger.getLogger(HuggingFaceEmbedder.class.getName()); + private final Embedder.Runtime runtime; private final String inputIdsName; private final String attentionMaskName; private final String tokenTypeIdsName; @@ -37,7 +38,8 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { private final PoolingStrategy poolingStrategy; @Inject - public HuggingFaceEmbedder(OnnxRuntime onnx, HuggingFaceEmbedderConfig config) { + public HuggingFaceEmbedder(OnnxRuntime onnx, Embedder.Runtime runtime, HuggingFaceEmbedderConfig config) { + this.runtime = runtime; inputIdsName = config.transformerInputIds(); attentionMaskName = config.transformerAttentionMask(); tokenTypeIdsName = config.transformerTokenTypeIds(); @@ -87,7 +89,11 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public List<Integer> embed(String s, Context context) { - return tokenizer.embed(s, context); + var start = System.nanoTime(); + var tokens = tokenizer.embed(s, context); + runtime.sampleSequenceLength(tokens.size(), context); + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return tokens; } @Override @@ -98,7 +104,9 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { @Override public Tensor embed(String s, Context context, TensorType tensorType) { + var start = System.nanoTime(); var encoding = tokenizer.encode(s, context.getLanguage()); + runtime.sampleSequenceLength(encoding.ids().size(), context); Tensor inputSequence = createTensorRepresentation(encoding.ids(), "d1"); Tensor attentionMask = createTensorRepresentation(encoding.attentionMask(), "d1"); Tensor tokenTypeIds = tokenTypeIdsName.isEmpty() ? null : createTensorRepresentation(encoding.typeIds(), "d1"); @@ -117,7 +125,9 @@ public class HuggingFaceEmbedder extends AbstractComponent implements Embedder { Map<String, Tensor> outputs = evaluator.evaluate(inputs); Tensor tokenEmbeddings = outputs.get(outputName); var result = poolingStrategy.toSentenceEmbedding(tensorType, tokenEmbeddings, attentionMask); - return normalize ? normalize(result, tensorType) : result; + var normalized = normalize ? normalize(result, tensorType) : result; + runtime.sampleEmbeddingLatency((System.nanoTime() - start)/1_000_000d, context); + return normalized; } Tensor normalize(Tensor embedding, TensorType tensorType) { diff --git a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java index 329b87cacd1..a0964eb5812 100644 --- a/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java +++ b/model-integration/src/test/java/ai/vespa/embedding/BertBaseEmbedderTest.java @@ -3,6 +3,7 @@ package ai.vespa.embedding; import ai.vespa.modelintegration.evaluator.OnnxRuntime; import com.yahoo.config.ModelReference; import com.yahoo.embedding.BertBaseEmbedderConfig; +import com.yahoo.language.process.Embedder; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -69,7 +70,7 @@ public class BertBaseEmbedderTest { } private static BertBaseEmbedder newBertBaseEmbedder(BertBaseEmbedderConfig cfg) { - return new BertBaseEmbedder(new OnnxRuntime(), cfg); + return new BertBaseEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), cfg); } } |