aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2019-12-16 20:29:48 +0100
committerGitHub <noreply@github.com>2019-12-16 20:29:48 +0100
commit40e8a8b4ac2a021ede5a5babd42976ab313ce0b8 (patch)
tree253ee93b860f20a9c1deeb4cf0f6a31945bf6bf8
parent6f5128b0d386b712aa94be3336a967990b096111 (diff)
parentbaa6a81aa07f37a543c836710b4c65b7831fd9db (diff)
Merge pull request #11548 from vespa-engine/bratseth/mixed-tensor-parse
Bratseth/mixed tensor parse
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg2
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd2
-rw-r--r--searchlib/abi-spec.json18
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java48
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj102
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java46
-rw-r--r--vespajlib/abi-spec.json37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java161
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java49
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java404
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java38
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java80
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java20
19 files changed, 812 insertions, 226 deletions
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 554a36aef86..9e9dfae2bc7 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -133,7 +133,7 @@ rankprofile[].fef.property[].value "3"
rankprofile[].fef.property[].name "vespa.rank.firstphase"
rankprofile[].fef.property[].value "rankingExpression(firstphase)"
rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
-rankprofile[].fef.property[].value "reduce(tensor(d0[1])(attribute{x:(rankingExpression(functionNotLabel))}), sum)"
+rankprofile[].fef.property[].value "reduce(tensor(d0[1])(attribute{x:rankingExpression(functionNotLabel)}), sum)"
rankprofile[].fef.property[].name "vespa.type.attribute.f2"
rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
rankprofile[].fef.property[].name "vespa.type.attribute.f3"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index c3380bed19c..6e0e7e3e148 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -93,7 +93,7 @@ search tensor {
rank-profile profile8 {
first-phase {
- expression: sum(tensor(d0[1])(attribute{x:(functionNotLabel)}))
+ expression: sum(tensor(d0[1])(attribute{x:functionNotLabel()}))
}
function functionNotLabel() {
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index debcd11fdbd..bde3b6abb6c 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -894,9 +894,9 @@
"public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
- "public final com.yahoo.tensor.TensorType tensorType()",
+ "public final com.yahoo.tensor.TensorType tensorType(java.util.List)",
"public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()",
- "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)",
+ "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder, java.util.List)",
"public final java.lang.String tensorFunctionName()",
"public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()",
"public final com.yahoo.searchlib.rankingexpression.rule.Function binaryFunctionName()",
@@ -910,9 +910,16 @@
"public final java.util.List tagCommaLeadingList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
- "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
"public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)",
- "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.tensor.functions.DynamicTensor mixedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
+ "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)",
+ "public final void keyValueOrMixedBlock(com.yahoo.tensor.TensorType, java.util.List, java.util.Map)",
+ "public final void keyValue(com.yahoo.tensor.TensorType, java.util.Map)",
+ "public final void mixedBlock(com.yahoo.tensor.TensorType, java.util.List, java.util.Map)",
+ "public final java.util.List indexedTensorCells()",
+ "public final void indexedTensorCellSubspaceList(java.util.List)",
+ "public final void indexedTensorCellSubspace(java.util.List)",
"public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)",
"public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)",
"public final void labelAndDimensionValues(java.util.List)",
@@ -1616,7 +1623,8 @@
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
"public static java.util.Map wrapScalars(java.util.Map)",
- "public static java.util.List wrapScalars(java.util.List)",
+ "public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.util.List, java.lang.String, java.util.List, java.util.Map)",
+ "public static java.util.List wrapScalars(com.yahoo.tensor.TensorType, java.util.List, java.util.List)",
"public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
],
"fields": []
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 0a67ab5534e..6200515462b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
+import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
@@ -18,6 +19,7 @@ import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedHashMap;
@@ -91,8 +93,50 @@ public class TensorFunctionNode extends CompositeNode {
return functions;
}
- public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
- return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
+ public static void wrapScalarBlock(TensorType type,
+ List<String> dimensionOrder,
+ String mappedDimensionLabel,
+ List<ExpressionNode> nodes,
+ Map<TensorAddress, ScalarFunction<Reference>> receivingMap) {
+ TensorType denseSubtype = new TensorType(type.valueType(),
+ type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
+ List<String> denseDimensionOrder = new ArrayList<>(dimensionOrder);
+ denseDimensionOrder.retainAll(denseSubtype.dimensionNames());
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype, denseDimensionOrder);
+ if (indexes.size() != nodes.size())
+ throw new IllegalArgumentException("At '" + mappedDimensionLabel + "': Need " + indexes.size() +
+ " values to fill a dense subspace of " + type + " but got " + nodes.size());
+ for (ExpressionNode node : nodes) {
+ indexes.next();
+
+ // Insert the mapped dimension into the dense subspace address of indexes
+ String[] labels = new String[type.rank()];
+ int indexedDimensionsIndex = 0;
+ int allDimensionsIndex = 0;
+ for (TensorType.Dimension dimension : type.dimensions()) {
+ if (dimension.isIndexed())
+ labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]);
+ else
+ labels[allDimensionsIndex++] = mappedDimensionLabel;
+ }
+
+ receivingMap.put(TensorAddress.of(labels), wrapScalar(node));
+ }
+ }
+
+ public static List<ScalarFunction<Reference>> wrapScalars(TensorType type,
+ List<String> dimensionOrder,
+ List<ExpressionNode> nodes) {
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type, dimensionOrder);
+ if (indexes.size() != nodes.size())
+ throw new IllegalArgumentException("Need " + indexes.size() + " values to fill " + type + " but got " + nodes.size());
+
+ List<ScalarFunction<Reference>> wrapped = new ArrayList<>(nodes.size());
+ while (indexes.hasNext()) {
+ indexes.next();
+ wrapped.add(wrapScalar(nodes.get((int)indexes.toSourceValueIndex())));
+ }
+ return wrapped;
}
public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index de3ad6b5d8c..e413e398183 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -6,7 +6,6 @@
*/
options {
CACHE_TOKENS = true;
- STATIC = false;
DEBUG_PARSER = false;
USER_TOKEN_MANAGER = false;
ERROR_REPORTING = true;
@@ -476,13 +475,14 @@ TensorFunctionNode tensorConcat() :
TensorFunctionNode tensorGenerate() :
{
TensorType type;
+ List dimensionOrder = new ArrayList();
TensorFunctionNode expression;
}
{
- <TENSOR> type = tensorType()
+ <TENSOR> type = tensorType(dimensionOrder)
(
expression = tensorGenerateBody(type) |
- expression = tensorValueBody(type)
+ expression = tensorValueBody(type, dimensionOrder)
)
{ return expression; }
}
@@ -501,7 +501,7 @@ TensorFunctionNode tensorRange() :
TensorType type;
}
{
- <RANGE> type = tensorType()
+ <RANGE> type = tensorType(null)
{ return new TensorFunctionNode(new Range(type)); }
}
@@ -510,7 +510,7 @@ TensorFunctionNode tensorDiag() :
TensorType type;
}
{
- <DIAG> type = tensorType()
+ <DIAG> type = tensorType(null)
{ return new TensorFunctionNode(new Diag(type)); }
}
@@ -519,7 +519,7 @@ TensorFunctionNode tensorRandom() :
TensorType type;
}
{
- <RANDOM> type = tensorType()
+ <RANDOM> type = tensorType(null)
{ return new TensorFunctionNode(new Random(type)); }
}
@@ -619,7 +619,7 @@ Reduce.Aggregator tensorReduceAggregator() :
{ return Reduce.Aggregator.valueOf(token.image); }
}
-TensorType tensorType() :
+TensorType tensorType(List dimensionOrder) :
{
TensorType.Builder builder;
TensorType.Value valueType;
@@ -628,8 +628,8 @@ TensorType tensorType() :
valueType = optionalTensorValueTypeParameter()
{ builder = new TensorType.Builder(valueType); }
<LBRACE>
- ( tensorTypeDimension(builder) ) ?
- ( <COMMA> tensorTypeDimension(builder) ) *
+ ( tensorTypeDimension(builder, dimensionOrder) ) ?
+ ( <COMMA> tensorTypeDimension(builder, dimensionOrder) ) *
<RBRACE>
{ return builder.build(); }
}
@@ -643,13 +643,17 @@ TensorType.Value optionalTensorValueTypeParameter() :
{ return TensorType.Value.fromId(valueType); }
}
-void tensorTypeDimension(TensorType.Builder builder) :
+void tensorTypeDimension(TensorType.Builder builder, List dimensionOrder) :
{
String name;
int size;
}
{
name = identifier()
+ { // Keep track of the order in which dimensions are written, if necessary
+ if (dimensionOrder != null)
+ dimensionOrder.add(name);
+ }
(
( <LCURLY> <RCURLY> { builder.mapped(name); } ) |
LOOKAHEAD(2) ( <LSQUARE> <RSQUARE> { builder.indexed(name); } ) |
@@ -832,15 +836,16 @@ Value primitiveValue() :
{ return Value.parse(sign + token.image); }
}
-TensorFunctionNode tensorValueBody(TensorType type) :
+TensorFunctionNode tensorValueBody(TensorType type, List dimensionOrder) :
{
DynamicTensor dynamicTensor;
}
{
<COLON>
(
+ LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type, dimensionOrder) |
dynamicTensor = mappedTensorValueBody(type) |
- dynamicTensor = indexedTensorValueBody(type)
+ dynamicTensor = indexedTensorValueBody(type, dimensionOrder)
)
{ return new TensorFunctionNode(dynamicTensor); }
}
@@ -851,23 +856,82 @@ DynamicTensor mappedTensorValueBody(TensorType type) :
}
{
<LCURLY>
- ( tensorCell(type, cells))*
+ [ tensorCell(type, cells)]
( <COMMA> tensorCell(type, cells))*
<RCURLY>
{ return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
-DynamicTensor indexedTensorValueBody(TensorType type) :
+DynamicTensor mixedTensorValueBody(TensorType type, List dimensionOrder) :
+{
+ java.util.Map cells = new LinkedHashMap();
+}
+{
+ <LCURLY>
+ keyValueOrMixedBlock(type, dimensionOrder, cells)
+ ( <COMMA> keyValueOrMixedBlock(type, dimensionOrder, cells))*
+ <RCURLY>
+ { return DynamicTensor.from(type, cells); }
+}
+
+DynamicTensor indexedTensorValueBody(TensorType type, List dimensionOrder) :
+{
+ List cells;
+}
+{
+ cells = indexedTensorCells()
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(type, dimensionOrder, cells)); }
+}
+
+void keyValueOrMixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) : {}
+{
+ LOOKAHEAD(3) mixedBlock(type, dimensionOrder, cellMap) | keyValue(type, cellMap)
+}
+
+void keyValue(TensorType type, java.util.Map cellMap) :
+{
+ String label;
+ ExpressionNode value;
+}
+{
+ label = tag() <COLON> value = expression()
+ { cellMap.put(TensorAddress.ofLabels(label), TensorFunctionNode.wrapScalar(value)); }
+}
+
+void mixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) :
+{
+ String label;
+ List cells;
+}
+{
+ label = tag() <COLON> cells = indexedTensorCells()
+ { TensorFunctionNode.wrapScalarBlock(type, dimensionOrder, label, cells, cellMap); }
+}
+
+List indexedTensorCells() :
{
List cells = new ArrayList();
+}
+{
+ <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE>
+ { return cells; }
+}
+
+void indexedTensorCellSubspaceList(List cells) :
+{
+}
+{
+ indexedTensorCellSubspace(cells) ( LOOKAHEAD(2) <COMMA> indexedTensorCellSubspace(cells) )*
+}
+
+void indexedTensorCellSubspace(List cells) :
+{
ExpressionNode value;
}
{
- <LSQUARE> // TODO: Parse inner square brackets properly
- ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
- ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
-// <RSQUARE>
- { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
+ ( <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE> )
+ |
+ ( value = expression() { cells.add(value); } )
}
void tensorCell(TensorType type, java.util.Map cells) :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index e3d3ac7b2e1..b750a7607cc 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -191,7 +191,7 @@ public class RankingExpressionTestCase {
// Accessing a function in a dynamic tensor, short form
assertSerialization(List.of("tensor(x[2]):{{x:0}:rankingExpression(scalarFunction),{x:1}:rankingExpression(scalarFunction)}"),
- "tensor(x[2]):[scalarFunction(), scalarFunction()]]",
+ "tensor(x[2]):[scalarFunction(), scalarFunction()]",
functions, false);
// Accessing a function in a dynamic tensor, long form
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 00750c70d2c..26861dd3cd6 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -14,6 +14,7 @@ import com.yahoo.tensor.TensorType;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
/**
* Tests expression evaluation
@@ -402,6 +403,51 @@ public class EvaluationTestCase {
"{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }");
tester.assertEvaluates("tensor<float>(d0[1],x[3]):[[1.0, 0.5, 0.25]]",
"tensor<float>(d0[1],x[3]):[[one,one_half,a_quarter]]");
+ tester.assertEvaluates("tensor(x[2],y[3]):[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]",
+ "tensor(x[2],y[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]");
+ tester.assertEvaluates("tensor(x{},y[2]):{{x:a,y:0}:1.0, {x:a,y:1}:0.5, {x:b,y:0}:0.25, {x:b,y:1}:2.0}",
+ "tensor(x{},y[2]):{{x:a,y:0}:one, {x:a,y:1}:one_half, {x:b,y:0}:a_quarter, {x:b,y:1}:2}");
+ tester.assertEvaluates("tensor(x{},y[2]):{a:[1.0, 0.5], b:[0.25, 2]}",
+ "tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter, 2]}");
+ tester.assertEvaluates("tensor(key{},x[2],y[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," +
+ " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}",
+ "tensor(key{},x[2],y[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," +
+ " key2:[[1,2,3],[4,5,6]]}");
+ tester.assertEvaluates("tensor(x{}):{{x:a}:1, {x:b}:-2, {x:cee}:0.5}", "tensor(x{}):{a:1, b:-2, cee:one_half}");
+
+ // Opposite order in the expression:
+ // - indexed
+ tester.assertEvaluates("tensor(x[3],y[2]):[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]",
+ "tensor(y[2],x[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]");
+ // - mixed
+ tester.assertEvaluates("tensor(key{},x[3],y[2]):{key1:[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]," +
+ " key2:[[1.0, 4.00], [2.0,5.0], [3.00, 6.0]]}",
+ "tensor(key{},y[2],x[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," +
+ " key2:[[1,2,3],[4,5,6]]}");
+ // Opposite order in literal parsing:
+ // - indexed
+ tester.assertEvaluates("tensor(y[2],x[3]):[[1,0.25,0.5],[0.5,0.25,1]]",
+ "tensor(x[3],y[2]):[[one,one_half], [a_quarter,a_quarter], [one_half,one]]");
+ // - mixed
+ tester.assertEvaluates("tensor(key{},y[2],x[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," +
+ " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}",
+ "tensor(key{},x[3],y[2]):{key1:[[one,a_quarter],[one_half,one_half],[a_quarter,one]]," +
+ " key2:[[1,4],[2,5],[3,6]]}");
+
+ try {
+ new RankingExpression("tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter]}");
+ fail("Expected exception");
+ }
+ catch (Exception e) {
+ assertEquals("At 'b': Need 2 values to fill a dense subspace of tensor(x{},y[2]) but got 1", e.getMessage());
+ }
+ try {
+ new RankingExpression("tensor(x[2],y[3]):[[1,2,3,4],[4,5,6]]");
+ fail("Expected exception");
+ }
+ catch (Exception e) {
+ assertEquals("Need 6 values to fill tensor(x[2],y[3]) but got 7", e.getMessage());
+ }
}
@Test
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 19328f5dbb2..a4a9a1e1b24 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -693,6 +693,7 @@
"methods": [
"public void <init>(int)",
"public com.yahoo.tensor.DimensionSizes$Builder set(int, long)",
+ "public com.yahoo.tensor.DimensionSizes$Builder add(long)",
"public long size(int)",
"public int dimensions()",
"public com.yahoo.tensor.DimensionSizes build()"
@@ -776,15 +777,14 @@
},
"com.yahoo.tensor.IndexedTensor$BoundBuilder": {
"superClass": "com.yahoo.tensor.IndexedTensor$Builder",
- "interfaces": [],
+ "interfaces": [
+ "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder"
+ ],
"attributes": [
"public",
"abstract"
],
- "methods": [
- "public abstract void cellByDirectIndex(long, double)",
- "public abstract void cellByDirectIndex(long, float)"
- ],
+ "methods": [],
"fields": []
},
"com.yahoo.tensor.IndexedTensor$Builder": {
@@ -813,6 +813,21 @@
],
"fields": []
},
+ "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract com.yahoo.tensor.TensorType type()",
+ "public abstract void cellByDirectIndex(long, double)",
+ "public abstract void cellByDirectIndex(long, float)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.IndexedTensor$Indexes": {
"superClass": "java.lang.Object",
"interfaces": [],
@@ -822,14 +837,17 @@
],
"methods": [
"public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType, java.util.List)",
"public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)",
"public com.yahoo.tensor.TensorAddress toAddress()",
"public long[] indexesCopy()",
"public long[] indexesForReading()",
+ "public long toSourceValueIndex()",
"public java.util.List toList()",
"public java.lang.String toString()",
"public abstract long size()",
- "public abstract void next()"
+ "public abstract void next()",
+ "public abstract boolean hasNext()"
],
"fields": [
"protected final long[] indexes"
@@ -943,6 +961,7 @@
],
"methods": [
"public long denseSubspaceSize()",
+ "public com.yahoo.tensor.IndexedTensor$DirectIndexBuilder denseSubspaceBuilder(com.yahoo.tensor.TensorAddress)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)",
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])",
@@ -1035,8 +1054,8 @@
],
"methods": [
"public void <init>(int)",
- "public void add(java.lang.String, long)",
- "public void add(java.lang.String, java.lang.String)",
+ "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, long)",
+ "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.PartialAddress build()"
],
"fields": []
@@ -1236,6 +1255,7 @@
"methods": [
"public void <init>()",
"public static com.yahoo.tensor.TensorAddress of(java.lang.String[])",
+ "public static varargs com.yahoo.tensor.TensorAddress ofLabels(java.lang.String[])",
"public static varargs com.yahoo.tensor.TensorAddress of(long[])",
"public abstract int size()",
"public abstract java.lang.String label(int)",
@@ -1395,6 +1415,7 @@
"public"
],
"methods": [
+ "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)",
"public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
"public com.yahoo.tensor.TensorType$Value valueType()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index d81c02fb75f..202817ece42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -71,6 +71,7 @@ public final class DimensionSizes {
*/
public final static class Builder {
+ private int dimensionIndex = 0;
private long[] sizes;
public Builder(int dimensions) {
@@ -82,6 +83,11 @@ public final class DimensionSizes {
return this;
}
+ public Builder add(long size) {
+ sizes[dimensionIndex++] = size;
+ return this;
+ }
+
/**
* Returns the length of this in the nth dimension
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 176ddfefc13..ba3a35e8eda 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor {
indexes.next();
// start brackets
- for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++)
+ for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
b.append("[");
// value
@@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalStateException("Unexpected value type " + type.valueType());
// end bracket and comma
- for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++)
+ for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
b.append("]");
if (index < size() - 1)
b.append(", ");
@@ -375,8 +375,22 @@ public abstract class IndexedTensor implements Tensor {
}
+ public interface DirectIndexBuilder {
+
+ TensorType type();
+
+
+
+ /** Sets a value by its <i>standard value order</i> index */
+ void cellByDirectIndex(long index, double value);
+
+ /** Sets a value by its <i>standard value order</i> index */
+ void cellByDirectIndex(long index, float value);
+
+ }
+
/** A bound builder can create the double array directly */
- public static abstract class BoundBuilder extends Builder {
+ public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder {
private DimensionSizes sizes;
@@ -393,14 +407,16 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type);
this.sizes = sizes;
}
- BoundBuilder fill(float [] values) {
+
+ BoundBuilder fill(float[] values) {
long index = 0;
for (float value : values) {
cellByDirectIndex(index++, value);
}
return this;
}
- BoundBuilder fill(double [] values) {
+
+ BoundBuilder fill(double[] values) {
long index = 0;
for (double value : values) {
cellByDirectIndex(index++, value);
@@ -410,12 +426,6 @@ public abstract class IndexedTensor implements Tensor {
DimensionSizes sizes() { return sizes; }
- /** Sets a value by its <i>standard value order</i> index */
- public abstract void cellByDirectIndex(long index, double value);
-
- /** Sets a value by its <i>standard value order</i> index */
- public abstract void cellByDirectIndex(long index, float value);
-
}
/**
@@ -767,6 +777,10 @@ public abstract class IndexedTensor implements Tensor {
return of(DimensionSizes.of(type));
}
+ public static Indexes of(TensorType type, List<String> iterateDimensionOrder) {
+ return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type));
+ }
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -779,6 +793,10 @@ public abstract class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
+ private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) {
+ return of(sizes, sizes, iterateDimensions);
+ }
+
private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) {
return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
@@ -812,6 +830,16 @@ public abstract class IndexedTensor implements Tensor {
}
}
+ private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) {
+ if (dimensionNames == null) return completeIterationOrder(type.rank());
+
+ List<Integer> iterationDimensions = new ArrayList<>(type.rank());
+ for (int i = 0; i < type.rank(); i++)
+ iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get());
+ return iterationDimensions;
+ }
+
+ /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */
private static List<Integer> completeIterationOrder(int length) {
List<Integer> iterationDimensions = new ArrayList<>(length);
for (int i = 0; i < length; i++)
@@ -844,7 +872,7 @@ public abstract class IndexedTensor implements Tensor {
/** Returns a copy of the indexes of this which must not be modified */
public long[] indexesForReading() { return indexes; }
- long toSourceValueIndex() {
+ public long toSourceValueIndex() {
return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
@@ -869,27 +897,15 @@ public abstract class IndexedTensor implements Tensor {
public abstract void next();
- /** Returns the number of dimensions from the right which are currently at the start position (0) */
- int rightDimensionsWhichAreAtStart() {
- int dimension = indexes.length - 1;
- int atStartCount = 0;
- while (dimension >= 0 && indexes[dimension] == 0) {
- atStartCount++;
- dimension--;
- }
- return atStartCount;
- }
+ /** Returns whether further values are available by calling next() */
+ public abstract boolean hasNext();
+
+ /** Returns the number of dimensions in iteration order which are currently at the start position (0) */
+ abstract int nextDimensionsAtStart();
+
+ /** Returns the number of dimensions in iteration order which are currently at their end position */
+ abstract int nextDimensionsAtEnd();
- /** Returns the number of dimensions from the right which are currently at the end position */
- int rightDimensionsWhichAreAtEnd() {
- int dimension = indexes.length - 1;
- int atEndCount = 0;
- while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) {
- atEndCount++;
- dimension--;
- }
- return atEndCount;
- }
}
private final static class EmptyIndexes extends Indexes {
@@ -904,10 +920,21 @@ public abstract class IndexedTensor implements Tensor {
@Override
public void next() {}
+ @Override
+ public boolean hasNext() { return false; }
+
+ @Override
+ int nextDimensionsAtStart() { return 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return 0; }
+
}
private final static class SingleValueIndexes extends Indexes {
+ private boolean exhausted = false;
+
private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) {
super(sourceSizes, iterateSizes, indexes);
}
@@ -916,7 +943,16 @@ public abstract class IndexedTensor implements Tensor {
public long size() { return 1; }
@Override
- public void next() {}
+ public void next() { exhausted = true; }
+
+ @Override
+ public boolean hasNext() { return ! exhausted; }
+
+ @Override
+ int nextDimensionsAtStart() { return 1; }
+
+ @Override
+ int nextDimensionsAtEnd() { return 1; }
}
@@ -945,7 +981,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -957,6 +993,31 @@ public abstract class IndexedTensor implements Tensor {
indexes[iterateDimensions.get(iterateDimensionsIndex)]++;
}
+ @Override
+ public boolean hasNext() {
+ for (int iterateDimension : iterateDimensions) {
+ if (indexes[iterateDimension] + 1 < dimensionSizes().size(iterateDimension))
+ return true; // some dimension is not at the end
+ }
+ return false;
+ }
+
+ @Override
+ int nextDimensionsAtStart() {
+ int dimension = 0;
+ while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0)
+ dimension++;
+ return dimension;
+ }
+
+ @Override
+ int nextDimensionsAtEnd() {
+ int dimension = 0;
+ while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1)
+ dimension++;
+ return dimension;
+ }
+
}
/** In this case we can reuse the source index computation for the iteration index */
@@ -969,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() {
+ public long toSourceValueIndex() {
return lastComputedSourceValueIndex = super.toSourceValueIndex();
}
@@ -1016,7 +1077,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -1026,11 +1087,22 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() { return currentSourceValueIndex; }
+ public long toSourceValueIndex() { return currentSourceValueIndex; }
@Override
long toIterationValueIndex() { return currentIterationValueIndex; }
+ @Override
+ public boolean hasNext() {
+ return indexes[iterateDimension] + 1 < size;
+ }
+
+ @Override
+ int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; }
+
}
/** In this case we only need to keep track of one index */
@@ -1068,7 +1140,7 @@ public abstract class IndexedTensor implements Tensor {
* Advances this to the next cell in the standard indexed tensor cell order.
* The first call to this will put it at the first position.
*
- * @throws RuntimeException if this is called more times than its size
+ * @throws RuntimeException if this is called when hasNext returns false
*/
@Override
public void next() {
@@ -1077,11 +1149,22 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() { return currentValueIndex; }
+ public boolean hasNext() {
+ return indexes[iterateDimension] + 1 < size;
+ }
+
+ @Override
+ public long toSourceValueIndex() { return currentValueIndex; }
@Override
long toIterationValueIndex() { return currentValueIndex; }
+ @Override
+ int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 1cde1fcdbb7..0c4efe78113 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -217,25 +217,34 @@ public class MixedTensor implements Tensor {
public static class BoundBuilder extends Builder {
/** For each sparse partial address, hold a dense subspace */
- final private Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>();
- final private Index.Builder indexBuilder;
- final private Index index;
+ private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>();
+ private final Index.Builder indexBuilder;
+ private final Index index;
+ private final TensorType denseSubtype;
private BoundBuilder(TensorType type) {
super(type);
indexBuilder = new Index.Builder(type);
index = indexBuilder.index();
+ denseSubtype = new TensorType(type.valueType(),
+ type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
}
public long denseSubspaceSize() {
return index.denseSubspaceSize();
}
- private double[] denseSubspace(TensorAddress sparsePartial) {
- if (!denseSubspaceMap.containsKey(sparsePartial)) {
- denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]);
+ private double[] denseSubspace(TensorAddress sparseAddress) {
+ if (!denseSubspaceMap.containsKey(sparseAddress)) {
+ denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]);
}
- return denseSubspaceMap.get(sparsePartial);
+ return denseSubspaceMap.get(sparseAddress);
+ }
+
+ public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) {
+ double[] values = new double[(int)denseSubspaceSize()];
+ denseSubspaceMap.put(sparseAddress, values);
+ return new DenseSubspaceBuilder(denseSubtype, values);
}
@Override
@@ -280,7 +289,6 @@ public class MixedTensor implements Tensor {
}
-
/**
* Temporarily stores all cells to find bounds of indexed dimensions,
* then creates a tensor using BoundBuilder. This is due to the
@@ -491,6 +499,31 @@ public class MixedTensor implements Tensor {
}
+ private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder {
+
+ private final TensorType type;
+ private final double[] values;
+
+ public DenseSubspaceBuilder(TensorType type, double[] values) {
+ this.type = type;
+ this.values = values;
+ }
+
+ @Override
+ public TensorType type() { return type; }
+
+ @Override
+ public void cellByDirectIndex(long index, double value) {
+ values[(int)index] = value;
+ }
+
+ @Override
+ public void cellByDirectIndex(long index, float value) {
+ values[(int)index] = value;
+ }
+
+ }
+
public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
TensorType.Builder builder = new TensorType.Builder(valueType);
for (TensorType.Dimension dimension : dimensions) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index 4eca9c47402..84f26d96725 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -122,16 +122,18 @@ public class PartialAddress {
labels = new Object[size];
}
- public void add(String dimensionName, long label) {
+ public Builder add(String dimensionName, long label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
+ return this;
}
- public void add(String dimensionName, String label) {
+ public Builder add(String dimensionName, String label) {
dimensionNames[index] = dimensionName;
labels[index] = label;
index++;
+ return this;
}
public PartialAddress build() {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index e705445c5a7..4770ad1b1f0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -21,6 +21,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return new StringTensorAddress(labels);
}
+ public static TensorAddress ofLabels(String ... labels) {
+ return new StringTensorAddress(labels);
+ }
+
public static TensorAddress of(long ... labels) {
return new NumericTensorAddress(labels);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 4d8b34b7dcf..5a1fd98a009 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.ArrayList;
+import java.util.List;
import java.util.Optional;
/**
@@ -9,14 +11,30 @@ import java.util.Optional;
class TensorParser {
static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) {
+ try {
+ return tensorFromBody(tensorString, explicitType);
+ } catch (IllegalArgumentException e) {
+ throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" +
+ (explicitType.isPresent() ? " of type " + explicitType.get() : ""),
+ e);
+ }
+ }
+
+ static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) {
Optional<TensorType> type;
String valueString;
+ // The order in which dimensions are written in the type string.
+ // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than
+ // the natural order of the tensor.
+ List<String> dimensionOrder;
+
tensorString = tensorString.trim();
if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
- TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ dimensionOrder = new ArrayList<>();
+ TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder);
if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + explicitType.get());
@@ -26,14 +44,19 @@ class TensorParser {
else {
type = explicitType;
valueString = tensorString;
+ dimensionOrder = null;
}
valueString = valueString.trim();
- if (valueString.startsWith("{")) {
+ if (valueString.startsWith("{") &&
+ (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) {
return tensorFromSparseValueString(valueString, type);
}
+ else if (valueString.startsWith("{")) {
+ return tensorFromMixedValueString(valueString, type, dimensionOrder);
+ }
else if (valueString.startsWith("[")) {
- return tensorFromDenseValueString(valueString, type);
+ return tensorFromDenseValueString(valueString, type, dimensionOrder);
}
else {
if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
@@ -54,8 +77,7 @@ class TensorParser {
String s = valueString.substring(1).trim(); // remove tensor start
int firstKeyOrTensorEnd = s.indexOf('}');
if (firstKeyOrTensorEnd < 0)
- throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" +
- valueString + "'");
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'");
String addressBody = s.substring(0, firstKeyOrTensorEnd).trim();
if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor
if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor
@@ -79,138 +101,312 @@ class TensorParser {
try {
valueString = valueString.trim();
Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString)));
- return fromCellString(builder, valueString);
+ SparseValueParser parser = new SparseValueParser(valueString, builder);
+ parser.parse();
+ return builder.build();
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" +
- valueString + "'");
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
}
}
- private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromMixedValueString(String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder) {
+ if (type.isEmpty())
+ throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
+ "on the form 'tensor(dimensions):...");
+ if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1)
+ throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " +
+ "but got " + type.get());
+
+
+ try {
+ valueString = valueString.trim();
+ if ( ! valueString.startsWith("{") && valueString.endsWith("}"))
+ throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
+ Tensor.Builder builder = Tensor.Builder.of(type.get());
+ MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder);
+ parser.parse();
+ return builder.build();
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('");
+ }
+ }
+
+ private static Tensor tensorFromDenseValueString(String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
- if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty())))
+ if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty())))
throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " +
"only dense dimensions with a given size");
- IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get());
- long index = 0;
- int currentChar;
- int nextNumberEnd = 0;
- // Since we know the dimensions the brackets are just syntactic sugar:
- while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) {
- nextNumberEnd = nextStopCharIndex(currentChar, valueString);
- if (currentChar == nextNumberEnd) return builder.build();
-
- TensorType.Value cellValueType = builder.type().valueType();
- String cellValueString = valueString.substring(currentChar, nextNumberEnd);
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get());
+ new DenseValueParser(valueString, dimensionOrder, builder).parse();
+ return builder.build();
+ }
+
+ private static abstract class ValueParser {
+
+ protected final String string;
+ protected int position = 0;
+
+ protected ValueParser(String string) {
+ this.string = string;
+ }
+
+ protected void skipSpace() {
+ while (position < string.length() && string.charAt(position) == ' ')
+ position++;
+ }
+
+ protected void consume(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
+ "' but got the end of the string");
+ if ( string.charAt(position) != character)
+ throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
+ "' but got '" + string.charAt(position) + "'");
+ position++;
+ }
+
+ protected Number consumeNumber(TensorType.Value cellValueType) {
+ skipSpace();
+
+ int nextNumberEnd = nextStopCharIndex(position, string);
try {
- if (cellValueType == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(index, Double.parseDouble(cellValueString));
- else if (cellValueType == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(index, Float.parseFloat(cellValueString));
- else
- throw new IllegalArgumentException(cellValueType + " is not supported");
+ String cellValueString = string.substring(position, nextNumberEnd);
+ try {
+ if (cellValueType == TensorType.Value.DOUBLE)
+ return Double.parseDouble(cellValueString);
+ else if (cellValueType == TensorType.Value.FLOAT)
+ return Float.parseFloat(cellValueString);
+ else
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ } catch (NumberFormatException e) {
+ throw new IllegalArgumentException("At value position " + position + ": '" +
+ cellValueString + "' is not a valid " + cellValueType);
+ }
}
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("At index " + index + ": '" +
- cellValueString + "' is not a valid " + cellValueType);
+ finally {
+ position = nextNumberEnd;
}
- index++;
}
- return builder.build();
- }
- /** Returns the position of the next character that should contain a number, or if none the string length */
- private static int nextStartCharIndex(int charIndex, String valueString) {
- for (; charIndex < valueString.length(); charIndex++) {
- if (valueString.charAt(charIndex) == ']') continue;
- if (valueString.charAt(charIndex) == '[') continue;
- if (valueString.charAt(charIndex) == ',') continue;
- if (valueString.charAt(charIndex) == ' ') continue;
- return charIndex;
+ protected int nextStopCharIndex(int position, String valueString) {
+ while (position < valueString.length()) {
+ if (valueString.charAt(position) == ',') return position;
+ if (valueString.charAt(position) == ']') return position;
+ if (valueString.charAt(position) == '}') return position;
+ position++;
+ }
+ throw new IllegalArgumentException("Malformed tensor value '" + valueString +
+ "': Expected a ',', ']' or '}' after position " + position);
}
- return valueString.length();
+
}
- private static int nextStopCharIndex(int charIndex, String valueString) {
- while (charIndex < valueString.length()) {
- if (valueString.charAt(charIndex) == ',') return charIndex;
- if (valueString.charAt(charIndex) == ']') return charIndex;
- charIndex++;
+ /** A single-use dense tensor string parser */
+ private static class DenseValueParser extends ValueParser {
+
+ private final IndexedTensor.DirectIndexBuilder builder;
+ private final IndexedTensor.Indexes indexes;
+ private final boolean hasInnerStructure;
+
+ public DenseValueParser(String string,
+ List<String> dimensionOrder,
+ IndexedTensor.DirectIndexBuilder builder) {
+ super(string);
+ this.builder = builder;
+ indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder);
+ hasInnerStructure = hasInnerStructure(string);
}
- throw new IllegalArgumentException("Malformed tensor value '" + valueString +
- "': Expected a ',' or ']' after position " + charIndex);
- }
- private static Tensor fromCellString(Tensor.Builder builder, String s) {
- int index = 1;
- index = skipSpace(index, s);
- while (index + 1 < s.length()) {
- int keyOrTensorEnd = s.indexOf('}', index);
- TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type());
- if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty
- addLabels(s.substring(index, keyOrTensorEnd + 1), addressBuilder);
- index = keyOrTensorEnd + 1;
- index = skipSpace(index, s);
- if ( s.charAt(index) != ':')
- throw new IllegalArgumentException("Expecting a ':' after " + s.substring(index) + ", got '" + s + "'");
- index++;
- }
- int valueEnd = s.indexOf(',', index);
- if (valueEnd < 0) { // last value
- valueEnd = s.indexOf('}', index);
- if (valueEnd < 0)
- throw new IllegalArgumentException("A tensor string must end by '}'");
+ public void parse() {
+ if (!hasInnerStructure)
+ consume('[');
+
+ while (indexes.hasNext()) {
+ indexes.next();
+ for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++)
+ consume('[');
+ consumeNumber();
+ for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++)
+ consume(']');
+ if (indexes.hasNext())
+ consume(',');
}
- TensorAddress address = addressBuilder.build();
- TensorType.Value cellValueType = builder.type().valueType();
- String cellValueString = s.substring(index, valueEnd).trim();
- try {
- if (cellValueType == TensorType.Value.DOUBLE)
- builder.cell(address, Double.parseDouble(cellValueString));
- else if (cellValueType == TensorType.Value.FLOAT)
- builder.cell(address, Float.parseFloat(cellValueString));
+ if (!hasInnerStructure)
+ consume(']');
+ }
+
+ public int position() { return position; }
+
+ /** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */
+ private static boolean hasInnerStructure(String valueString) {
+ valueString = valueString.trim();
+ valueString = valueString.substring(1);
+ int firstLeftBracket = valueString.indexOf('[');
+ return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']');
+ }
+
+ protected void consumeNumber() {
+ Number number = consumeNumber(builder.type().valueType());
+ if (builder.type().valueType() == TensorType.Value.DOUBLE)
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number);
+ else if (builder.type().valueType() == TensorType.Value.FLOAT)
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number);
+ }
+
+ }
+
+ /**
+ * Parses mixed tensor short forms {a:[1,2], ...} AND 1d mapped tensor short form {a:b, ...}.
+ */
+ private static class MixedValueParser extends ValueParser {
+
+ private final Tensor.Builder builder;
+ private List<String> dimensionOrder;
+
+ public MixedValueParser(String string, List<String> dimensionOrder, Tensor.Builder builder) {
+ super(string);
+ this.dimensionOrder = dimensionOrder;
+ this.builder = builder;
+ }
+
+ private void parse() {
+ TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension));
+ if (dimensionOrder != null)
+ dimensionOrder.remove(mappedDimension.name());
+
+ skipSpace();
+ consume('{');
+ skipSpace();
+ while (position + 1 < string.length()) {
+ int labelEnd = string.indexOf(':', position);
+ if (labelEnd <= position)
+ throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...}, or {sparse-label:value, ...}");
+ String label = string.substring(position, labelEnd);
+ position = labelEnd + 1;
+ skipSpace();
+
+ TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build();
+ if (builder.type().rank() > 1)
+ parseDenseSubspace(mappedAddress, dimensionOrder);
else
- throw new IllegalArgumentException(cellValueType + " is not supported");
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" +
- cellValueString + "' is not a valid " + cellValueType);
+ consumeNumber(mappedAddress);
+ if ( ! consumeOptional(','))
+ consume('}');
+ skipSpace();
}
+ }
- index = valueEnd+1;
- index = skipSpace(index, s);
+ private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) {
+ DenseValueParser denseParser = new DenseValueParser(string.substring(position),
+ denseDimensionOrder,
+ ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(sparseAddress));
+ denseParser.parse();
+ position+= denseParser.position();
+ }
+
+ private boolean consumeOptional(char character) {
+ skipSpace();
+
+ if (position >= string.length())
+ return false;
+ if ( string.charAt(position) != character)
+ return false;
+
+ position++;
+ return true;
+ }
+
+ private void consumeNumber(TensorAddress address) {
+ Number number = consumeNumber(builder.type().valueType());
+ if (builder.type().valueType() == TensorType.Value.DOUBLE)
+ builder.cell(address, (Double)number);
+ else if (builder.type().valueType() == TensorType.Value.FLOAT)
+ builder.cell(address, (Float)number);
}
- return builder.build();
- }
- private static int skipSpace(int index, String s) {
- while (index < s.length() && s.charAt(index) == ' ')
- index++;
- return index;
}
- /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */
- private static void addLabels(String mapAddressString, TensorAddress.Builder builder) {
- mapAddressString = mapAddressString.trim();
- if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}")))
- throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'");
+ private static class SparseValueParser extends ValueParser {
- String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim();
- if (addressBody.isEmpty()) return;
+ private final Tensor.Builder builder;
- for (String elementString : addressBody.split(",")) {
- String[] pair = elementString.split(":");
- if (pair.length != 2)
- throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " +
- "got '" + elementString + "'");
- String dimension = pair[0].trim();
- builder.add(dimension, pair[1].trim());
+ public SparseValueParser(String string, Tensor.Builder builder) {
+ super(string);
+ this.builder = builder;
}
+
+ private void parse() {
+ consume('{');
+ skipSpace();
+ while (position + 1 < string.length()) {
+ int keyOrTensorEnd = string.indexOf('}', position);
+ TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type());
+ if (keyOrTensorEnd < string.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty
+ addLabels(string.substring(position, keyOrTensorEnd + 1), addressBuilder);
+ position = keyOrTensorEnd + 1;
+ skipSpace();
+ consume(':');
+ }
+ int valueEnd = string.indexOf(',', position);
+ if (valueEnd < 0) { // last value
+ valueEnd = string.indexOf('}', position);
+ if (valueEnd < 0)
+ throw new IllegalArgumentException("A sparse tensor string must end by '}'");
+ }
+
+ TensorAddress address = addressBuilder.build();
+ TensorType.Value cellValueType = builder.type().valueType();
+ String cellValueString = string.substring(position, valueEnd).trim();
+ try {
+ if (cellValueType == TensorType.Value.DOUBLE)
+ builder.cell(address, Double.parseDouble(cellValueString));
+ else if (cellValueType == TensorType.Value.FLOAT)
+ builder.cell(address, Float.parseFloat(cellValueString));
+ else
+ throw new IllegalArgumentException(cellValueType + " is not supported");
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" +
+ cellValueString + "' is not a valid " + cellValueType);
+ }
+
+ position = valueEnd+1;
+ skipSpace();
+ }
+ }
+
+ /** Creates a tensor address from a string on the form {dimension1:label1,dimension2:label2,...} */
+ private static void addLabels(String mapAddressString, TensorAddress.Builder builder) {
+ mapAddressString = mapAddressString.trim();
+ if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}")))
+ throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'");
+
+ String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim();
+ if (addressBody.isEmpty()) return;
+
+ for (String elementString : addressBody.split(",")) {
+ String[] pair = elementString.split(":");
+ if (pair.length != 2)
+ throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " +
+ "got '" + elementString + "'");
+ String dimension = pair[0].trim();
+ builder.add(dimension, pair[1].trim());
+ }
+ }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 95cc70804e2..ca3f8ff28a4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -82,7 +82,7 @@ public class TensorType {
private final TensorType mappedSubtype;
- private TensorType(Value valueType, Collection<Dimension> dimensions) {
+ public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
List<Dimension> dimensionList = new ArrayList<>(dimensions);
Collections.sort(dimensionList);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index def3ab6b4ec..4fdb0906740 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -24,6 +24,13 @@ public class TensorTypeParser {
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
public static TensorType fromSpec(String specString) {
+ return fromSpec(specString, null);
+ }
+
+ /**
+ * @param dimensionOrder if not null, this will be populated with the dimension names in the order they are written
+ */
+ static TensorType fromSpec(String specString, List<String> dimensionOrder) {
specString = specString.trim();
if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING))
throw formatException(specString);
@@ -48,10 +55,14 @@ public class TensorTypeParser {
List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
String trimmedElement = element.trim();
- boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
- tryParseMappedDimension(trimmedElement, dimensions);
- if ( ! success)
+ TensorType.Dimension dimension = tryParseIndexedDimension(trimmedElement);
+ if (dimension == null)
+ dimension = tryParseMappedDimension(trimmedElement);
+ if (dimension == null)
throw formatException(specString, "Dimension '" + element + "' is on the wrong format");
+ dimensions.add(dimension);
+ if (dimensionOrder != null)
+ dimensionOrder.add(dimension.name());
}
return new TensorType.Builder(valueType, dimensions).build();
}
@@ -68,29 +79,26 @@ public class TensorTypeParser {
}
}
- private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
+ private static TensorType.Dimension tryParseIndexedDimension(String element) {
Matcher matcher = indexedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
String dimensionSize = matcher.group(2);
- if (dimensionSize.isEmpty()) {
- dimensions.add(TensorType.Dimension.indexed(dimensionName));
- } else {
- dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)));
- }
- return true;
+ if (dimensionSize.isEmpty())
+ return TensorType.Dimension.indexed(dimensionName);
+ else
+ return TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize));
}
- return false;
+ return null;
}
- private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) {
+ private static TensorType.Dimension tryParseMappedDimension(String element) {
Matcher matcher = mappedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
- dimensions.add(TensorType.Dimension.mapped(dimensionName));
- return true;
+ return TensorType.Dimension.mapped(dimensionName);
}
- return false;
+ return null;
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index 1928971820c..6f9a5c13886 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -22,6 +22,12 @@ public class TensorParserTestCase {
}
@Test
+ public void testSingle() {
+ assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(),
+ "tensor(x[1]):[1.0]");
+ }
+
+ @Test
public void testDenseParsing() {
assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(),
"tensor():{0.0}");
@@ -55,18 +61,9 @@ public class TensorParserTestCase {
.cell(3.0, 1, 0, 0)
.cell(4.0, 1, 1, 0)
.cell(5.0, 2, 0, 0)
- .cell(6.0, 2, 1, 0).build(),
- "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]");
- assertEquals("Messy input",
- Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
- .cell( 1.0, 0, 0, 0)
- .cell( 2.0, 0, 1, 0)
- .cell( 3.0, 1, 0, 0)
- .cell( 4.0, 1, 1, 0)
- .cell( 5.0, 2, 0, 0)
.cell(-6.0, 2, 1, 0).build(),
- Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]"));
- assertEquals("Skipping syntactic sugar",
+ "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [-6.0]]]");
+ assertEquals("Skipping structure",
Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])"))
.cell( 1.0, 0, 0, 0)
.cell( 2.0, 0, 1, 0)
@@ -77,6 +74,59 @@ public class TensorParserTestCase {
Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]"));
}
+ @Test
+ public void testDenseWrongOrder() {
+ assertEquals("Opposite order of dimensions",
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2])"))
+ .cell(1, 0, 0)
+ .cell(4, 0, 1)
+ .cell(2, 1, 0)
+ .cell(5, 1, 1)
+ .cell(3, 2, 0)
+ .cell(6, 2, 1).build(),
+ Tensor.from("tensor(y[2],x[3]):[[1,2,3],[4,5,6]]"));
+ }
+
+ @Test
+ public void testMixedParsing() {
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])"))
+ .cell(TensorAddress.ofLabels("a", "0"), 1)
+ .cell(TensorAddress.ofLabels("a", "1"), 2)
+ .cell(TensorAddress.ofLabels("b", "0"), 3)
+ .cell(TensorAddress.ofLabels("b", "1"), 4).build(),
+ Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}"));
+ }
+
+ @Test
+ public void testSparseShortFormParsing() {
+ assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{})"))
+ .cell(TensorAddress.ofLabels("a"), 1)
+ .cell(TensorAddress.ofLabels("b"), 2).build(),
+ Tensor.from("tensor(key{}):{a:1, b:2}"));
+ }
+
+ @Test
+ public void testMixedWrongOrder() {
+ assertEquals("Opposite order of dimensions",
+ Tensor.Builder.of(TensorType.fromSpec("tensor(key{},x[3],y[2])"))
+ .cell(TensorAddress.ofLabels("key1", "0", "0"), 1)
+ .cell(TensorAddress.ofLabels("key1", "0", "1"), 4)
+ .cell(TensorAddress.ofLabels("key1", "1", "0"), 2)
+ .cell(TensorAddress.ofLabels("key1", "1", "1"), 5)
+ .cell(TensorAddress.ofLabels("key1", "2", "0"), 3)
+ .cell(TensorAddress.ofLabels("key1", "2", "1"), 6)
+ .cell(TensorAddress.ofLabels("key2", "0", "0"), 7)
+ .cell(TensorAddress.ofLabels("key2", "0", "1"), 10)
+ .cell(TensorAddress.ofLabels("key2", "1", "0"), 8)
+ .cell(TensorAddress.ofLabels("key2", "1", "1"), 11)
+ .cell(TensorAddress.ofLabels("key2", "2", "0"), 9)
+ .cell(TensorAddress.ofLabels("key2", "2", "1"), 12).build(),
+ Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}"));
+ assertEquals("Opposite order of dimensions",
+ Tensor.from("tensor(key{},x[3],y[2]):{key1:[[1,4],[2,5],[3,6]], key2:[[7,10],[8,11],[9,12]]}"),
+ Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}"));
+ }
+
private void assertDense(Tensor expectedTensor, String denseFormat) {
assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat));
assertEquals(denseFormat, expectedTensor.toString());
@@ -92,8 +142,12 @@ public class TensorParserTestCase {
"{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
assertIllegal("At {x:0}: '1-.0' is not a valid double",
"{{x:0}:1-.0}");
- assertIllegal("At index 0: '1-.0' is not a valid double",
+ assertIllegal("At value position 1: '1-.0' is not a valid double",
"tensor(x[1]):[1-.0]");
+ assertIllegal("At value position 5: Expected a ',' but got ']'",
+ "tensor(x[3]):[1, 2]");
+ assertIllegal("At value position 8: Expected a ']' but got ','",
+ "tensor(x[3]):[1, 2, 3, 4]");
}
private void assertIllegal(String message, String tensor) {
@@ -102,7 +156,7 @@ public class TensorParserTestCase {
fail("Expected an IllegalArgumentException when parsing " + tensor);
}
catch (IllegalArgumentException e) {
- assertEquals(message, e.getMessage());
+ assertEquals(message, e.getCause().getMessage());
}
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index 11365531019..9f077cb7b00 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -56,7 +56,8 @@ public class TensorTestCase {
fail("Expected parse error");
}
catch (IllegalArgumentException expected) {
- assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", expected.getMessage());
+ assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'",
+ expected.getCause().getMessage());
}
}
@@ -259,9 +260,9 @@ public class TensorTestCase {
assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0",
"tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}");
assertLargest("{x:1,y:1}:4.0",
- "tensor(x[2],y[2]):[[1,2],[3,4]");
+ "tensor(x[2],y[2]):[[1,2],[3,4]]");
assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0",
- "tensor(x[2],y[2]):[[4,2],[3,4]");
+ "tensor(x[2],y[2]):[[4,2],[3,4]]");
}
@Test
@@ -273,9 +274,9 @@ public class TensorTestCase {
assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0",
"tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}");
assertSmallest("{x:0,y:0}:1.0",
- "tensor(x[2],y[2]):[[1,2],[3,4]");
+ "tensor(x[2],y[2]):[[1,2],[3,4]]");
assertSmallest("{x:0,y:1}:2.0",
- "tensor(x[2],y[2]):[[4,2],[3,4]");
+ "tensor(x[2],y[2]):[[4,2],[3,4]]");
}
private void assertLargest(String expectedCells, String tensorString) {
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 2a34bc11b76..2231d32281a 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -10,6 +10,7 @@ import org.junit.Ignore;
import org.junit.Test;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import static org.junit.Assert.assertEquals;
@@ -20,21 +21,36 @@ import static org.junit.Assert.assertEquals;
public class DynamicTensorTestCase {
@Test
- public void testDynamicTensorFunction() {
+ public void testDynamicIndexedRank1TensorFunction() {
TensorType dense = TensorType.fromSpec("tensor(x[3])");
DynamicTensor<Name> t1 = DynamicTensor.from(dense,
List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
+ }
+ @Test
+ public void testDynamicMappedRank1TensorFunction() {
TensorType sparse = TensorType.fromSpec("tensor(x{})");
DynamicTensor<Name> t2 = DynamicTensor.from(sparse,
Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
- new Constant(5)));
+ new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
+ @Test
+ public void testDynamicMappedRank2TensorFunction() {
+ TensorType sparse = TensorType.fromSpec("tensor(x{},y{})");
+ HashMap<TensorAddress, ScalarFunction<Name>> values = new HashMap<>();
+ values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "b").build(),
+ new Constant(5));
+ values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "c").build(),
+ new Constant(7));
+ DynamicTensor<Name> t2 = DynamicTensor.from(sparse, values);
+ assertEquals(Tensor.from(sparse, "{{x:a,y:b}:5, {x:a,y:c}:7}"), t2.evaluate());
+ }
+
@Ignore // Enable for benchmarking
public void benchMarkTensorAddressBuilder() {
long start = System.nanoTime();