diff options
author | Jon Bratseth <bratseth@oath.com> | 2019-12-16 20:29:48 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-12-16 20:29:48 +0100 |
commit | 40e8a8b4ac2a021ede5a5babd42976ab313ce0b8 (patch) | |
tree | 253ee93b860f20a9c1deeb4cf0f6a31945bf6bf8 | |
parent | 6f5128b0d386b712aa94be3336a967990b096111 (diff) | |
parent | baa6a81aa07f37a543c836710b4c65b7831fd9db (diff) |
Merge pull request #11548 from vespa-engine/bratseth/mixed-tensor-parse
Bratseth/mixed tensor parse
19 files changed, 812 insertions, 226 deletions
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 554a36aef86..9e9dfae2bc7 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -133,7 +133,7 @@ rankprofile[].fef.property[].value "3" rankprofile[].fef.property[].name "vespa.rank.firstphase" rankprofile[].fef.property[].value "rankingExpression(firstphase)" rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" -rankprofile[].fef.property[].value "reduce(tensor(d0[1])(attribute{x:(rankingExpression(functionNotLabel))}), sum)" +rankprofile[].fef.property[].value "reduce(tensor(d0[1])(attribute{x:rankingExpression(functionNotLabel)}), sum)" rankprofile[].fef.property[].name "vespa.type.attribute.f2" rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" rankprofile[].fef.property[].name "vespa.type.attribute.f3" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index c3380bed19c..6e0e7e3e148 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -93,7 +93,7 @@ search tensor { rank-profile profile8 { first-phase { - expression: sum(tensor(d0[1])(attribute{x:(functionNotLabel)})) + expression: sum(tensor(d0[1])(attribute{x:functionNotLabel()})) } function functionNotLabel() { diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index debcd11fdbd..bde3b6abb6c 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -894,9 +894,9 @@ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", - "public final com.yahoo.tensor.TensorType tensorType()", + "public final com.yahoo.tensor.TensorType tensorType(java.util.List)", "public final com.yahoo.tensor.TensorType$Value optionalTensorValueTypeParameter()", - "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder)", + "public final void tensorTypeDimension(com.yahoo.tensor.TensorType$Builder, java.util.List)", "public final java.lang.String tensorFunctionName()", "public final com.yahoo.searchlib.rankingexpression.rule.Function unaryFunctionName()", "public final com.yahoo.searchlib.rankingexpression.rule.Function binaryFunctionName()", @@ -910,9 +910,16 @@ "public final java.util.List tagCommaLeadingList()", "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()", - "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType, java.util.List)", "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)", - "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.tensor.functions.DynamicTensor mixedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)", + "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType, java.util.List)", + "public final void keyValueOrMixedBlock(com.yahoo.tensor.TensorType, java.util.List, java.util.Map)", + "public final void keyValue(com.yahoo.tensor.TensorType, java.util.Map)", + "public final void mixedBlock(com.yahoo.tensor.TensorType, java.util.List, java.util.Map)", + "public final java.util.List indexedTensorCells()", + "public final void indexedTensorCellSubspaceList(java.util.List)", + "public final void indexedTensorCellSubspace(java.util.List)", "public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)", "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)", "public final void labelAndDimensionValues(java.util.List)", @@ -1616,7 +1623,8 @@ "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", "public static java.util.Map wrapScalars(java.util.Map)", - "public static java.util.List wrapScalars(java.util.List)", + "public static void wrapScalarBlock(com.yahoo.tensor.TensorType, java.util.List, java.lang.String, java.util.List, java.util.Map)", + "public static java.util.List wrapScalars(com.yahoo.tensor.TensorType, java.util.List, java.util.List)", "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" ], "fields": [] diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 0a67ab5534e..6200515462b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -8,6 +8,7 @@ import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; +import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -18,6 +19,7 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.LinkedHashMap; @@ -91,8 +93,50 @@ public class TensorFunctionNode extends CompositeNode { return functions; } - public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { - return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); + public static void wrapScalarBlock(TensorType type, + List<String> dimensionOrder, + String mappedDimensionLabel, + List<ExpressionNode> nodes, + Map<TensorAddress, ScalarFunction<Reference>> receivingMap) { + TensorType denseSubtype = new TensorType(type.valueType(), + type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); + List<String> denseDimensionOrder = new ArrayList<>(dimensionOrder); + denseDimensionOrder.retainAll(denseSubtype.dimensionNames()); + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype, denseDimensionOrder); + if (indexes.size() != nodes.size()) + throw new IllegalArgumentException("At '" + mappedDimensionLabel + "': Need " + indexes.size() + + " values to fill a dense subspace of " + type + " but got " + nodes.size()); + for (ExpressionNode node : nodes) { + indexes.next(); + + // Insert the mapped dimension into the dense subspace address of indexes + String[] labels = new String[type.rank()]; + int indexedDimensionsIndex = 0; + int allDimensionsIndex = 0; + for (TensorType.Dimension dimension : type.dimensions()) { + if (dimension.isIndexed()) + labels[allDimensionsIndex++] = String.valueOf(indexes.indexesForReading()[indexedDimensionsIndex++]); + else + labels[allDimensionsIndex++] = mappedDimensionLabel; + } + + receivingMap.put(TensorAddress.of(labels), wrapScalar(node)); + } + } + + public static List<ScalarFunction<Reference>> wrapScalars(TensorType type, + List<String> dimensionOrder, + List<ExpressionNode> nodes) { + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type, dimensionOrder); + if (indexes.size() != nodes.size()) + throw new IllegalArgumentException("Need " + indexes.size() + " values to fill " + type + " but got " + nodes.size()); + + List<ScalarFunction<Reference>> wrapped = new ArrayList<>(nodes.size()); + while (indexes.hasNext()) { + indexes.next(); + wrapped.add(wrapScalar(nodes.get((int)indexes.toSourceValueIndex()))); + } + return wrapped; } public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) { diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index de3ad6b5d8c..e413e398183 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -6,7 +6,6 @@ */ options { CACHE_TOKENS = true; - STATIC = false; DEBUG_PARSER = false; USER_TOKEN_MANAGER = false; ERROR_REPORTING = true; @@ -476,13 +475,14 @@ TensorFunctionNode tensorConcat() : TensorFunctionNode tensorGenerate() : { TensorType type; + List dimensionOrder = new ArrayList(); TensorFunctionNode expression; } { - <TENSOR> type = tensorType() + <TENSOR> type = tensorType(dimensionOrder) ( expression = tensorGenerateBody(type) | - expression = tensorValueBody(type) + expression = tensorValueBody(type, dimensionOrder) ) { return expression; } } @@ -501,7 +501,7 @@ TensorFunctionNode tensorRange() : TensorType type; } { - <RANGE> type = tensorType() + <RANGE> type = tensorType(null) { return new TensorFunctionNode(new Range(type)); } } @@ -510,7 +510,7 @@ TensorFunctionNode tensorDiag() : TensorType type; } { - <DIAG> type = tensorType() + <DIAG> type = tensorType(null) { return new TensorFunctionNode(new Diag(type)); } } @@ -519,7 +519,7 @@ TensorFunctionNode tensorRandom() : TensorType type; } { - <RANDOM> type = tensorType() + <RANDOM> type = tensorType(null) { return new TensorFunctionNode(new Random(type)); } } @@ -619,7 +619,7 @@ Reduce.Aggregator tensorReduceAggregator() : { return Reduce.Aggregator.valueOf(token.image); } } -TensorType tensorType() : +TensorType tensorType(List dimensionOrder) : { TensorType.Builder builder; TensorType.Value valueType; @@ -628,8 +628,8 @@ TensorType tensorType() : valueType = optionalTensorValueTypeParameter() { builder = new TensorType.Builder(valueType); } <LBRACE> - ( tensorTypeDimension(builder) ) ? - ( <COMMA> tensorTypeDimension(builder) ) * + ( tensorTypeDimension(builder, dimensionOrder) ) ? + ( <COMMA> tensorTypeDimension(builder, dimensionOrder) ) * <RBRACE> { return builder.build(); } } @@ -643,13 +643,17 @@ TensorType.Value optionalTensorValueTypeParameter() : { return TensorType.Value.fromId(valueType); } } -void tensorTypeDimension(TensorType.Builder builder) : +void tensorTypeDimension(TensorType.Builder builder, List dimensionOrder) : { String name; int size; } { name = identifier() + { // Keep track of the order in which dimensions are written, if necessary + if (dimensionOrder != null) + dimensionOrder.add(name); + } ( ( <LCURLY> <RCURLY> { builder.mapped(name); } ) | LOOKAHEAD(2) ( <LSQUARE> <RSQUARE> { builder.indexed(name); } ) | @@ -832,15 +836,16 @@ Value primitiveValue() : { return Value.parse(sign + token.image); } } -TensorFunctionNode tensorValueBody(TensorType type) : +TensorFunctionNode tensorValueBody(TensorType type, List dimensionOrder) : { DynamicTensor dynamicTensor; } { <COLON> ( + LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type, dimensionOrder) | dynamicTensor = mappedTensorValueBody(type) | - dynamicTensor = indexedTensorValueBody(type) + dynamicTensor = indexedTensorValueBody(type, dimensionOrder) ) { return new TensorFunctionNode(dynamicTensor); } } @@ -851,23 +856,82 @@ DynamicTensor mappedTensorValueBody(TensorType type) : } { <LCURLY> - ( tensorCell(type, cells))* + [ tensorCell(type, cells)] ( <COMMA> tensorCell(type, cells))* <RCURLY> { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } -DynamicTensor indexedTensorValueBody(TensorType type) : +DynamicTensor mixedTensorValueBody(TensorType type, List dimensionOrder) : +{ + java.util.Map cells = new LinkedHashMap(); +} +{ + <LCURLY> + keyValueOrMixedBlock(type, dimensionOrder, cells) + ( <COMMA> keyValueOrMixedBlock(type, dimensionOrder, cells))* + <RCURLY> + { return DynamicTensor.from(type, cells); } +} + +DynamicTensor indexedTensorValueBody(TensorType type, List dimensionOrder) : +{ + List cells; +} +{ + cells = indexedTensorCells() + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(type, dimensionOrder, cells)); } +} + +void keyValueOrMixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) : {} +{ + LOOKAHEAD(3) mixedBlock(type, dimensionOrder, cellMap) | keyValue(type, cellMap) +} + +void keyValue(TensorType type, java.util.Map cellMap) : +{ + String label; + ExpressionNode value; +} +{ + label = tag() <COLON> value = expression() + { cellMap.put(TensorAddress.ofLabels(label), TensorFunctionNode.wrapScalar(value)); } +} + +void mixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) : +{ + String label; + List cells; +} +{ + label = tag() <COLON> cells = indexedTensorCells() + { TensorFunctionNode.wrapScalarBlock(type, dimensionOrder, label, cells, cellMap); } +} + +List indexedTensorCells() : { List cells = new ArrayList(); +} +{ + <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE> + { return cells; } +} + +void indexedTensorCellSubspaceList(List cells) : +{ +} +{ + indexedTensorCellSubspace(cells) ( LOOKAHEAD(2) <COMMA> indexedTensorCellSubspace(cells) )* +} + +void indexedTensorCellSubspace(List cells) : +{ ExpressionNode value; } { - <LSQUARE> // TODO: Parse inner square brackets properly - ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* - ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* -// <RSQUARE> - { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } + ( <LSQUARE> indexedTensorCellSubspaceList(cells) <RSQUARE> ) + | + ( value = expression() { cells.add(value); } ) } void tensorCell(TensorType type, java.util.Map cells) : diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index e3d3ac7b2e1..b750a7607cc 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -191,7 +191,7 @@ public class RankingExpressionTestCase { // Accessing a function in a dynamic tensor, short form assertSerialization(List.of("tensor(x[2]):{{x:0}:rankingExpression(scalarFunction),{x:1}:rankingExpression(scalarFunction)}"), - "tensor(x[2]):[scalarFunction(), scalarFunction()]]", + "tensor(x[2]):[scalarFunction(), scalarFunction()]", functions, false); // Accessing a function in a dynamic tensor, long form diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 00750c70d2c..26861dd3cd6 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -14,6 +14,7 @@ import com.yahoo.tensor.TensorType; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; /** * Tests expression evaluation @@ -402,6 +403,51 @@ public class EvaluationTestCase { "{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }"); tester.assertEvaluates("tensor<float>(d0[1],x[3]):[[1.0, 0.5, 0.25]]", "tensor<float>(d0[1],x[3]):[[one,one_half,a_quarter]]"); + tester.assertEvaluates("tensor(x[2],y[3]):[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]", + "tensor(x[2],y[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]"); + tester.assertEvaluates("tensor(x{},y[2]):{{x:a,y:0}:1.0, {x:a,y:1}:0.5, {x:b,y:0}:0.25, {x:b,y:1}:2.0}", + "tensor(x{},y[2]):{{x:a,y:0}:one, {x:a,y:1}:one_half, {x:b,y:0}:a_quarter, {x:b,y:1}:2}"); + tester.assertEvaluates("tensor(x{},y[2]):{a:[1.0, 0.5], b:[0.25, 2]}", + "tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter, 2]}"); + tester.assertEvaluates("tensor(key{},x[2],y[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," + + " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}", + "tensor(key{},x[2],y[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," + + " key2:[[1,2,3],[4,5,6]]}"); + tester.assertEvaluates("tensor(x{}):{{x:a}:1, {x:b}:-2, {x:cee}:0.5}", "tensor(x{}):{a:1, b:-2, cee:one_half}"); + + // Opposite order in the expression: + // - indexed + tester.assertEvaluates("tensor(x[3],y[2]):[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]", + "tensor(y[2],x[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]"); + // - mixed + tester.assertEvaluates("tensor(key{},x[3],y[2]):{key1:[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]," + + " key2:[[1.0, 4.00], [2.0,5.0], [3.00, 6.0]]}", + "tensor(key{},y[2],x[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," + + " key2:[[1,2,3],[4,5,6]]}"); + // Opposite order in literal parsing: + // - indexed + tester.assertEvaluates("tensor(y[2],x[3]):[[1,0.25,0.5],[0.5,0.25,1]]", + "tensor(x[3],y[2]):[[one,one_half], [a_quarter,a_quarter], [one_half,one]]"); + // - mixed + tester.assertEvaluates("tensor(key{},y[2],x[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," + + " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}", + "tensor(key{},x[3],y[2]):{key1:[[one,a_quarter],[one_half,one_half],[a_quarter,one]]," + + " key2:[[1,4],[2,5],[3,6]]}"); + + try { + new RankingExpression("tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter]}"); + fail("Expected exception"); + } + catch (Exception e) { + assertEquals("At 'b': Need 2 values to fill a dense subspace of tensor(x{},y[2]) but got 1", e.getMessage()); + } + try { + new RankingExpression("tensor(x[2],y[3]):[[1,2,3,4],[4,5,6]]"); + fail("Expected exception"); + } + catch (Exception e) { + assertEquals("Need 6 values to fill tensor(x[2],y[3]) but got 7", e.getMessage()); + } } @Test diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 19328f5dbb2..a4a9a1e1b24 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -693,6 +693,7 @@ "methods": [ "public void <init>(int)", "public com.yahoo.tensor.DimensionSizes$Builder set(int, long)", + "public com.yahoo.tensor.DimensionSizes$Builder add(long)", "public long size(int)", "public int dimensions()", "public com.yahoo.tensor.DimensionSizes build()" @@ -776,15 +777,14 @@ }, "com.yahoo.tensor.IndexedTensor$BoundBuilder": { "superClass": "com.yahoo.tensor.IndexedTensor$Builder", - "interfaces": [], + "interfaces": [ + "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder" + ], "attributes": [ "public", "abstract" ], - "methods": [ - "public abstract void cellByDirectIndex(long, double)", - "public abstract void cellByDirectIndex(long, float)" - ], + "methods": [], "fields": [] }, "com.yahoo.tensor.IndexedTensor$Builder": { @@ -813,6 +813,21 @@ ], "fields": [] }, + "com.yahoo.tensor.IndexedTensor$DirectIndexBuilder": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public", + "interface", + "abstract" + ], + "methods": [ + "public abstract com.yahoo.tensor.TensorType type()", + "public abstract void cellByDirectIndex(long, double)", + "public abstract void cellByDirectIndex(long, float)" + ], + "fields": [] + }, "com.yahoo.tensor.IndexedTensor$Indexes": { "superClass": "java.lang.Object", "interfaces": [], @@ -822,14 +837,17 @@ ], "methods": [ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)", + "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType, java.util.List)", "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)", "public com.yahoo.tensor.TensorAddress toAddress()", "public long[] indexesCopy()", "public long[] indexesForReading()", + "public long toSourceValueIndex()", "public java.util.List toList()", "public java.lang.String toString()", "public abstract long size()", - "public abstract void next()" + "public abstract void next()", + "public abstract boolean hasNext()" ], "fields": [ "protected final long[] indexes" @@ -943,6 +961,7 @@ ], "methods": [ "public long denseSubspaceSize()", + "public com.yahoo.tensor.IndexedTensor$DirectIndexBuilder denseSubspaceBuilder(com.yahoo.tensor.TensorAddress)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, float)", "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", "public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])", @@ -1035,8 +1054,8 @@ ], "methods": [ "public void <init>(int)", - "public void add(java.lang.String, long)", - "public void add(java.lang.String, java.lang.String)", + "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, long)", + "public com.yahoo.tensor.PartialAddress$Builder add(java.lang.String, java.lang.String)", "public com.yahoo.tensor.PartialAddress build()" ], "fields": [] @@ -1236,6 +1255,7 @@ "methods": [ "public void <init>()", "public static com.yahoo.tensor.TensorAddress of(java.lang.String[])", + "public static varargs com.yahoo.tensor.TensorAddress ofLabels(java.lang.String[])", "public static varargs com.yahoo.tensor.TensorAddress of(long[])", "public abstract int size()", "public abstract java.lang.String label(int)", @@ -1395,6 +1415,7 @@ "public" ], "methods": [ + "public void <init>(com.yahoo.tensor.TensorType$Value, java.util.Collection)", "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index d81c02fb75f..202817ece42 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -71,6 +71,7 @@ public final class DimensionSizes { */ public final static class Builder { + private int dimensionIndex = 0; private long[] sizes; public Builder(int dimensions) { @@ -82,6 +83,11 @@ public final class DimensionSizes { return this; } + public Builder add(long size) { + sizes[dimensionIndex++] = size; + return this; + } + /** * Returns the length of this in the nth dimension * diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 176ddfefc13..ba3a35e8eda 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor { indexes.next(); // start brackets - for (int i = 0; i < indexes.rightDimensionsWhichAreAtStart(); i++) + for (int i = 0; i < indexes.nextDimensionsAtStart(); i++) b.append("["); // value @@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalStateException("Unexpected value type " + type.valueType()); // end bracket and comma - for (int i = 0; i < indexes.rightDimensionsWhichAreAtEnd(); i++) + for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++) b.append("]"); if (index < size() - 1) b.append(", "); @@ -375,8 +375,22 @@ public abstract class IndexedTensor implements Tensor { } + public interface DirectIndexBuilder { + + TensorType type(); + + + + /** Sets a value by its <i>standard value order</i> index */ + void cellByDirectIndex(long index, double value); + + /** Sets a value by its <i>standard value order</i> index */ + void cellByDirectIndex(long index, float value); + + } + /** A bound builder can create the double array directly */ - public static abstract class BoundBuilder extends Builder { + public static abstract class BoundBuilder extends Builder implements DirectIndexBuilder { private DimensionSizes sizes; @@ -393,14 +407,16 @@ public abstract class IndexedTensor implements Tensor { throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; } - BoundBuilder fill(float [] values) { + + BoundBuilder fill(float[] values) { long index = 0; for (float value : values) { cellByDirectIndex(index++, value); } return this; } - BoundBuilder fill(double [] values) { + + BoundBuilder fill(double[] values) { long index = 0; for (double value : values) { cellByDirectIndex(index++, value); @@ -410,12 +426,6 @@ public abstract class IndexedTensor implements Tensor { DimensionSizes sizes() { return sizes; } - /** Sets a value by its <i>standard value order</i> index */ - public abstract void cellByDirectIndex(long index, double value); - - /** Sets a value by its <i>standard value order</i> index */ - public abstract void cellByDirectIndex(long index, float value); - } /** @@ -767,6 +777,10 @@ public abstract class IndexedTensor implements Tensor { return of(DimensionSizes.of(type)); } + public static Indexes of(TensorType type, List<String> iterateDimensionOrder) { + return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type)); + } + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -779,6 +793,10 @@ public abstract class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } + private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) { + return of(sizes, sizes, iterateDimensions); + } + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) { return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } @@ -812,6 +830,16 @@ public abstract class IndexedTensor implements Tensor { } } + private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) { + if (dimensionNames == null) return completeIterationOrder(type.rank()); + + List<Integer> iterationDimensions = new ArrayList<>(type.rank()); + for (int i = 0; i < type.rank(); i++) + iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get()); + return iterationDimensions; + } + + /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */ private static List<Integer> completeIterationOrder(int length) { List<Integer> iterationDimensions = new ArrayList<>(length); for (int i = 0; i < length; i++) @@ -844,7 +872,7 @@ public abstract class IndexedTensor implements Tensor { /** Returns a copy of the indexes of this which must not be modified */ public long[] indexesForReading() { return indexes; } - long toSourceValueIndex() { + public long toSourceValueIndex() { return IndexedTensor.toValueIndex(indexes, sourceSizes); } @@ -869,27 +897,15 @@ public abstract class IndexedTensor implements Tensor { public abstract void next(); - /** Returns the number of dimensions from the right which are currently at the start position (0) */ - int rightDimensionsWhichAreAtStart() { - int dimension = indexes.length - 1; - int atStartCount = 0; - while (dimension >= 0 && indexes[dimension] == 0) { - atStartCount++; - dimension--; - } - return atStartCount; - } + /** Returns whether further values are available by calling next() */ + public abstract boolean hasNext(); + + /** Returns the number of dimensions in iteration order which are currently at the start position (0) */ + abstract int nextDimensionsAtStart(); + + /** Returns the number of dimensions in iteration order which are currently at their end position */ + abstract int nextDimensionsAtEnd(); - /** Returns the number of dimensions from the right which are currently at the end position */ - int rightDimensionsWhichAreAtEnd() { - int dimension = indexes.length - 1; - int atEndCount = 0; - while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) { - atEndCount++; - dimension--; - } - return atEndCount; - } } private final static class EmptyIndexes extends Indexes { @@ -904,10 +920,21 @@ public abstract class IndexedTensor implements Tensor { @Override public void next() {} + @Override + public boolean hasNext() { return false; } + + @Override + int nextDimensionsAtStart() { return 0; } + + @Override + int nextDimensionsAtEnd() { return 0; } + } private final static class SingleValueIndexes extends Indexes { + private boolean exhausted = false; + private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @@ -916,7 +943,16 @@ public abstract class IndexedTensor implements Tensor { public long size() { return 1; } @Override - public void next() {} + public void next() { exhausted = true; } + + @Override + public boolean hasNext() { return ! exhausted; } + + @Override + int nextDimensionsAtStart() { return 1; } + + @Override + int nextDimensionsAtEnd() { return 1; } } @@ -945,7 +981,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -957,6 +993,31 @@ public abstract class IndexedTensor implements Tensor { indexes[iterateDimensions.get(iterateDimensionsIndex)]++; } + @Override + public boolean hasNext() { + for (int iterateDimension : iterateDimensions) { + if (indexes[iterateDimension] + 1 < dimensionSizes().size(iterateDimension)) + return true; // some dimension is not at the end + } + return false; + } + + @Override + int nextDimensionsAtStart() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0) + dimension++; + return dimension; + } + + @Override + int nextDimensionsAtEnd() { + int dimension = 0; + while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1) + dimension++; + return dimension; + } + } /** In this case we can reuse the source index computation for the iteration index */ @@ -969,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { + public long toSourceValueIndex() { return lastComputedSourceValueIndex = super.toSourceValueIndex(); } @@ -1016,7 +1077,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -1026,11 +1087,22 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentSourceValueIndex; } + public long toSourceValueIndex() { return currentSourceValueIndex; } @Override long toIterationValueIndex() { return currentIterationValueIndex; } + @Override + public boolean hasNext() { + return indexes[iterateDimension] + 1 < size; + } + + @Override + int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; } + } /** In this case we only need to keep track of one index */ @@ -1068,7 +1140,7 @@ public abstract class IndexedTensor implements Tensor { * Advances this to the next cell in the standard indexed tensor cell order. * The first call to this will put it at the first position. * - * @throws RuntimeException if this is called more times than its size + * @throws RuntimeException if this is called when hasNext returns false */ @Override public void next() { @@ -1077,11 +1149,22 @@ public abstract class IndexedTensor implements Tensor { } @Override - long toSourceValueIndex() { return currentValueIndex; } + public boolean hasNext() { + return indexes[iterateDimension] + 1 < size; + } + + @Override + public long toSourceValueIndex() { return currentValueIndex; } @Override long toIterationValueIndex() { return currentValueIndex; } + @Override + int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; } + + @Override + int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 1cde1fcdbb7..0c4efe78113 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -217,25 +217,34 @@ public class MixedTensor implements Tensor { public static class BoundBuilder extends Builder { /** For each sparse partial address, hold a dense subspace */ - final private Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); - final private Index.Builder indexBuilder; - final private Index index; + private final Map<TensorAddress, double[]> denseSubspaceMap = new HashMap<>(); + private final Index.Builder indexBuilder; + private final Index index; + private final TensorType denseSubtype; private BoundBuilder(TensorType type) { super(type); indexBuilder = new Index.Builder(type); index = indexBuilder.index(); + denseSubtype = new TensorType(type.valueType(), + type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList())); } public long denseSubspaceSize() { return index.denseSubspaceSize(); } - private double[] denseSubspace(TensorAddress sparsePartial) { - if (!denseSubspaceMap.containsKey(sparsePartial)) { - denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]); + private double[] denseSubspace(TensorAddress sparseAddress) { + if (!denseSubspaceMap.containsKey(sparseAddress)) { + denseSubspaceMap.put(sparseAddress, new double[(int)denseSubspaceSize()]); } - return denseSubspaceMap.get(sparsePartial); + return denseSubspaceMap.get(sparseAddress); + } + + public IndexedTensor.DirectIndexBuilder denseSubspaceBuilder(TensorAddress sparseAddress) { + double[] values = new double[(int)denseSubspaceSize()]; + denseSubspaceMap.put(sparseAddress, values); + return new DenseSubspaceBuilder(denseSubtype, values); } @Override @@ -280,7 +289,6 @@ public class MixedTensor implements Tensor { } - /** * Temporarily stores all cells to find bounds of indexed dimensions, * then creates a tensor using BoundBuilder. This is due to the @@ -491,6 +499,31 @@ public class MixedTensor implements Tensor { } + private static class DenseSubspaceBuilder implements IndexedTensor.DirectIndexBuilder { + + private final TensorType type; + private final double[] values; + + public DenseSubspaceBuilder(TensorType type, double[] values) { + this.type = type; + this.values = values; + } + + @Override + public TensorType type() { return type; } + + @Override + public void cellByDirectIndex(long index, double value) { + values[(int)index] = value; + } + + @Override + public void cellByDirectIndex(long index, float value) { + values[(int)index] = value; + } + + } + public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { TensorType.Builder builder = new TensorType.Builder(valueType); for (TensorType.Dimension dimension : dimensions) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index 4eca9c47402..84f26d96725 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -122,16 +122,18 @@ public class PartialAddress { labels = new Object[size]; } - public void add(String dimensionName, long label) { + public Builder add(String dimensionName, long label) { dimensionNames[index] = dimensionName; labels[index] = label; index++; + return this; } - public void add(String dimensionName, String label) { + public Builder add(String dimensionName, String label) { dimensionNames[index] = dimensionName; labels[index] = label; index++; + return this; } public PartialAddress build() { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index e705445c5a7..4770ad1b1f0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -21,6 +21,10 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return new StringTensorAddress(labels); } + public static TensorAddress ofLabels(String ... labels) { + return new StringTensorAddress(labels); + } + public static TensorAddress of(long ... labels) { return new NumericTensorAddress(labels); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 4d8b34b7dcf..5a1fd98a009 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import java.util.ArrayList; +import java.util.List; import java.util.Optional; /** @@ -9,14 +11,30 @@ import java.util.Optional; class TensorParser { static Tensor tensorFrom(String tensorString, Optional<TensorType> explicitType) { + try { + return tensorFromBody(tensorString, explicitType); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Could not parse '" + tensorString + "' as a tensor" + + (explicitType.isPresent() ? " of type " + explicitType.get() : ""), + e); + } + } + + static Tensor tensorFromBody(String tensorString, Optional<TensorType> explicitType) { Optional<TensorType> type; String valueString; + // The order in which dimensions are written in the type string. + // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than + // the natural order of the tensor. + List<String> dimensionOrder; + tensorString = tensorString.trim(); if (tensorString.startsWith("tensor")) { int colonIndex = tensorString.indexOf(':'); String typeString = tensorString.substring(0, colonIndex); - TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + dimensionOrder = new ArrayList<>(); + TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder); if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + "passed type " + explicitType.get()); @@ -26,14 +44,19 @@ class TensorParser { else { type = explicitType; valueString = tensorString; + dimensionOrder = null; } valueString = valueString.trim(); - if (valueString.startsWith("{")) { + if (valueString.startsWith("{") && + (type.isEmpty() || type.get().rank() == 0 || valueString.substring(1).trim().startsWith("{") || valueString.substring(1).trim().equals("}"))) { return tensorFromSparseValueString(valueString, type); } + else if (valueString.startsWith("{")) { + return tensorFromMixedValueString(valueString, type, dimensionOrder); + } else if (valueString.startsWith("[")) { - return tensorFromDenseValueString(valueString, type); + return tensorFromDenseValueString(valueString, type, dimensionOrder); } else { if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty)) @@ -54,8 +77,7 @@ class TensorParser { String s = valueString.substring(1).trim(); // remove tensor start int firstKeyOrTensorEnd = s.indexOf('}'); if (firstKeyOrTensorEnd < 0) - throw new IllegalArgumentException("Excepted a number or a string starting by {, [ or tensor(...):, got '" + - valueString + "'"); + throw new IllegalArgumentException("Excepted a number or a string starting by '{', '[' or 'tensor(...):...'"); String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor @@ -79,138 +101,312 @@ class TensorParser { try { valueString = valueString.trim(); Tensor.Builder builder = Tensor.Builder.of(type.orElse(typeFromSparseValueString(valueString))); - return fromCellString(builder, valueString); + SparseValueParser parser = new SparseValueParser(valueString, builder); + parser.parse(); + return builder.build(); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Excepted a number or a string starting by { or tensor(, got '" + - valueString + "'"); + throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('"); } } - private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) { + private static Tensor tensorFromMixedValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { + if (type.isEmpty()) + throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " + + "on the form 'tensor(dimensions):..."); + if (type.get().dimensions().stream().filter(d -> ! d.isIndexed()).count() != 1) + throw new IllegalArgumentException("The mixed tensor form requires a type with a single mapped dimension, " + + "but got " + type.get()); + + + try { + valueString = valueString.trim(); + if ( ! valueString.startsWith("{") && valueString.endsWith("}")) + throw new IllegalArgumentException("A mixed tensor must be enclosed in {}"); + Tensor.Builder builder = Tensor.Builder.of(type.get()); + MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder); + parser.parse(); + return builder.build(); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Excepted a number or a string starting by '{' or 'tensor('"); + } + } + + private static Tensor tensorFromDenseValueString(String valueString, + Optional<TensorType> type, + List<String> dimensionOrder) { if (type.isEmpty()) throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " + "on the form 'tensor(dimensions):..."); - if (type.get().dimensions().stream().anyMatch(d -> ( d.size().isEmpty()))) + if (type.get().dimensions().stream().anyMatch(d -> (d.size().isEmpty()))) throw new IllegalArgumentException("The dense tensor form requires a tensor type containing " + "only dense dimensions with a given size"); - IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)IndexedTensor.Builder.of(type.get()); - long index = 0; - int currentChar; - int nextNumberEnd = 0; - // Since we know the dimensions the brackets are just syntactic sugar: - while ((currentChar = nextStartCharIndex(nextNumberEnd + 1, valueString)) < valueString.length()) { - nextNumberEnd = nextStopCharIndex(currentChar, valueString); - if (currentChar == nextNumberEnd) return builder.build(); - - TensorType.Value cellValueType = builder.type().valueType(); - String cellValueString = valueString.substring(currentChar, nextNumberEnd); + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get()); + new DenseValueParser(valueString, dimensionOrder, builder).parse(); + return builder.build(); + } + + private static abstract class ValueParser { + + protected final String string; + protected int position = 0; + + protected ValueParser(String string) { + this.string = string; + } + + protected void skipSpace() { + while (position < string.length() && string.charAt(position) == ' ') + position++; + } + + protected void consume(char character) { + skipSpace(); + + if (position >= string.length()) + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + + "' but got the end of the string"); + if ( string.charAt(position) != character) + throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character + + "' but got '" + string.charAt(position) + "'"); + position++; + } + + protected Number consumeNumber(TensorType.Value cellValueType) { + skipSpace(); + + int nextNumberEnd = nextStopCharIndex(position, string); try { - if (cellValueType == TensorType.Value.DOUBLE) - builder.cellByDirectIndex(index, Double.parseDouble(cellValueString)); - else if (cellValueType == TensorType.Value.FLOAT) - builder.cellByDirectIndex(index, Float.parseFloat(cellValueString)); - else - throw new IllegalArgumentException(cellValueType + " is not supported"); + String cellValueString = string.substring(position, nextNumberEnd); + try { + if (cellValueType == TensorType.Value.DOUBLE) + return Double.parseDouble(cellValueString); + else if (cellValueType == TensorType.Value.FLOAT) + return Float.parseFloat(cellValueString); + else + throw new IllegalArgumentException(cellValueType + " is not supported"); + } catch (NumberFormatException e) { + throw new IllegalArgumentException("At value position " + position + ": '" + + cellValueString + "' is not a valid " + cellValueType); + } } - catch (NumberFormatException e) { - throw new IllegalArgumentException("At index " + index + ": '" + - cellValueString + "' is not a valid " + cellValueType); + finally { + position = nextNumberEnd; } - index++; } - return builder.build(); - } - /** Returns the position of the next character that should contain a number, or if none the string length */ - private static int nextStartCharIndex(int charIndex, String valueString) { - for (; charIndex < valueString.length(); charIndex++) { - if (valueString.charAt(charIndex) == ']') continue; - if (valueString.charAt(charIndex) == '[') continue; - if (valueString.charAt(charIndex) == ',') continue; - if (valueString.charAt(charIndex) == ' ') continue; - return charIndex; + protected int nextStopCharIndex(int position, String valueString) { + while (position < valueString.length()) { + if (valueString.charAt(position) == ',') return position; + if (valueString.charAt(position) == ']') return position; + if (valueString.charAt(position) == '}') return position; + position++; + } + throw new IllegalArgumentException("Malformed tensor value '" + valueString + + "': Expected a ',', ']' or '}' after position " + position); } - return valueString.length(); + } - private static int nextStopCharIndex(int charIndex, String valueString) { - while (charIndex < valueString.length()) { - if (valueString.charAt(charIndex) == ',') return charIndex; - if (valueString.charAt(charIndex) == ']') return charIndex; - charIndex++; + /** A single-use dense tensor string parser */ + private static class DenseValueParser extends ValueParser { + + private final IndexedTensor.DirectIndexBuilder builder; + private final IndexedTensor.Indexes indexes; + private final boolean hasInnerStructure; + + public DenseValueParser(String string, + List<String> dimensionOrder, + IndexedTensor.DirectIndexBuilder builder) { + super(string); + this.builder = builder; + indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder); + hasInnerStructure = hasInnerStructure(string); } - throw new IllegalArgumentException("Malformed tensor value '" + valueString + - "': Expected a ',' or ']' after position " + charIndex); - } - private static Tensor fromCellString(Tensor.Builder builder, String s) { - int index = 1; - index = skipSpace(index, s); - while (index + 1 < s.length()) { - int keyOrTensorEnd = s.indexOf('}', index); - TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); - if (keyOrTensorEnd < s.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty - addLabels(s.substring(index, keyOrTensorEnd + 1), addressBuilder); - index = keyOrTensorEnd + 1; - index = skipSpace(index, s); - if ( s.charAt(index) != ':') - throw new IllegalArgumentException("Expecting a ':' after " + s.substring(index) + ", got '" + s + "'"); - index++; - } - int valueEnd = s.indexOf(',', index); - if (valueEnd < 0) { // last value - valueEnd = s.indexOf('}', index); - if (valueEnd < 0) - throw new IllegalArgumentException("A tensor string must end by '}'"); + public void parse() { + if (!hasInnerStructure) + consume('['); + + while (indexes.hasNext()) { + indexes.next(); + for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++) + consume('['); + consumeNumber(); + for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++) + consume(']'); + if (indexes.hasNext()) + consume(','); } - TensorAddress address = addressBuilder.build(); - TensorType.Value cellValueType = builder.type().valueType(); - String cellValueString = s.substring(index, valueEnd).trim(); - try { - if (cellValueType == TensorType.Value.DOUBLE) - builder.cell(address, Double.parseDouble(cellValueString)); - else if (cellValueType == TensorType.Value.FLOAT) - builder.cell(address, Float.parseFloat(cellValueString)); + if (!hasInnerStructure) + consume(']'); + } + + public int position() { return position; } + + /** Are there inner square brackets in this or is it just a flat list of numbers until ']'? */ + private static boolean hasInnerStructure(String valueString) { + valueString = valueString.trim(); + valueString = valueString.substring(1); + int firstLeftBracket = valueString.indexOf('['); + return firstLeftBracket >= 0 && firstLeftBracket < valueString.indexOf(']'); + } + + protected void consumeNumber() { + Number number = consumeNumber(builder.type().valueType()); + if (builder.type().valueType() == TensorType.Value.DOUBLE) + builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Double)number); + else if (builder.type().valueType() == TensorType.Value.FLOAT) + builder.cellByDirectIndex(indexes.toSourceValueIndex(), (Float)number); + } + + } + + /** + * Parses mixed tensor short forms {a:[1,2], ...} AND 1d mapped tensor short form {a:b, ...}. + */ + private static class MixedValueParser extends ValueParser { + + private final Tensor.Builder builder; + private List<String> dimensionOrder; + + public MixedValueParser(String string, List<String> dimensionOrder, Tensor.Builder builder) { + super(string); + this.dimensionOrder = dimensionOrder; + this.builder = builder; + } + + private void parse() { + TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get(); + TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension)); + if (dimensionOrder != null) + dimensionOrder.remove(mappedDimension.name()); + + skipSpace(); + consume('{'); + skipSpace(); + while (position + 1 < string.length()) { + int labelEnd = string.indexOf(':', position); + if (labelEnd <= position) + throw new IllegalArgumentException("A mixed tensor value must be on the form {sparse-label:[dense subspace], ...}, or {sparse-label:value, ...}"); + String label = string.substring(position, labelEnd); + position = labelEnd + 1; + skipSpace(); + + TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build(); + if (builder.type().rank() > 1) + parseDenseSubspace(mappedAddress, dimensionOrder); else - throw new IllegalArgumentException(cellValueType + " is not supported"); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" + - cellValueString + "' is not a valid " + cellValueType); + consumeNumber(mappedAddress); + if ( ! consumeOptional(',')) + consume('}'); + skipSpace(); } + } - index = valueEnd+1; - index = skipSpace(index, s); + private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) { + DenseValueParser denseParser = new DenseValueParser(string.substring(position), + denseDimensionOrder, + ((MixedTensor.BoundBuilder)builder).denseSubspaceBuilder(sparseAddress)); + denseParser.parse(); + position+= denseParser.position(); + } + + private boolean consumeOptional(char character) { + skipSpace(); + + if (position >= string.length()) + return false; + if ( string.charAt(position) != character) + return false; + + position++; + return true; + } + + private void consumeNumber(TensorAddress address) { + Number number = consumeNumber(builder.type().valueType()); + if (builder.type().valueType() == TensorType.Value.DOUBLE) + builder.cell(address, (Double)number); + else if (builder.type().valueType() == TensorType.Value.FLOAT) + builder.cell(address, (Float)number); } - return builder.build(); - } - private static int skipSpace(int index, String s) { - while (index < s.length() && s.charAt(index) == ' ') - index++; - return index; } - /** Creates a tenor address from a string on the form {dimension1:label1,dimension2:label2,...} */ - private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { - mapAddressString = mapAddressString.trim(); - if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}"))) - throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'"); + private static class SparseValueParser extends ValueParser { - String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim(); - if (addressBody.isEmpty()) return; + private final Tensor.Builder builder; - for (String elementString : addressBody.split(",")) { - String[] pair = elementString.split(":"); - if (pair.length != 2) - throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " + - "got '" + elementString + "'"); - String dimension = pair[0].trim(); - builder.add(dimension, pair[1].trim()); + public SparseValueParser(String string, Tensor.Builder builder) { + super(string); + this.builder = builder; } + + private void parse() { + consume('{'); + skipSpace(); + while (position + 1 < string.length()) { + int keyOrTensorEnd = string.indexOf('}', position); + TensorAddress.Builder addressBuilder = new TensorAddress.Builder(builder.type()); + if (keyOrTensorEnd < string.length() - 1) { // Key end: This has a key - otherwise TensorAddress is empty + addLabels(string.substring(position, keyOrTensorEnd + 1), addressBuilder); + position = keyOrTensorEnd + 1; + skipSpace(); + consume(':'); + } + int valueEnd = string.indexOf(',', position); + if (valueEnd < 0) { // last value + valueEnd = string.indexOf('}', position); + if (valueEnd < 0) + throw new IllegalArgumentException("A sparse tensor string must end by '}'"); + } + + TensorAddress address = addressBuilder.build(); + TensorType.Value cellValueType = builder.type().valueType(); + String cellValueString = string.substring(position, valueEnd).trim(); + try { + if (cellValueType == TensorType.Value.DOUBLE) + builder.cell(address, Double.parseDouble(cellValueString)); + else if (cellValueType == TensorType.Value.FLOAT) + builder.cell(address, Float.parseFloat(cellValueString)); + else + throw new IllegalArgumentException(cellValueType + " is not supported"); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("At " + address.toString(builder.type()) + ": '" + + cellValueString + "' is not a valid " + cellValueType); + } + + position = valueEnd+1; + skipSpace(); + } + } + + /** Creates a tensor address from a string on the form {dimension1:label1,dimension2:label2,...} */ + private static void addLabels(String mapAddressString, TensorAddress.Builder builder) { + mapAddressString = mapAddressString.trim(); + if ( ! (mapAddressString.startsWith("{") && mapAddressString.endsWith("}"))) + throw new IllegalArgumentException("Expecting a tensor address enclosed in {}, got '" + mapAddressString + "'"); + + String addressBody = mapAddressString.substring(1, mapAddressString.length() - 1).trim(); + if (addressBody.isEmpty()) return; + + for (String elementString : addressBody.split(",")) { + String[] pair = elementString.split(":"); + if (pair.length != 2) + throw new IllegalArgumentException("Expecting argument elements on the form dimension:label, " + + "got '" + elementString + "'"); + String dimension = pair[0].trim(); + builder.add(dimension, pair[1].trim()); + } + } + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 95cc70804e2..ca3f8ff28a4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -82,7 +82,7 @@ public class TensorType { private final TensorType mappedSubtype; - private TensorType(Value valueType, Collection<Dimension> dimensions) { + public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); Collections.sort(dimensionList); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index def3ab6b4ec..4fdb0906740 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -24,6 +24,13 @@ public class TensorTypeParser { private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}"); public static TensorType fromSpec(String specString) { + return fromSpec(specString, null); + } + + /** + * @param dimensionOrder if not null, this will be populated with the dimension names in the order they are written + */ + static TensorType fromSpec(String specString, List<String> dimensionOrder) { specString = specString.trim(); if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING)) throw formatException(specString); @@ -48,10 +55,14 @@ public class TensorTypeParser { List<TensorType.Dimension> dimensions = new ArrayList<>(); for (String element : dimensionsSpec.split(",")) { String trimmedElement = element.trim(); - boolean success = tryParseIndexedDimension(trimmedElement, dimensions) || - tryParseMappedDimension(trimmedElement, dimensions); - if ( ! success) + TensorType.Dimension dimension = tryParseIndexedDimension(trimmedElement); + if (dimension == null) + dimension = tryParseMappedDimension(trimmedElement); + if (dimension == null) throw formatException(specString, "Dimension '" + element + "' is on the wrong format"); + dimensions.add(dimension); + if (dimensionOrder != null) + dimensionOrder.add(dimension.name()); } return new TensorType.Builder(valueType, dimensions).build(); } @@ -68,29 +79,26 @@ public class TensorTypeParser { } } - private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) { + private static TensorType.Dimension tryParseIndexedDimension(String element) { Matcher matcher = indexedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); String dimensionSize = matcher.group(2); - if (dimensionSize.isEmpty()) { - dimensions.add(TensorType.Dimension.indexed(dimensionName)); - } else { - dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize))); - } - return true; + if (dimensionSize.isEmpty()) + return TensorType.Dimension.indexed(dimensionName); + else + return TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)); } - return false; + return null; } - private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) { + private static TensorType.Dimension tryParseMappedDimension(String element) { Matcher matcher = mappedPattern.matcher(element); if (matcher.matches()) { String dimensionName = matcher.group(1); - dimensions.add(TensorType.Dimension.mapped(dimensionName)); - return true; + return TensorType.Dimension.mapped(dimensionName); } - return false; + return null; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java index 1928971820c..6f9a5c13886 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java @@ -22,6 +22,12 @@ public class TensorParserTestCase { } @Test + public void testSingle() { + assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor(x[1])")).cell(1.0, 0).build(), + "tensor(x[1]):[1.0]"); + } + + @Test public void testDenseParsing() { assertDense(Tensor.Builder.of(TensorType.fromSpec("tensor()")).build(), "tensor():{0.0}"); @@ -55,18 +61,9 @@ public class TensorParserTestCase { .cell(3.0, 1, 0, 0) .cell(4.0, 1, 1, 0) .cell(5.0, 2, 0, 0) - .cell(6.0, 2, 1, 0).build(), - "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [6.0]]]"); - assertEquals("Messy input", - Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) - .cell( 1.0, 0, 0, 0) - .cell( 2.0, 0, 1, 0) - .cell( 3.0, 1, 0, 0) - .cell( 4.0, 1, 1, 0) - .cell( 5.0, 2, 0, 0) .cell(-6.0, 2, 1, 0).build(), - Tensor.from("tensor( x[3],y[2],z[1]) : [ [ [1.0, 2.0, 3.0] , [4.0, 5,-6.0] ] ]")); - assertEquals("Skipping syntactic sugar", + "tensor(x[3],y[2],z[1]):[[[1.0], [2.0]], [[3.0], [4.0]], [[5.0], [-6.0]]]"); + assertEquals("Skipping structure", Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2],z[1])")) .cell( 1.0, 0, 0, 0) .cell( 2.0, 0, 1, 0) @@ -77,6 +74,59 @@ public class TensorParserTestCase { Tensor.from("tensor( x[3],y[2],z[1]) : [1.0, 2.0, 3.0 , 4.0, 5, -6.0]")); } + @Test + public void testDenseWrongOrder() { + assertEquals("Opposite order of dimensions", + Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2])")) + .cell(1, 0, 0) + .cell(4, 0, 1) + .cell(2, 1, 0) + .cell(5, 1, 1) + .cell(3, 2, 0) + .cell(6, 2, 1).build(), + Tensor.from("tensor(y[2],x[3]):[[1,2,3],[4,5,6]]")); + } + + @Test + public void testMixedParsing() { + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])")) + .cell(TensorAddress.ofLabels("a", "0"), 1) + .cell(TensorAddress.ofLabels("a", "1"), 2) + .cell(TensorAddress.ofLabels("b", "0"), 3) + .cell(TensorAddress.ofLabels("b", "1"), 4).build(), + Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}")); + } + + @Test + public void testSparseShortFormParsing() { + assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{})")) + .cell(TensorAddress.ofLabels("a"), 1) + .cell(TensorAddress.ofLabels("b"), 2).build(), + Tensor.from("tensor(key{}):{a:1, b:2}")); + } + + @Test + public void testMixedWrongOrder() { + assertEquals("Opposite order of dimensions", + Tensor.Builder.of(TensorType.fromSpec("tensor(key{},x[3],y[2])")) + .cell(TensorAddress.ofLabels("key1", "0", "0"), 1) + .cell(TensorAddress.ofLabels("key1", "0", "1"), 4) + .cell(TensorAddress.ofLabels("key1", "1", "0"), 2) + .cell(TensorAddress.ofLabels("key1", "1", "1"), 5) + .cell(TensorAddress.ofLabels("key1", "2", "0"), 3) + .cell(TensorAddress.ofLabels("key1", "2", "1"), 6) + .cell(TensorAddress.ofLabels("key2", "0", "0"), 7) + .cell(TensorAddress.ofLabels("key2", "0", "1"), 10) + .cell(TensorAddress.ofLabels("key2", "1", "0"), 8) + .cell(TensorAddress.ofLabels("key2", "1", "1"), 11) + .cell(TensorAddress.ofLabels("key2", "2", "0"), 9) + .cell(TensorAddress.ofLabels("key2", "2", "1"), 12).build(), + Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); + assertEquals("Opposite order of dimensions", + Tensor.from("tensor(key{},x[3],y[2]):{key1:[[1,4],[2,5],[3,6]], key2:[[7,10],[8,11],[9,12]]}"), + Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}")); + } + private void assertDense(Tensor expectedTensor, String denseFormat) { assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat)); assertEquals(denseFormat, expectedTensor.toString()); @@ -92,8 +142,12 @@ public class TensorParserTestCase { "{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}"); assertIllegal("At {x:0}: '1-.0' is not a valid double", "{{x:0}:1-.0}"); - assertIllegal("At index 0: '1-.0' is not a valid double", + assertIllegal("At value position 1: '1-.0' is not a valid double", "tensor(x[1]):[1-.0]"); + assertIllegal("At value position 5: Expected a ',' but got ']'", + "tensor(x[3]):[1, 2]"); + assertIllegal("At value position 8: Expected a ']' but got ','", + "tensor(x[3]):[1, 2, 3, 4]"); } private void assertIllegal(String message, String tensor) { @@ -102,7 +156,7 @@ public class TensorParserTestCase { fail("Expected an IllegalArgumentException when parsing " + tensor); } catch (IllegalArgumentException e) { - assertEquals(message, e.getMessage()); + assertEquals(message, e.getCause().getMessage()); } } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 11365531019..9f077cb7b00 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -56,7 +56,8 @@ public class TensorTestCase { fail("Expected parse error"); } catch (IllegalArgumentException expected) { - assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", expected.getMessage()); + assertEquals("Excepted a number or a string starting by {, [ or tensor(...):, got '--'", + expected.getCause().getMessage()); } } @@ -259,9 +260,9 @@ public class TensorTestCase { assertLargest("{d1:l1,d2:l1}:6.0, {d1:l1,d2:l2}:6.0", "tensor(d1{},d2{}):{{d1:l1,d2:l1}:6.0,{d1:l1,d2:l3}:5.0,{d1:l1,d2:l2}:6.0}"); assertLargest("{x:1,y:1}:4.0", - "tensor(x[2],y[2]):[[1,2],[3,4]"); + "tensor(x[2],y[2]):[[1,2],[3,4]]"); assertLargest("{x:0,y:0}:4.0, {x:1,y:1}:4.0", - "tensor(x[2],y[2]):[[4,2],[3,4]"); + "tensor(x[2],y[2]):[[4,2],[3,4]]"); } @Test @@ -273,9 +274,9 @@ public class TensorTestCase { assertSmallest("{d1:l1,d2:l1}:5.0, {d1:l1,d2:l2}:5.0", "tensor(d1{},d2{}):{{d1:l1,d2:l1}:5.0,{d1:l1,d2:l3}:6.0,{d1:l1,d2:l2}:5.0}"); assertSmallest("{x:0,y:0}:1.0", - "tensor(x[2],y[2]):[[1,2],[3,4]"); + "tensor(x[2],y[2]):[[1,2],[3,4]]"); assertSmallest("{x:0,y:1}:2.0", - "tensor(x[2],y[2]):[[4,2],[3,4]"); + "tensor(x[2],y[2]):[[4,2],[3,4]]"); } private void assertLargest(String expectedCells, String tensorString) { diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 2a34bc11b76..2231d32281a 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -10,6 +10,7 @@ import org.junit.Ignore; import org.junit.Test; import java.util.Collections; +import java.util.HashMap; import java.util.List; import static org.junit.Assert.assertEquals; @@ -20,21 +21,36 @@ import static org.junit.Assert.assertEquals; public class DynamicTensorTestCase { @Test - public void testDynamicTensorFunction() { + public void testDynamicIndexedRank1TensorFunction() { TensorType dense = TensorType.fromSpec("tensor(x[3])"); DynamicTensor<Name> t1 = DynamicTensor.from(dense, List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); + } + @Test + public void testDynamicMappedRank1TensorFunction() { TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor<Name> t2 = DynamicTensor.from(sparse, Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), - new Constant(5))); + new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } + @Test + public void testDynamicMappedRank2TensorFunction() { + TensorType sparse = TensorType.fromSpec("tensor(x{},y{})"); + HashMap<TensorAddress, ScalarFunction<Name>> values = new HashMap<>(); + values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "b").build(), + new Constant(5)); + values.put(new TensorAddress.Builder(sparse).add("x", "a").add("y", "c").build(), + new Constant(7)); + DynamicTensor<Name> t2 = DynamicTensor.from(sparse, values); + assertEquals(Tensor.from(sparse, "{{x:a,y:b}:5, {x:a,y:c}:7}"), t2.evaluate()); + } + @Ignore // Enable for benchmarking public void benchMarkTensorAddressBuilder() { long start = System.nanoTime(); |