diff options
Diffstat (limited to 'indexinglanguage')
6 files changed, 86 insertions, 11 deletions
diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java index 29399a38fa9..1a9caaa5ca1 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/EmbedExpression.java @@ -119,7 +119,7 @@ public class EmbedExpression extends Expression { "Don't know what tensor type to embed into"); targetType = toTargetTensor(context.getInputType(this, outputField)); if ( ! validTarget(targetType)) - throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, " + + throw new VerificationException(this, "The embedding target field must either be a dense 1d tensor, a mapped 1d tensor," + "an array of dense 1d tensors, or a mixed 2d tensor"); context.setValueType(createdOutputType()); } @@ -134,14 +134,14 @@ public class EmbedExpression extends Expression { if ( ! ( dataType instanceof TensorDataType)) throw new IllegalArgumentException("Expected a tensor data type but got " + dataType); return ((TensorDataType)dataType).getTensorType(); - } private boolean validTarget(TensorType target) { - if (target.dimensions().size() == 1 && target.indexedSubtype().rank() == 1) - return true; - if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 && target.mappedSubtype().rank() == 1) + if (target.dimensions().size() == 1) //indexed or mapped 1d tensor return true; + if (target.dimensions().size() == 2 && target.indexedSubtype().rank() == 1 + && target.mappedSubtype().rank() == 1) + return true; //mixed mapped-indexed 2d tensor return false; } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/TokenizeExpression.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/TokenizeExpression.java index 169b79a62af..b807ad4cb65 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/TokenizeExpression.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/expressions/TokenizeExpression.java @@ -69,6 +69,9 @@ public final class TokenizeExpression extends Expression { if (config.hasNonDefaultMaxTokenLength()) { ret.append(" max-length:" + config.getMaxTokenizeLength()); } + if (config.hasNonDefaultMaxTermOccurrences()) { + ret.append(" max-occurrences:" + config.getMaxTermOccurrences()); + } return ret.toString(); } diff --git a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/linguistics/AnnotatorConfig.java b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/linguistics/AnnotatorConfig.java index 5c1bf0813c4..7b6f350d831 100644 --- a/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/linguistics/AnnotatorConfig.java +++ b/indexinglanguage/src/main/java/com/yahoo/vespa/indexinglanguage/linguistics/AnnotatorConfig.java @@ -95,6 +95,10 @@ public class AnnotatorConfig implements Cloneable { return maxTokenizeLength != DEFAULT_MAX_TOKENIZE_LENGTH; } + public boolean hasNonDefaultMaxTermOccurrences() { + return maxTermOccurrences != DEFAULT_MAX_TERM_OCCURRENCES; + } + @Override public boolean equals(Object obj) { if (!(obj instanceof AnnotatorConfig rhs)) { diff --git a/indexinglanguage/src/main/javacc/IndexingParser.jj b/indexinglanguage/src/main/javacc/IndexingParser.jj index ea05f33d745..42bbd26cee6 100644 --- a/indexinglanguage/src/main/javacc/IndexingParser.jj +++ b/indexinglanguage/src/main/javacc/IndexingParser.jj @@ -173,6 +173,7 @@ TOKEN : <JOIN: "join"> | <LOWER_CASE: "lowercase"> | <MAX_LENGTH: "max-length"> | + <MAX_OCCURRENCES: "max-occurrences"> | <NGRAM: "ngram"> | <NORMALIZE: "normalize"> | <NOW: "now"> | @@ -664,10 +665,12 @@ AnnotatorConfig tokenizeCfg() : AnnotatorConfig val = new AnnotatorConfig(annotatorCfg); String str = "SHORTEST"; Integer maxLength; + Integer maxTermOccurrences; } { ( <STEM> ( <COLON> str = string() ) ? { val.setStemMode(str); } | <MAX_LENGTH> <COLON> maxLength = integer() { val.setMaxTokenLength(maxLength); } | + <MAX_OCCURRENCES> <COLON> maxTermOccurrences = integer() { val.setMaxTermOccurrences(maxTermOccurrences); } | <NORMALIZE> { val.setRemoveAccents(true); } )+ { return val; } } diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java index 2b28756a6a8..6206c2efe7a 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/ScriptTestCase.java @@ -315,7 +315,7 @@ public class ScriptTestCase { } @Test - public void testArrayEmbedToSparseTensor() throws ParseException { + public void testArrayEmbedToMixedTensor() throws ParseException { Map<String, Embedder> embedders = Map.of("emb1", new MockEmbedder("myDocument.mySparseTensor")); TensorType tensorType = TensorType.fromSpec("tensor(passage{}, d[4])"); @@ -348,19 +348,65 @@ public class ScriptTestCase { sparseTensor.getTensor().get()); } + @SuppressWarnings("OptionalGetWithoutIsPresent") + @Test + public void testEmbedToSparseTensor() throws ParseException { + + Embedder mappedEmbedder = new MockEmbedder("myDocument.mySparseTensor", 0,true); + Map<String, Embedder> embedders = Map.of("emb1",mappedEmbedder); + + TensorType tensorType = TensorType.fromSpec("tensor(t{})"); + var expression = Expression.fromString("input text | embed | attribute 'mySparseTensor'", + new SimpleLinguistics(), + embedders); + + SimpleTestAdapter adapter = new SimpleTestAdapter(); + adapter.createField(new Field("text", DataType.STRING)); + + var tensorField = new Field("mySparseTensor", new TensorDataType(tensorType)); + adapter.createField(tensorField); + + var text = new StringFieldValue("abc"); + adapter.setValue("text", text); + expression.setStatementOutput(new DocumentType("myDocument"), tensorField); + + // Necessary to resolve output type + VerificationContext verificationContext = new VerificationContext(adapter); + assertEquals(new TensorDataType(tensorType), expression.verify(verificationContext)); + + ExecutionContext context = new ExecutionContext(adapter); + context.setValue(text); + expression.execute(context); + assertTrue(adapter.values.containsKey("mySparseTensor")); + var sparseTensor = (TensorFieldValue)adapter.values.get("mySparseTensor"); + assertEquals(Tensor.from(tensorType, "tensor(t{}):{97:97.0, 98:98.0, 99:99.0}"), + sparseTensor.getTensor().get()); + } + // An embedder which returns the char value of each letter in the input. */ private static class MockEmbedder implements Embedder { private final String expectedDestination; private final int addition; + private final boolean mappedTensor; + + public MockEmbedder(String expectedDestination) { - this(expectedDestination, 0); + this(expectedDestination, 0, false); + } + public MockEmbedder(String expectedDestination, boolean mapped) { + this(expectedDestination, 0,mapped); + } + + public MockEmbedder(String expectedDestination,int addition) { + this(expectedDestination, addition,false); } - public MockEmbedder(String expectedDestination, int addition) { + public MockEmbedder(String expectedDestination, int addition, boolean mappedTensor) { this.expectedDestination = expectedDestination; this.addition = addition; + this.mappedTensor = mappedTensor; } @Override @@ -372,11 +418,20 @@ public class ScriptTestCase { public Tensor embed(String text, Embedder.Context context, TensorType tensorType) { assertEquals(expectedDestination, context.getDestination()); var b = Tensor.Builder.of(tensorType); - for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) - b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + if (mappedTensor) { + for(int i = 0; i < text.length(); i++) { + var value = text.charAt(i) + addition; + b.cell(). + label(tensorType.dimensions().get(0).name(), text.charAt(i)) + .value(value); + } + } else { + for (int i = 0; i < tensorType.dimensions().get(0).size().get(); i++) + b.cell(i < text.length() ? text.charAt(i) + addition : 0, i); + + } return b.build(); } - } private void assertThrows(Runnable r, String msg) { diff --git a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java index 6acc2bf32f3..a7ed7ae3e72 100644 --- a/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java +++ b/indexinglanguage/src/test/java/com/yahoo/vespa/indexinglanguage/parser/ExpressionTestCase.java @@ -7,6 +7,8 @@ import com.yahoo.language.simple.SimpleLinguistics; import com.yahoo.vespa.indexinglanguage.expressions.*; import org.junit.Test; +import java.util.Optional; + import static org.junit.Assert.assertEquals; /** @@ -70,6 +72,7 @@ public class ExpressionTestCase { assertExpression(TokenizeExpression.class, "tokenize stem:\"ALL\" normalize"); assertExpression(TokenizeExpression.class, "tokenize stem:\"ALL\""); assertExpression(TokenizeExpression.class, "tokenize normalize"); + assertExpression(TokenizeExpression.class, "tokenize max-occurrences: 15", Optional.of("tokenize max-occurrences:15")); assertExpression(ToLongExpression.class, "to_long"); assertExpression(ToPositionExpression.class, "to_pos"); assertExpression(ToStringExpression.class, "to_string"); @@ -85,9 +88,16 @@ public class ExpressionTestCase { } private static void assertExpression(Class expectedClass, String str) throws ParseException { + assertExpression(expectedClass, str, Optional.empty()); + } + + private static void assertExpression(Class expectedClass, String str, Optional<String> expStr) throws ParseException { Linguistics linguistics = new SimpleLinguistics(); Expression foo = Expression.fromString(str, linguistics, Embedder.throwsOnUse.asMap()); assertEquals(expectedClass, foo.getClass()); + if (expStr.isPresent()) { + assertEquals(expStr.get(), foo.toString()); + } Expression bar = Expression.fromString(foo.toString(), linguistics, Embedder.throwsOnUse.asMap()); assertEquals(foo.hashCode(), bar.hashCode()); assertEquals(foo, bar); |