diff options
Diffstat (limited to 'config-model')
6 files changed, 33 insertions, 30 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/Schema.java b/config-model/src/main/java/com/yahoo/schema/Schema.java index 6548907000a..279b5729ea1 100644 --- a/config-model/src/main/java/com/yahoo/schema/Schema.java +++ b/config-model/src/main/java/com/yahoo/schema/Schema.java @@ -293,8 +293,7 @@ public class Schema implements ImmutableSchema { @Override public List<ImmutableSDField> allFieldsList() { - List<ImmutableSDField> all = new ArrayList<>(); - all.addAll(extraFieldList()); + List<ImmutableSDField> all = new ArrayList<>(extraFieldList()); for (Field field : documentType.fieldSet()) { all.add((ImmutableSDField) field); } @@ -668,11 +667,10 @@ public class Schema implements ImmutableSchema { @Override public boolean equals(Object o) { - if (!(o instanceof Schema)) { + if (!(o instanceof Schema other)) { return false; } - Schema other = (Schema)o; return getName().equals(other.getName()); } @@ -688,7 +686,7 @@ public class Schema implements ImmutableSchema { public boolean isAccessingDiskSummary(SummaryField field) { if (!field.getTransform().isInMemory()) return true; - if (field.getSources().size() == 0) return isAccessingDiskSummary(getName()); + if (field.getSources().isEmpty()) return isAccessingDiskSummary(getName()); for (SummaryField.Source source : field.getSources()) { if (isAccessingDiskSummary(source.getName())) return true; @@ -717,22 +715,6 @@ public class Schema implements ImmutableSchema { return owner.schemas().get(inherited.get()); } - /** - * For adding structs defined in document scope - * - * @param dt the struct to add - * @return self, for chaining - */ - public Schema addType(SDDocumentType dt) { - documentType.addType(dt); // TODO This is a very very dirty thing. It must go - return this; - } - - public Schema addAnnotation(SDAnnotationType dt) { - documentType.addAnnotation(dt); - return this; - } - public void validate(DeployLogger logger) { if (inherited.isPresent()) { if (! owner.schemas().containsKey(inherited.get())) diff --git a/config-model/src/main/java/com/yahoo/schema/document/SDField.java b/config-model/src/main/java/com/yahoo/schema/document/SDField.java index 972d3a57040..5cc51f9fedc 100644 --- a/config-model/src/main/java/com/yahoo/schema/document/SDField.java +++ b/config-model/src/main/java/com/yahoo/schema/document/SDField.java @@ -599,9 +599,9 @@ public class SDField extends Field implements TypedKey, ImmutableSDField { * per field, not per index) */ public void setRankType(RankType rankType) { - this.rankType=rankType; + this.rankType = rankType; for (Index index : getIndices().values()) { - if (index.getRankType()==null) + if (index.getRankType() == null) index.setRankType(rankType); } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java index abca3290a31..aa8f97784e1 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java @@ -33,10 +33,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo private final Integer transformerStartSequenceToken; private final Integer transformerEndSequenceToken; private final Integer transformerMaskToken; + + private final Integer transformerPadToken; private final Integer maxTokens; private final String transformerInputIds; private final String transformerAttentionMask; + private final Integer queryTokenId; + + private final Integer documentTokenId; + private final String transformerOutput; public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) { @@ -55,6 +61,9 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null); transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null); transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null); + transformerPadToken = getChildValue(xml, "transformer-pad-token").map(Integer::parseInt).orElse(null); + queryTokenId = getChildValue(xml, "query-token-id").map(Integer::parseInt).orElse(null); + documentTokenId = getChildValue(xml, "document-token-id").map(Integer::parseInt).orElse(null); transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null); transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null); transformerOutput = getChildValue(xml, "transformer-output").orElse(null); @@ -73,10 +82,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken); if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken); if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken); - onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); - onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); - onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); - onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); + if (transformerPadToken != null) b.transformerPadToken(transformerPadToken); + if (queryTokenId != null) b.queryTokenId(queryTokenId); + if (documentTokenId != null) b.documentTokenId(documentTokenId); + onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value))); + onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads); + onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads); + onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber())); } } diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc index 919253977ca..14fae90678d 100644 --- a/config-model/src/main/resources/schema/common.rnc +++ b/config-model/src/main/resources/schema/common.rnc @@ -138,9 +138,10 @@ ColBertEmbedder = element transformer-mask-token { xsd:integer }? & element transformer-input-ids { xsd:string }? & element transformer-attention-mask { xsd:string }? & - element transformer-token-type-ids { xsd:string }? & + element transformer-pad-token { xsd:integer }? & + element query-token-id { xsd:integer }? & + element document-token-id { xsd:integer }? & element transformer-output { xsd:string }? & - element normalize { xsd:boolean }? & OnnxModelExecutionParams & StartOfSequence & EndOfSequence diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml index 59c29aefc6a..e92679e3c96 100644 --- a/config-model/src/test/cfg/application/embed/services.xml +++ b/config-model/src/test/cfg/application/embed/services.xml @@ -67,9 +67,9 @@ <transformer-start-sequence-token>101</transformer-start-sequence-token> <transformer-end-sequence-token>102</transformer-end-sequence-token> <transformer-mask-token>103</transformer-mask-token> + <transformer-pad-token>0</transformer-pad-token> <transformer-input-ids>my_input_ids</transformer-input-ids> <transformer-attention-mask>my_attention_mask</transformer-attention-mask> - <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids> <transformer-output>my_output</transformer-output> <onnx-execution-mode>parallel</onnx-execution-mode> <onnx-intraop-threads>10</onnx-intraop-threads> diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java index 2532a5be863..4efffc8310a 100644 --- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java +++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java @@ -130,6 +130,10 @@ public class EmbedderTestCase { var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); + assertEquals(1, embedderCfg.queryTokenId()); + assertEquals(2, embedderCfg.documentTokenId()); + assertEquals(0, embedderCfg.transformerPadToken()); + assertEquals(103, embedderCfg.transformerMaskToken()); } @Test @@ -143,6 +147,10 @@ public class EmbedderTestCase { var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster); assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value()); assertEquals(-1, tokenizerCfg.maxLength()); + assertEquals(1, embedderCfg.queryTokenId()); + assertEquals(2, embedderCfg.documentTokenId()); + assertEquals(0, embedderCfg.transformerPadToken()); + assertEquals(103, embedderCfg.transformerMaskToken()); } @Test |