summaryrefslogtreecommitdiffstats
path: root/config-model
diff options
context:
space:
mode:
Diffstat (limited to 'config-model')
-rw-r--r--config-model/src/main/java/com/yahoo/schema/Schema.java24
-rw-r--r--config-model/src/main/java/com/yahoo/schema/document/SDField.java4
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java20
-rw-r--r--config-model/src/main/resources/schema/common.rnc5
-rw-r--r--config-model/src/test/cfg/application/embed/services.xml2
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java8
6 files changed, 33 insertions, 30 deletions
diff --git a/config-model/src/main/java/com/yahoo/schema/Schema.java b/config-model/src/main/java/com/yahoo/schema/Schema.java
index 6548907000a..279b5729ea1 100644
--- a/config-model/src/main/java/com/yahoo/schema/Schema.java
+++ b/config-model/src/main/java/com/yahoo/schema/Schema.java
@@ -293,8 +293,7 @@ public class Schema implements ImmutableSchema {
@Override
public List<ImmutableSDField> allFieldsList() {
- List<ImmutableSDField> all = new ArrayList<>();
- all.addAll(extraFieldList());
+ List<ImmutableSDField> all = new ArrayList<>(extraFieldList());
for (Field field : documentType.fieldSet()) {
all.add((ImmutableSDField) field);
}
@@ -668,11 +667,10 @@ public class Schema implements ImmutableSchema {
@Override
public boolean equals(Object o) {
- if (!(o instanceof Schema)) {
+ if (!(o instanceof Schema other)) {
return false;
}
- Schema other = (Schema)o;
return getName().equals(other.getName());
}
@@ -688,7 +686,7 @@ public class Schema implements ImmutableSchema {
public boolean isAccessingDiskSummary(SummaryField field) {
if (!field.getTransform().isInMemory()) return true;
- if (field.getSources().size() == 0) return isAccessingDiskSummary(getName());
+ if (field.getSources().isEmpty()) return isAccessingDiskSummary(getName());
for (SummaryField.Source source : field.getSources()) {
if (isAccessingDiskSummary(source.getName()))
return true;
@@ -717,22 +715,6 @@ public class Schema implements ImmutableSchema {
return owner.schemas().get(inherited.get());
}
- /**
- * For adding structs defined in document scope
- *
- * @param dt the struct to add
- * @return self, for chaining
- */
- public Schema addType(SDDocumentType dt) {
- documentType.addType(dt); // TODO This is a very very dirty thing. It must go
- return this;
- }
-
- public Schema addAnnotation(SDAnnotationType dt) {
- documentType.addAnnotation(dt);
- return this;
- }
-
public void validate(DeployLogger logger) {
if (inherited.isPresent()) {
if (! owner.schemas().containsKey(inherited.get()))
diff --git a/config-model/src/main/java/com/yahoo/schema/document/SDField.java b/config-model/src/main/java/com/yahoo/schema/document/SDField.java
index 972d3a57040..5cc51f9fedc 100644
--- a/config-model/src/main/java/com/yahoo/schema/document/SDField.java
+++ b/config-model/src/main/java/com/yahoo/schema/document/SDField.java
@@ -599,9 +599,9 @@ public class SDField extends Field implements TypedKey, ImmutableSDField {
* per field, not per index)
*/
public void setRankType(RankType rankType) {
- this.rankType=rankType;
+ this.rankType = rankType;
for (Index index : getIndices().values()) {
- if (index.getRankType()==null)
+ if (index.getRankType() == null)
index.setRankType(rankType);
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
index abca3290a31..aa8f97784e1 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/ColBertEmbedder.java
@@ -33,10 +33,16 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
private final Integer transformerStartSequenceToken;
private final Integer transformerEndSequenceToken;
private final Integer transformerMaskToken;
+
+ private final Integer transformerPadToken;
private final Integer maxTokens;
private final String transformerInputIds;
private final String transformerAttentionMask;
+ private final Integer queryTokenId;
+
+ private final Integer documentTokenId;
+
private final String transformerOutput;
public ColBertEmbedder(ApplicationContainerCluster cluster, Element xml, DeployState state) {
@@ -55,6 +61,9 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
transformerStartSequenceToken = getChildValue(xml, "transformer-start-sequence-token").map(Integer::parseInt).orElse(null);
transformerEndSequenceToken = getChildValue(xml, "transformer-end-sequence-token").map(Integer::parseInt).orElse(null);
transformerMaskToken = getChildValue(xml, "transformer-mask-token").map(Integer::parseInt).orElse(null);
+ transformerPadToken = getChildValue(xml, "transformer-pad-token").map(Integer::parseInt).orElse(null);
+ queryTokenId = getChildValue(xml, "query-token-id").map(Integer::parseInt).orElse(null);
+ documentTokenId = getChildValue(xml, "document-token-id").map(Integer::parseInt).orElse(null);
transformerInputIds = getChildValue(xml, "transformer-input-ids").orElse(null);
transformerAttentionMask = getChildValue(xml, "transformer-attention-mask").orElse(null);
transformerOutput = getChildValue(xml, "transformer-output").orElse(null);
@@ -73,10 +82,13 @@ public class ColBertEmbedder extends TypedComponent implements ColBertEmbedderCo
if (transformerStartSequenceToken != null) b.transformerStartSequenceToken(transformerStartSequenceToken);
if (transformerEndSequenceToken != null) b.transformerEndSequenceToken(transformerEndSequenceToken);
if (transformerMaskToken != null) b.transformerMaskToken(transformerMaskToken);
- onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
- onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
- onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
- onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
+ if (transformerPadToken != null) b.transformerPadToken(transformerPadToken);
+ if (queryTokenId != null) b.queryTokenId(queryTokenId);
+ if (documentTokenId != null) b.documentTokenId(documentTokenId);
+ onnxModelOptions.executionMode().ifPresent(value -> b.transformerExecutionMode(TransformerExecutionMode.Enum.valueOf(value)));
+ onnxModelOptions.interOpThreads().ifPresent(b::transformerInterOpThreads);
+ onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
+ onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
}
}
diff --git a/config-model/src/main/resources/schema/common.rnc b/config-model/src/main/resources/schema/common.rnc
index 919253977ca..14fae90678d 100644
--- a/config-model/src/main/resources/schema/common.rnc
+++ b/config-model/src/main/resources/schema/common.rnc
@@ -138,9 +138,10 @@ ColBertEmbedder =
element transformer-mask-token { xsd:integer }? &
element transformer-input-ids { xsd:string }? &
element transformer-attention-mask { xsd:string }? &
- element transformer-token-type-ids { xsd:string }? &
+ element transformer-pad-token { xsd:integer }? &
+ element query-token-id { xsd:integer }? &
+ element document-token-id { xsd:integer }? &
element transformer-output { xsd:string }? &
- element normalize { xsd:boolean }? &
OnnxModelExecutionParams &
StartOfSequence &
EndOfSequence
diff --git a/config-model/src/test/cfg/application/embed/services.xml b/config-model/src/test/cfg/application/embed/services.xml
index 59c29aefc6a..e92679e3c96 100644
--- a/config-model/src/test/cfg/application/embed/services.xml
+++ b/config-model/src/test/cfg/application/embed/services.xml
@@ -67,9 +67,9 @@
<transformer-start-sequence-token>101</transformer-start-sequence-token>
<transformer-end-sequence-token>102</transformer-end-sequence-token>
<transformer-mask-token>103</transformer-mask-token>
+ <transformer-pad-token>0</transformer-pad-token>
<transformer-input-ids>my_input_ids</transformer-input-ids>
<transformer-attention-mask>my_attention_mask</transformer-attention-mask>
- <transformer-token-type-ids>my_token_type_ids</transformer-token-type-ids>
<transformer-output>my_output</transformer-output>
<onnx-execution-mode>parallel</onnx-execution-mode>
<onnx-intraop-threads>10</onnx-intraop-threads>
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
index 2532a5be863..4efffc8310a 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/xml/EmbedderTestCase.java
@@ -130,6 +130,10 @@ public class EmbedderTestCase {
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://my/url/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
assertEquals(-1, tokenizerCfg.maxLength());
+ assertEquals(1, embedderCfg.queryTokenId());
+ assertEquals(2, embedderCfg.documentTokenId());
+ assertEquals(0, embedderCfg.transformerPadToken());
+ assertEquals(103, embedderCfg.transformerMaskToken());
}
@Test
@@ -143,6 +147,10 @@ public class EmbedderTestCase {
var tokenizerCfg = assertHuggingfaceTokenizerComponentPresent(cluster);
assertEquals("https://data.vespa.oath.cloud/onnx_models/multilingual-e5-base/tokenizer.json", modelReference(tokenizerCfg.model().get(0), "path").url().orElseThrow().value());
assertEquals(-1, tokenizerCfg.maxLength());
+ assertEquals(1, embedderCfg.queryTokenId());
+ assertEquals(2, embedderCfg.documentTokenId());
+ assertEquals(0, embedderCfg.transformerPadToken());
+ assertEquals(103, embedderCfg.transformerMaskToken());
}
@Test