aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java1
-rw-r--r--indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java73
-rw-r--r--indexinglanguage/src/main/javacc/IndexingParser.jj12
-rw-r--r--indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java250
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java34
-rw-r--r--model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java12
-rw-r--r--model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java1
8 files changed, 314 insertions, 94 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
index 9e0a3a0ba5c..92c930e16e0 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/container/component/SpladeEmbedder.java
@@ -61,5 +61,6 @@ public class SpladeEmbedder extends TypedComponent implements SpladeEmbedderConf
onnxModelOptions.intraOpThreads().ifPresent(b::transformerIntraOpThreads);
onnxModelOptions.gpuDevice().ifPresent(value -> b.transformerGpuDevice(value.deviceNumber()));
}
+
}
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
index 7c5e8912e49..5daf74a9723 100644
--- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
+++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java
@@ -27,6 +27,7 @@ public class EmbedExpression extends Expression {
private final Embedder embedder;
private final String embedderId;
+ private final List<String> embedderArguments;
/** The destination the embedding will be written to on the form [schema name].[field name] */
private String destination;
@@ -34,22 +35,23 @@ public class EmbedExpression extends Expression {
/** The target type we are embedding into. */
private TensorType targetType;
- public EmbedExpression(Map<String, Embedder> embedders, String embedderId) {
+ public EmbedExpression(Map<String, Embedder> embedders, String embedderId, List<String> embedderArguments) {
super(null);
this.embedderId = embedderId;
+ this.embedderArguments = List.copyOf(embedderArguments);
- boolean embedderIdProvided = embedderId != null && embedderId.length() > 0;
+ boolean embedderIdProvided = embedderId != null && !embedderId.isEmpty();
if (embedders.size() == 0) {
throw new IllegalStateException("No embedders provided"); // should never happen
}
+ else if (embedders.size() == 1 && ! embedderIdProvided) {
+ this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
+ }
else if (embedders.size() > 1 && ! embedderIdProvided) {
this.embedder = new Embedder.FailingEmbedder("Multiple embedders are provided but no embedder id is given. " +
"Valid embedders are " + validEmbedders(embedders));
}
- else if (embedders.size() == 1 && ! embedderIdProvided) {
- this.embedder = embedders.entrySet().stream().findFirst().get().getValue();
- }
else if ( ! embedders.containsKey(embedderId)) {
this.embedder = new Embedder.FailingEmbedder("Can't find embedder '" + embedderId + "'. " +
"Valid embedders are " + validEmbedders(embedders));
@@ -91,17 +93,51 @@ public class EmbedExpression extends Expression {
private Tensor embedArrayValue(ExecutionContext context) {
var input = (Array<StringFieldValue>)context.getValue();
var builder = Tensor.Builder.of(targetType);
+ if (targetType.rank() == 2)
+ embedArrayValueToRank2Tensor(input, builder, context);
+ else
+ embedArrayValueToRank3Tensor(input, builder, context);
+ return builder.build();
+ }
+
+ private void embedArrayValueToRank2Tensor(Array<StringFieldValue> input,
+ Tensor.Builder builder,
+ ExecutionContext context) {
+ String mappedDimension = targetType.mappedSubtype().dimensions().get(0).name();
+ String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name();
for (int i = 0; i < input.size(); i++) {
Tensor tensor = embed(input.get(i).getString(), targetType.indexedSubtype(), context);
for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
Tensor.Cell cell = cells.next();
builder.cell()
- .label(targetType.mappedSubtype().dimensions().get(0).name(), i)
- .label(targetType.indexedSubtype().dimensions().get(0).name(), cell.getKey().numericLabel(0))
+ .label(mappedDimension, i)
+ .label(indexedDimension, cell.getKey().numericLabel(0))
+ .value(cell.getValue());
+ }
+ }
+ }
+
+ private void embedArrayValueToRank3Tensor(Array<StringFieldValue> input,
+ Tensor.Builder builder,
+ ExecutionContext context) {
+ String outerMappedDimension = embedderArguments.get(0);
+ String innerMappedDimension = targetType.mappedSubtype().dimensionNames().stream().filter(d -> !d.equals(outerMappedDimension)).findFirst().get();
+ String indexedDimension = targetType.indexedSubtype().dimensions().get(0).name();
+ long indexedDimensionSize = targetType.indexedSubtype().dimensions().get(0).size().get();
+ var innerType = new TensorType.Builder().mapped(innerMappedDimension).indexed(indexedDimension,indexedDimensionSize).build();
+ int innerMappedDimensionIndex = innerType.indexOfDimensionAsInt(innerMappedDimension);
+ int indexedDimensionIndex = innerType.indexOfDimensionAsInt(indexedDimension);
+ for (int i = 0; i < input.size(); i++) {
+ Tensor tensor = embed(input.get(i).getString(), innerType, context);
+ for (Iterator<Tensor.Cell> cells = tensor.cellIterator(); cells.hasNext(); ) {
+ Tensor.Cell cell = cells.next();
+ builder.cell()
+ .label(outerMappedDimension, i)
+ .label(innerMappedDimension, cell.getKey().label(innerMappedDimensionIndex))
+ .label(indexedDimension, cell.getKey().numericLabel(indexedDimensionIndex))
.value(cell.getValue());
}
}
- return builder.build();
}
private Tensor embed(String input, TensorType targetType, ExecutionContext context) {
@@ -120,7 +156,17 @@ public class EmbedExpression extends Expression {
targetType = toTargetTensor(context.getInputType(this, outputField));
if ( ! validTarget(targetType))
throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," +
- "an array of dense 1d tensors, or a mixed 2d tensor");
+ "an array of dense 1d tensors, or a mixed 2d or 3d tensor");
+ if (targetType.rank() == 3) {
+ if (embedderArguments.size() != 1)
+ throw new VerificationException(this, "When the embedding target field is a 3d tensor " +
+ "the name of the tensor dimension that corresponds to the input array elements must " +
+ "be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...");
+ if ( ! targetType.mappedSubtype().dimensionNames().contains(embedderArguments.get(0)))
+ throw new VerificationException(this, "The dimension '" + embedderArguments.get(0) + "' given to embed " +
+ "is not a sparse dimension of the target type " + targetType);
+ }
+
context.setValueType(createdOutputType());
}
@@ -137,11 +183,12 @@ public class EmbedExpression extends Expression {
}
private boolean validTarget(TensorType target) {
- if (target.dimensions().size() == 1) //indexed or mapped 1d tensor
+ if (target.rank() == 1) // indexed or mapped 1d tensor
return true;
- if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1
- && target.mappedSubtype().rank() == 1)
- return true; //mixed mapped-indexed 2d tensor
+ if (target.rank() == 2 && target.indexedSubtype().rank() == 1)
+ return true; // mixed 2d tensor
+ if (target.rank() == 3 && target.indexedSubtype().rank() == 1)
+ return true; // mixed 3d tensor
return false;
}
diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj
index 42bbd26cee6..a3b4039408a 100644
--- a/indexinglanguage/src/main/javacc/IndexingParser.jj
+++ b/indexinglanguage/src/main/javacc/IndexingParser.jj
@@ -37,7 +37,6 @@ import com.yahoo.language.Linguistics;
/**
* @author Simon Thoresen Hult
- * @version $Id$
*/
public class IndexingParser {
@@ -386,11 +385,16 @@ Expression echoExp() : { }
Expression embedExp() :
{
- String val = "";
+ String embedderId = "";
+ String embedderArgument;
+ List<String> embedderArguments = new ArrayList<String>();
}
{
- ( <EMBED> [ LOOKAHEAD(2) val = identifier() ] )
- { return new EmbedExpression(embedders, val); }
+ (
+ <EMBED> [ LOOKAHEAD(2) embedderId = identifier() ]
+ ( LOOKAHEAD(2) embedderArgument = identifier() { embedderArguments.add(embedderArgument); } )*
+ )
+ { return new EmbedExpression(embedders, embedderId, embedderArguments); }
}
Expression exactExp() : { }
diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
index 6206c2efe7a..7fe55b738df 100644
--- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
+++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java
@@ -181,7 +181,7 @@ public class ScriptTestCase {
Expression.fromString(exp, new SimpleLinguistics(), Embedder.throwsOnUse.asMap());
Map<String, Embedder> embedder = Map.of(
- "emb1", new MockEmbedder("myDocument.myTensor")
+ "emb1", new MockIndexedEmbedder("myDocument.myTensor")
);
testEmbedStatement("input myText | embed | attribute 'myTensor'", embedder,
"input text", "[105, 110, 112, 117]");
@@ -193,8 +193,8 @@ public class ScriptTestCase {
null, null);
Map<String, Embedder> embedders = Map.of(
- "emb1", new MockEmbedder("myDocument.myTensor"),
- "emb2", new MockEmbedder("myDocument.myTensor", 1)
+ "emb1", new MockIndexedEmbedder("myDocument.myTensor"),
+ "emb2", new MockIndexedEmbedder("myDocument.myTensor", 1)
);
testEmbedStatement("input myText | embed emb1 | attribute 'myTensor'", embedders,
"my input", "[109.0, 121.0, 32.0, 105.0]");
@@ -243,7 +243,7 @@ public class ScriptTestCase {
@SuppressWarnings("unchecked")
@Test
public void testArrayEmbed() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.myTensorArray"));
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.myTensorArray"));
TensorType tensorType = TensorType.fromSpec("tensor(d[4])");
var expression = Expression.fromString("input myTextArray | for_each { embed } | attribute 'myTensorArray'",
@@ -277,7 +277,7 @@ public class ScriptTestCase {
@Test
public void testArrayEmbedWithConcatenation() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
var expression = Expression.fromString("input myTextArray | for_each { input title . \" \" . _ } | embed | attribute 'mySparseTensor'",
@@ -314,9 +314,10 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs */
@Test
- public void testArrayEmbedToMixedTensor() throws ParseException {
- Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor"));
+ public void testArrayEmbedTo2dMixedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockIndexedEmbedder("myDocument.mySparseTensor"));
TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])");
var expression = Expression.fromString("input myTextArray | embed | attribute 'mySparseTensor'",
@@ -348,17 +349,125 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 passage | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType));
+ adapter.createField(tensorField);
+
+ var array = new Array<StringFieldValue>(new ArrayDataType(DataType.STRING));
+ array.add(new StringFieldValue("first"));
+ array.add(new StringFieldValue("sec"));
+ adapter.setValue("myTextArray", array);
+ expression.setStatementOutput(new DocumentType("myDocument"), tensorField);
+
+ assertEquals(new TensorDataType(tensorType), expression.verify(new VerificationContext(adapter)));
+
+ ExecutionContext context = new ExecutionContext(adapter);
+ context.setValue(array);
+ expression.execute(context);
+ assertTrue(adapter.values.containsKey("mySparseTensor"));
+ var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor");
+ // The two "passages" are [first, sec], the middle (d=1) token encodes those letters
+ assertEquals(Tensor.from(tensorType,
+ """
+ {
+ {passage:0, token:0, d:0}: 101,
+ {passage:0, token:0, d:1}: 102,
+ {passage:0, token:0, d:2}: 103,
+ {passage:0, token:1, d:0}: 104,
+ {passage:0, token:1, d:1}: 105,
+ {passage:0, token:1, d:2}: 106,
+ {passage:0, token:2, d:0}: 113,
+ {passage:0, token:2, d:1}: 114,
+ {passage:0, token:2, d:2}: 115,
+ {passage:0, token:3, d:0}: 114,
+ {passage:0, token:3, d:1}: 115,
+ {passage:0, token:3, d:2}: 116,
+ {passage:0, token:4, d:0}: 115,
+ {passage:0, token:4, d:1}: 116,
+ {passage:0, token:4, d:2}: 117,
+ {passage:1, token:0, d:0}: 114,
+ {passage:1, token:0, d:1}: 115,
+ {passage:1, token:0, d:2}: 116,
+ {passage:1, token:1, d:0}: 100,
+ {passage:1, token:1, d:1}: 101,
+ {passage:1, token:1, d:2}: 102,
+ {passage:1, token:2, d:0}: 98,
+ {passage:1, token:2, d:1}: 99,
+ {passage:1, token:2, d:2}: 100
+ }
+ """),
+ sparseTensor.getTensor().get());
+ }
+
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor_missingDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("When the embedding target field is a 3d tensor the name of the tensor dimension that corresponds to the input array elements must be given as a second argument to embed, e.g: ... | embed colbert paragraph | ...",
+ e.getMessage());
+ }
+ }
+
+ /** Multiple paragraphs, and each paragraph leading to multiple vectors (ColBert style) */
+ @Test
+ public void testArrayEmbedTo3dMixedTensor_wrongDimensionArgument() throws ParseException {
+ Map<String, Embedder> embedders = Map.of("emb1", new MockMixedEmbedder("myDocument.mySparseTensor"));
+
+ TensorType tensorType = TensorType.fromSpec("tensor(passage{}, token{}, d[3])");
+ var expression = Expression.fromString("input myTextArray | embed emb1 d | attribute 'mySparseTensor'",
+ new SimpleLinguistics(),
+ embedders);
+
+ SimpleTestAdapter adapter = new SimpleTestAdapter();
+ adapter.createField(new Field("myTextArray", new ArrayDataType(DataType.STRING)));
+ adapter.createField(new Field("mySparseTensor", new TensorDataType(tensorType)));
+
+ try {
+ expression.verify(new VerificationContext(adapter));
+ fail("Expected exception");
+ }
+ catch (VerificationException e) {
+ assertEquals("The dimension 'd' given to embed is not a sparse dimension of the target type tensor(d[3],passage{},token{})",
+ e.getMessage());
+ }
+ }
+
@SuppressWarnings("OptionalGetWithoutIsPresent")
@Test
public void testEmbedToSparseTensor() throws ParseException {
-
- Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true);
+ Embedder mappedEmbedder = new MockMappedEmbedder("myDocument.mySparseTensor", 0);
Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder);
TensorType tensorType = TensorType.fromSpec("tensor(t{})");
var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'",
- new SimpleLinguistics(),
- embedders);
+ new SimpleLinguistics(),
+ embedders);
SimpleTestAdapter adapter = new SimpleTestAdapter();
adapter.createField(new Field("text", DataType.STRING));
@@ -383,30 +492,23 @@ public class ScriptTestCase {
sparseTensor.getTensor().get());
}
- // An embedder which returns the char value of each letter in the input. */
- private static class MockEmbedder implements Embedder {
-
- private final String expectedDestination;
- private final int addition;
-
- private final boolean mappedTensor;
-
-
- public MockEmbedder(String expectedDestination) {
- this(expectedDestination, 0, false);
- }
- public MockEmbedder(String expectedDestination, boolean mapped) {
- this(expectedDestination, 0,mapped);
+ private void assertThrows(Runnable r, String msg) {
+ try {
+ r.run();
+ fail();
+ } catch (IllegalStateException e) {
+ assertEquals(e.getMessage(), msg);
}
+ }
- public MockEmbedder(String expectedDestination,int addition) {
- this(expectedDestination, addition,false);
- }
+ private static abstract class MockEmbedder implements Embedder {
- public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) {
+ final String expectedDestination;
+ final int addition;
+
+ public MockEmbedder(String expectedDestination, int addition) {
this.expectedDestination = expectedDestination;
this.addition = addition;
- this.mappedTensor = mappedTensor;
}
@Override
@@ -414,32 +516,84 @@ public class ScriptTestCase {
return null;
}
+ void verifyDestination(Embedder.Context context) {
+ assertEquals(expectedDestination, context.getDestination());
+ }
+
+ }
+
+ /** An embedder which returns the char value of each letter in the input as a 1d indexed tensor. */
+ private static class MockIndexedEmbedder extends MockEmbedder {
+
+ public MockIndexedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockIndexedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
@Override
public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
- assertEquals(expectedDestination, context.getDestination());
+ verifyDestination(context);
var b = Tensor.Builder.of(tensorType);
- if (mappedTensor) {
- for(int i = 0; i < text.length(); i++) {
- var value = text.charAt(i) + addition;
- b.cell().
- label(tensorType.dimensions().get(0).name(), text.charAt(i))
- .value(value);
- }
- } else {
- for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
- b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++)
+ b.cell(i < text.length() ? text.charAt(i) + addition : 0, i);
+ return b.build();
+ }
- }
+ }
+
+ /** An embedder which returns the char value of each letter in the input as a 1d mapped tensor. */
+ private static class MockMappedEmbedder extends MockEmbedder {
+
+ public MockMappedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockMappedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ verifyDestination(context);
+ var b = Tensor.Builder.of(tensorType);
+ for (int i = 0; i < text.length(); i++)
+ b.cell().label(tensorType.dimensions().get(0).name(), text.charAt(i)).value(text.charAt(i) + addition);
return b.build();
}
+
}
- private void assertThrows(Runnable r, String msg) {
- try {
- r.run();
- fail();
- } catch (IllegalStateException e) {
- assertEquals(e.getMessage(), msg);
+ /**
+ * An embedder which returns the char value of each letter in the input as a 2d mixed tensor where each input
+ * char becomes an indexed dimension containing input-1, input, input+1.
+ */
+ private static class MockMixedEmbedder extends MockEmbedder {
+
+ public MockMixedEmbedder(String expectedDestination) {
+ this(expectedDestination, 0);
+ }
+
+ public MockMixedEmbedder(String expectedDestination, int addition) {
+ super(expectedDestination, addition);
+ }
+
+ @Override
+ public Tensor embed(String text, Embedder.Context context, TensorType tensorType) {
+ verifyDestination(context);
+ var b = Tensor.Builder.of(tensorType);
+ String mappedDimension = tensorType.mappedSubtype().dimensions().get(0).name();
+ String indexedDimension = tensorType.indexedSubtype().dimensions().get(0).name();
+ for (int i = 0; i < text.length(); i++) {
+ for (int j = 0; j < 3; j++) {
+ b.cell().label(mappedDimension, i)
+ .label(indexedDimension, j)
+ .value(text.charAt(i) + addition + j - 1);
+ }
+ }
+ return b.build();
}
}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
index 8c39cc8c813..f76bfd28abf 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/ColBertEmbedder.java
@@ -18,7 +18,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.functions.Reduce;
+
import java.nio.file.Paths;
import java.util.Map;
import java.util.List;
@@ -34,10 +34,14 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES
* This embedder uses a HuggingFace tokenizer to produce a token sequence that is then input to a transformer model.
*
* See col-bert-embedder.def for configurable parameters.
+ *
* @author bergum
*/
@Beta
public class ColBertEmbedder extends AbstractComponent implements Embedder {
+
+ private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
+
private final Embedder.Runtime runtime;
private final String inputIdsName;
private final String attentionMaskName;
@@ -117,7 +121,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
private void validateName(Map<String, TensorType> types, String name, String type) {
if (!types.containsKey(name)) {
throw new IllegalArgumentException("Model does not contain required " + type + ": '" + name + "'. " +
- "Model contains: " + String.join(",", types.keySet()));
+ "Model contains: " + String.join(",", types.keySet()));
}
}
@@ -128,9 +132,9 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
@Override
public Tensor embed(String text, Context context, TensorType tensorType) {
- if (!verifyTensorType(tensorType)) {
+ if ( ! validTensorType(tensorType)) {
throw new IllegalArgumentException("Invalid colbert embedder tensor target destination. " +
- "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
+ "Wanted a mixed 2-d mapped-indexed tensor, got " + tensorType);
}
if (context.getDestination().startsWith("query")) {
return embedQuery(text, context, tensorType);
@@ -196,7 +200,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
int dims = tensorType.indexedSubtype().dimensions().get(0).size().get().intValue();
if (dims != result.shape()[2]) {
throw new IllegalArgumentException("Token vector dimensionality does not" +
- " match indexed dimensionality of " + dims);
+ " match indexed dimensionality of " + dims);
}
Tensor resultTensor = toFloatTensor(result, tensorType, input.inputIds.size());
runtime.sampleEmbeddingLatency((System.nanoTime() - start) / 1_000_000d, context);
@@ -213,13 +217,13 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
Tensor attentionMaskTensor = createTensorRepresentation(input.attentionMask, "d1");
var inputs = Map.of(inputIdsName, inputIdsTensor.expand("d0"),
- attentionMaskName, attentionMaskTensor.expand("d0"));
+ attentionMaskName, attentionMaskTensor.expand("d0"));
Map<String, Tensor> outputs = evaluator.evaluate(inputs);
Tensor tokenEmbeddings = outputs.get(outputName);
IndexedTensor result = (IndexedTensor) tokenEmbeddings;
Tensor contextualEmbeddings;
- int maxTokens = input.inputIds.size(); //Retain all token vectors, including PAD tokens.
+ int maxTokens = input.inputIds.size(); // Retain all token vectors, including PAD tokens.
if (tensorType.valueType() == TensorType.Value.INT8) {
contextualEmbeddings = toBitTensor(result, tensorType, maxTokens);
} else {
@@ -230,7 +234,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
public static Tensor toFloatTensor(IndexedTensor result, TensorType type, int nTokens) {
- if(result.shape().length != 3)
+ if (result.shape().length != 3)
throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
int size = type.indexedSubtype().dimensions().size();
if (size != 1)
@@ -253,8 +257,7 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
public static Tensor toBitTensor(IndexedTensor result, TensorType type, int nTokens) {
if (type.valueType() != TensorType.Value.INT8)
- throw new IllegalArgumentException("Only a int8 tensor type can be" +
- " the destination of bit packing");
+ throw new IllegalArgumentException("Only a int8 tensor type can be the destination of bit packing");
if(result.shape().length != 3)
throw new IllegalArgumentException("Expected onnx result to have 3-dimensions [batch, sequence, dim]");
@@ -264,8 +267,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
int wantedDimensionality = type.indexedSubtype().dimensions().get(0).size().get().intValue();
int resultDimensionality = (int)result.shape()[2];
if (resultDimensionality != 8 * wantedDimensionality) {
- throw new IllegalArgumentException("Not possible to pack " + resultDimensionality
- + " + dimensions into " + wantedDimensionality + " dimensions");
+ throw new IllegalArgumentException("Not possible to pack " + resultDimensionality +
+ " + dimensions into " + wantedDimensionality + " dimensions");
}
Tensor.Builder builder = Tensor.Builder.of(type);
for (int token = 0; token < nTokens; token++) {
@@ -302,9 +305,8 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
return unpacker.evaluate(context).asTensor();
}
- protected boolean verifyTensorType(TensorType target) {
- return target.dimensions().size() == 2 &&
- target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1;
+ protected boolean validTensorType(TensorType target) {
+ return target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1;
}
private IndexedTensor createTensorRepresentation(List<Long> input, String dimension) {
@@ -316,5 +318,5 @@ public class ColBertEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
- private static final String PUNCTUATION = "!\"#$%&'()*+,-./:;<=>?@[\\]^_`{|}~";
+
}
diff --git a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
index 3a64083c623..58bd4deb659 100644
--- a/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
+++ b/model-integration/src/main/java/ai/vespa/embedding/SpladeEmbedder.java
@@ -25,9 +25,12 @@ import static com.yahoo.language.huggingface.ModelInfo.TruncationStrategy.LONGES
/**
* A SPLADE embedder that is embedding text to a 1-d mapped tensor. For interpretability, the tensor labels
* are the subword strings from the wordpiece vocabulary that has a score above a threshold (default 0.0).
+ *
+ * @author bergum
*/
@Beta
public class SpladeEmbedder extends AbstractComponent implements Embedder {
+
private final Embedder.Runtime runtime;
private final String inputIdsName;
private final String attentionMaskName;
@@ -110,7 +113,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
public Tensor embed(String text, Context context, TensorType tensorType) {
if (!verifyTensorType(tensorType)) {
throw new IllegalArgumentException("Invalid splade embedder tensor destination. " +
- "Wanted a mapped 1-d tensor, got " + tensorType);
+ "Wanted a mapped 1-d tensor, got " + tensorType);
}
var start = System.nanoTime();
@@ -132,17 +135,17 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
return spladeTensor;
}
-
/**
* Sparsify the output tensor by applying a threshold on the log of the relu of the output.
* This uses generic tensor reduce+map, and is slightly slower than a custom unrolled variant.
+ *
* @param modelOutput the model output tensor of shape d1,dim where d1 is the sequence length and dim is size
- * of the vocabulary
+ * of the vocabulary
* @param tensorType the type of the destination tensor
* @return A mapped tensor with the terms from the vocab that has a score above the threshold
*/
private Tensor sparsifyReduce(Tensor modelOutput, TensorType tensorType) {
- //Remove batch dim, batch size of 1
+ // Remove batch dim, batch size of 1
Tensor output = modelOutput.reduce(Reduce.Aggregator.max, "d0", "d1");
Tensor logOfRelu = output.map((x) -> Math.log(1 + (x > 0 ? x : 0)));
IndexedTensor vocab = (IndexedTensor) logOfRelu;
@@ -227,6 +230,7 @@ public class SpladeEmbedder extends AbstractComponent implements Embedder {
}
return builder.build();
}
+
@Override
public void deconstruct() {
evaluator.close();
diff --git a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
index 0cae94c372a..be75c4d3351 100644
--- a/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
+++ b/model-integration/src/test/java/ai/vespa/embedding/ColBertEmbedderTest.java
@@ -19,6 +19,9 @@ import java.util.Set;
import static org.junit.Assert.*;
import static org.junit.Assume.assumeTrue;
+/**
+ * @author bergum
+ */
public class ColBertEmbedderTest {
@Test
@@ -67,23 +70,24 @@ public class ColBertEmbedderTest {
assertEmbed("tensor<float>(qt{},x[128])", "this is a query", queryContext);
assertThrows(IllegalArgumentException.class, () -> {
- //throws because int8 is not supported for query context
+ // throws because int8 is not supported for query context
assertEmbed("tensor<int8>(qt{},x[16])", "this is a query", queryContext);
});
+
assertThrows(IllegalArgumentException.class, () -> {
- //throws because 16 is less than model output (128) and we want float
+ // throws because 16 is less than model output (128) and we want float
assertEmbed("tensor<float>(qt{},x[16])", "this is a query", queryContext);
});
assertThrows(IllegalArgumentException.class, () -> {
- //throws because 128/8 does not fit into 15
+ // throws because 128/8 does not fit into 15
assertEmbed("tensor<int8>(qt{},x[15])", "this is a query", indexingContext);
});
}
@Test
public void testInputTensorsWordPiece() {
- //wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999]
+ // wordPiece tokenizer("this is a query !") -> [2023, 2003, 1037, 23032, 999]
List<Long> tokens = List.of(2023L, 2003L, 1037L, 23032L, 999L);
ColBertEmbedder.TransformerInput input = embedder.buildTransformerInput(tokens,10,true);
assertEquals(10,input.inputIds().size());
@@ -100,7 +104,7 @@ public class ColBertEmbedderTest {
@Test
public void testInputTensorsSentencePiece() {
- //Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711]
+ // Sentencepiece tokenizer("this is a query !") -> [903, 83, 10, 41, 1294, 711]
// ! is mapped to 711 and is a punctuation character
List<Long> tokens = List.of(903L, 83L, 10L, 41L, 1294L, 711L);
ColBertEmbedder.TransformerInput input = multiLingualEmbedder.buildTransformerInput(tokens,10,true);
@@ -109,7 +113,7 @@ public class ColBertEmbedderTest {
assertEquals(List.of(0L, 3L, 903L, 83L, 10L, 41L, 1294L, 711L, 2L, 250001L),input.inputIds());
assertEquals(List.of(1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 1L, 0L),input.attentionMask());
- //NO padding for document side and 711 (punctuation) is now filtered out
+ // NO padding for document side and 711 (punctuation) is now filtered out
input = multiLingualEmbedder.buildTransformerInput(tokens,10,false);
assertEquals(8,input.inputIds().size());
assertEquals(8,input.attentionMask().size());
@@ -156,12 +160,12 @@ public class ColBertEmbedderTest {
sb.append(" ");
}
String text = sb.toString();
- Long now = System.currentTimeMillis();
+ long now = System.currentTimeMillis();
int n = 1000;
for (int i = 0; i < n; i++) {
assertEmbed("tensor<float>(dt{},x[128])", text, indexingContext);
}
- Long elapsed = (System.currentTimeMillis() - now);
+ long elapsed = (System.currentTimeMillis() - now);
System.out.println("Elapsed time: " + elapsed + " ms");
}
@@ -170,7 +174,7 @@ public class ColBertEmbedderTest {
Tensor result = embedder.embed(text, context, destType);
assertEquals(destType,result.type());
MixedTensor mixedTensor = (MixedTensor) result;
- if(context == queryContext) {
+ if (context == queryContext) {
assertEquals(32*mixedTensor.denseSubspaceSize(),mixedTensor.size());
}
return result;
@@ -200,12 +204,14 @@ public class ColBertEmbedderTest {
static final ColBertEmbedder multiLingualEmbedder;
static final Embedder.Context indexingContext;
static final Embedder.Context queryContext;
+
static {
indexingContext = new Embedder.Context("schema.indexing");
queryContext = new Embedder.Context("query(qt)");
embedder = getEmbedder();
multiLingualEmbedder = getMultiLingualEmbedder();
}
+
private static ColBertEmbedder getEmbedder() {
String vocabPath = "src/test/models/onnx/transformer/real_tokenizer.json";
String modelPath = "src/test/models/onnx/transformer/colbert-dummy-v2.onnx";
@@ -235,4 +241,5 @@ public class ColBertEmbedderTest {
return new ColBertEmbedder(new OnnxRuntime(), Embedder.Runtime.testInstance(), builder.build());
}
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index ac9dc4e4eca..d27c7cf0168 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -634,6 +634,7 @@ public interface Tensor {
public Builder value(double cellValue) {
return tensorBuilder.cell(addressBuilder.build(), cellValue);
}
+
public Builder value(float cellValue) {
return tensorBuilder.cell(addressBuilder.build(), cellValue);
}