summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-12-14 08:34:09 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-12-14 08:34:09 +0100
commitf5ccf036b4f7368f217a6bcbffc1699aac5eac2d (patch)
tree749afd3b29f52b918c67099c1742cb9db50211cf
parent3954dbe2403bdbb21e9a558fbc55fd137afa40f8 (diff)
Interpret dimensions in written order
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java18
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj43
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java23
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java95
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java66
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java38
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java37
9 files changed, 241 insertions, 88 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 11fc581640d..a248fa6dd45 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -19,6 +19,7 @@ import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.LinkedHashMap;
@@ -93,14 +94,15 @@ public class TensorFunctionNode extends CompositeNode {
}
public static void wrapScalarBlock(TensorType type,
+ List<String> dimensionOrder,
String mappedDimensionLabel,
List<ExpressionNode> nodes,
Map<TensorAddress, ScalarFunction<Reference>> receivingMap) {
- TensorType.Dimension sparseDimension = type.dimensions().stream().filter(d -> ! d.isIndexed()).findFirst().get();
TensorType denseSubtype = new TensorType(type.valueType(),
type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()));
-
- IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype);
+ List<String> denseDimensionOrder = new ArrayList<>(dimensionOrder);
+ denseDimensionOrder.retainAll(denseSubtype.dimensionNames());
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(denseSubtype, denseDimensionOrder);
for (ExpressionNode node : nodes) {
indexes.next();
@@ -119,7 +121,15 @@ public class TensorFunctionNode extends CompositeNode {
}
}
- public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
+ public static List<ScalarFunction<Reference>> wrapScalars(TensorType type,
+ List<String> dimensionOrder,
+ List<ExpressionNode> nodes) {
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type, dimensionOrder);
+ List<ScalarFunction<Reference>> wrapped = new ArrayList<>();
+ while (indexes.hasNext()) {
+ indexes.next();
+ wrapped.add(wrapScalar(nodes.get((int)indexes.toSourceValueIndex())));
+ }
return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 71456d0ed00..22d2abd4aef 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -475,13 +475,14 @@ TensorFunctionNode tensorConcat() :
TensorFunctionNode tensorGenerate() :
{
TensorType type;
+ List dimensionOrder = new ArrayList();
TensorFunctionNode expression;
}
{
- <TENSOR> type = tensorType()
+ <TENSOR> type = tensorType(dimensionOrder)
(
expression = tensorGenerateBody(type) |
- expression = tensorValueBody(type)
+ expression = tensorValueBody(type, dimensionOrder)
)
{ return expression; }
}
@@ -500,7 +501,7 @@ TensorFunctionNode tensorRange() :
TensorType type;
}
{
- <RANGE> type = tensorType()
+ <RANGE> type = tensorType(null)
{ return new TensorFunctionNode(new Range(type)); }
}
@@ -509,7 +510,7 @@ TensorFunctionNode tensorDiag() :
TensorType type;
}
{
- <DIAG> type = tensorType()
+ <DIAG> type = tensorType(null)
{ return new TensorFunctionNode(new Diag(type)); }
}
@@ -518,7 +519,7 @@ TensorFunctionNode tensorRandom() :
TensorType type;
}
{
- <RANDOM> type = tensorType()
+ <RANDOM> type = tensorType(null)
{ return new TensorFunctionNode(new Random(type)); }
}
@@ -618,7 +619,7 @@ Reduce.Aggregator tensorReduceAggregator() :
{ return Reduce.Aggregator.valueOf(token.image); }
}
-TensorType tensorType() :
+TensorType tensorType(List dimensionOrder) :
{
TensorType.Builder builder;
TensorType.Value valueType;
@@ -627,8 +628,8 @@ TensorType tensorType() :
valueType = optionalTensorValueTypeParameter()
{ builder = new TensorType.Builder(valueType); }
<LBRACE>
- ( tensorTypeDimension(builder) ) ?
- ( <COMMA> tensorTypeDimension(builder) ) *
+ ( tensorTypeDimension(builder, dimensionOrder) ) ?
+ ( <COMMA> tensorTypeDimension(builder, dimensionOrder) ) *
<RBRACE>
{ return builder.build(); }
}
@@ -642,13 +643,17 @@ TensorType.Value optionalTensorValueTypeParameter() :
{ return TensorType.Value.fromId(valueType); }
}
-void tensorTypeDimension(TensorType.Builder builder) :
+void tensorTypeDimension(TensorType.Builder builder, List dimensionOrder) :
{
String name;
int size;
}
{
name = identifier()
+ { // Keep track of the order in which dimensions are written, if necessary
+ if (dimensionOrder != null)
+ dimensionOrder.add(name);
+ }
(
( <LCURLY> <RCURLY> { builder.mapped(name); } ) |
LOOKAHEAD(2) ( <LSQUARE> <RSQUARE> { builder.indexed(name); } ) |
@@ -831,16 +836,16 @@ Value primitiveValue() :
{ return Value.parse(sign + token.image); }
}
-TensorFunctionNode tensorValueBody(TensorType type) :
+TensorFunctionNode tensorValueBody(TensorType type, List dimensionOrder) :
{
DynamicTensor dynamicTensor;
}
{
<COLON>
(
- LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type) |
+ LOOKAHEAD(2) dynamicTensor = mixedTensorValueBody(type, dimensionOrder) |
dynamicTensor = mappedTensorValueBody(type) |
- dynamicTensor = indexedTensorValueBody(type)
+ dynamicTensor = indexedTensorValueBody(type, dimensionOrder)
)
{ return new TensorFunctionNode(dynamicTensor); }
}
@@ -857,35 +862,35 @@ DynamicTensor mappedTensorValueBody(TensorType type) :
{ return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
-DynamicTensor mixedTensorValueBody(TensorType type) :
+DynamicTensor mixedTensorValueBody(TensorType type, List dimensionOrder) :
{
java.util.Map cells = new LinkedHashMap();
}
{
<LCURLY>
- mixedBlock(type, cells)
- ( <COMMA> mixedBlock(type, cells))*
+ mixedBlock(type, dimensionOrder, cells)
+ ( <COMMA> mixedBlock(type, dimensionOrder, cells))*
<RCURLY>
{ return DynamicTensor.from(type, cells); }
}
-DynamicTensor indexedTensorValueBody(TensorType type) :
+DynamicTensor indexedTensorValueBody(TensorType type, List dimensionOrder) :
{
List cells;
}
{
cells = indexedTensorCells()
- { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells, type, dimensionOrder)); }
}
-void mixedBlock(TensorType type, java.util.Map cellMap) :
+void mixedBlock(TensorType type, List dimensionOrder, java.util.Map cellMap) :
{
String label;
List cells;
}
{
label = tag() <COLON> cells = indexedTensorCells()
- { TensorFunctionNode.wrapScalarBlock(type, label, cells, cellMap); }
+ { TensorFunctionNode.wrapScalarBlock(type, dimensionOrder, label, cells, cellMap); }
}
List indexedTensorCells() :
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 0601043f2ce..fa65ce0408b 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -408,6 +408,29 @@ public class EvaluationTestCase {
"tensor(x{},y[2]):{{x:a,y:0}:one, {x:a,y:1}:one_half, {x:b,y:0}:a_quarter, {x:b,y:1}:2}");
tester.assertEvaluates("tensor(x{},y[2]):{a:[1.0, 0.5], b:[0.25, 2]}",
"tensor(x{},y[2]):{a:[one, one_half], b:[a_quarter, 2]}");
+ tester.assertEvaluates("tensor(key{},x[2],y[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," +
+ " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}",
+ "tensor(key{},x[2],y[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," +
+ " key2:[[1,2,3],[4,5,6]]}");
+
+ // Opposite order in the expression:
+ // - indexed
+ tester.assertEvaluates("tensor(x[3],y[2]):[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]",
+ "tensor(y[2],x[3]):[[one,one_half,a_quarter],[a_quarter,one_half,one]]");
+ // - mixed
+ tester.assertEvaluates("tensor(key{},x[3],y[2]):{key1:[[1.0, 0.25], [0.5,0.5], [0.25, 1.0]]," +
+ " key2:[[1.0, 4.00], [2.0,5.0], [3.00, 6.0]]}",
+ "tensor(key{},y[2],x[3]):{key1:[[one,one_half,a_quarter],[a_quarter,one_half,one]]," +
+ " key2:[[1,2,3],[4,5,6]]}");
+ // Opposite order in literal parsing:
+ // - indexed
+ tester.assertEvaluates("tensor(y[2],x[3]):[[1,0.25,0.5],[0.5,0.25,1]]",
+ "tensor(x[3],y[2]):[[one,one_half], [a_quarter,a_quarter], [one_half,one]]");
+ // - mixed
+ tester.assertEvaluates("tensor(key{},y[2],x[3]):{key1:[[1.0, 0.5, 0.25],[0.25, 0.5, 1.0]]," +
+ " key2:[[1.0, 2.0, 3.00],[4.00, 5.0, 6.0]]}",
+ "tensor(key{},x[3],y[2]):{key1:[[one,a_quarter],[one_half,one_half],[a_quarter,one]]," +
+ " key2:[[1,4],[2,5],[3,6]]}");
}
@Test
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index d91b38a8a96..1fcdf7f5cca 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -693,6 +693,7 @@
"methods": [
"public void <init>(int)",
"public com.yahoo.tensor.DimensionSizes$Builder set(int, long)",
+ "public com.yahoo.tensor.DimensionSizes$Builder add(long)",
"public long size(int)",
"public int dimensions()",
"public com.yahoo.tensor.DimensionSizes build()"
@@ -836,10 +837,12 @@
],
"methods": [
"public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)",
+ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType, java.util.List)",
"public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)",
"public com.yahoo.tensor.TensorAddress toAddress()",
"public long[] indexesCopy()",
"public long[] indexesForReading()",
+ "public long toSourceValueIndex()",
"public java.util.List toList()",
"public java.lang.String toString()",
"public abstract long size()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index d81c02fb75f..202817ece42 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -71,6 +71,7 @@ public final class DimensionSizes {
*/
public final static class Builder {
+ private int dimensionIndex = 0;
private long[] sizes;
public Builder(int dimensions) {
@@ -82,6 +83,11 @@ public final class DimensionSizes {
return this;
}
+ public Builder add(long size) {
+ sizes[dimensionIndex++] = size;
+ return this;
+ }
+
/**
* Returns the length of this in the nth dimension
*
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 30923976fa5..ba3a35e8eda 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -218,7 +218,7 @@ public abstract class IndexedTensor implements Tensor {
indexes.next();
// start brackets
- for (int i = 0; i < indexes.rightDimensionsAtStart(); i++)
+ for (int i = 0; i < indexes.nextDimensionsAtStart(); i++)
b.append("[");
// value
@@ -230,7 +230,7 @@ public abstract class IndexedTensor implements Tensor {
throw new IllegalStateException("Unexpected value type " + type.valueType());
// end bracket and comma
- for (int i = 0; i < indexes.rightDimensionsAtEnd(); i++)
+ for (int i = 0; i < indexes.nextDimensionsAtEnd(); i++)
b.append("]");
if (index < size() - 1)
b.append(", ");
@@ -777,6 +777,10 @@ public abstract class IndexedTensor implements Tensor {
return of(DimensionSizes.of(type));
}
+ public static Indexes of(TensorType type, List<String> iterateDimensionOrder) {
+ return of(DimensionSizes.of(type), toIterationOrder(iterateDimensionOrder, type));
+ }
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -789,6 +793,10 @@ public abstract class IndexedTensor implements Tensor {
return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size);
}
+ private static Indexes of(DimensionSizes sizes, List<Integer> iterateDimensions) {
+ return of(sizes, sizes, iterateDimensions);
+ }
+
private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List<Integer> iterateDimensions) {
return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions));
}
@@ -822,6 +830,16 @@ public abstract class IndexedTensor implements Tensor {
}
}
+ private static List<Integer> toIterationOrder(List<String> dimensionNames, TensorType type) {
+ if (dimensionNames == null) return completeIterationOrder(type.rank());
+
+ List<Integer> iterationDimensions = new ArrayList<>(type.rank());
+ for (int i = 0; i < type.rank(); i++)
+ iterationDimensions.add(type.rank() - 1 - type.indexOfDimension(dimensionNames.get(i)).get());
+ return iterationDimensions;
+ }
+
+ /** Since the right dimensions binds closest, iteration order is the opposite of the tensor order */
private static List<Integer> completeIterationOrder(int length) {
List<Integer> iterationDimensions = new ArrayList<>(length);
for (int i = 0; i < length; i++)
@@ -854,7 +872,7 @@ public abstract class IndexedTensor implements Tensor {
/** Returns a copy of the indexes of this which must not be modified */
public long[] indexesForReading() { return indexes; }
- long toSourceValueIndex() {
+ public long toSourceValueIndex() {
return IndexedTensor.toValueIndex(indexes, sourceSizes);
}
@@ -882,27 +900,12 @@ public abstract class IndexedTensor implements Tensor {
/** Returns whether further values are available by calling next() */
public abstract boolean hasNext();
- /** Returns the number of dimensions from the right which are currently at the start position (0) */
- int rightDimensionsAtStart() {
- int dimension = indexes.length - 1;
- int atStartCount = 0;
- while (dimension >= 0 && indexes[dimension] == 0) {
- atStartCount++;
- dimension--;
- }
- return atStartCount;
- }
+ /** Returns the number of dimensions in iteration order which are currently at the start position (0) */
+ abstract int nextDimensionsAtStart();
+
+ /** Returns the number of dimensions in iteration order which are currently at their end position */
+ abstract int nextDimensionsAtEnd();
- /** Returns the number of dimensions from the right which are currently at the end position */
- int rightDimensionsAtEnd() {
- int dimension = indexes.length - 1;
- int atEndCount = 0;
- while (dimension >= 0 && indexes[dimension] == dimensionSizes().size(dimension) - 1) {
- atEndCount++;
- dimension--;
- }
- return atEndCount;
- }
}
private final static class EmptyIndexes extends Indexes {
@@ -920,6 +923,12 @@ public abstract class IndexedTensor implements Tensor {
@Override
public boolean hasNext() { return false; }
+ @Override
+ int nextDimensionsAtStart() { return 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return 0; }
+
}
private final static class SingleValueIndexes extends Indexes {
@@ -939,6 +948,12 @@ public abstract class IndexedTensor implements Tensor {
@Override
public boolean hasNext() { return ! exhausted; }
+ @Override
+ int nextDimensionsAtStart() { return 1; }
+
+ @Override
+ int nextDimensionsAtEnd() { return 1; }
+
}
private static class MultiDimensionIndexes extends Indexes {
@@ -987,6 +1002,22 @@ public abstract class IndexedTensor implements Tensor {
return false;
}
+ @Override
+ int nextDimensionsAtStart() {
+ int dimension = 0;
+ while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == 0)
+ dimension++;
+ return dimension;
+ }
+
+ @Override
+ int nextDimensionsAtEnd() {
+ int dimension = 0;
+ while (dimension < iterateDimensions.size() && indexes[iterateDimensions.get(dimension)] == dimensionSizes().size(iterateDimensions.get(dimension)) - 1)
+ dimension++;
+ return dimension;
+ }
+
}
/** In this case we can reuse the source index computation for the iteration index */
@@ -999,7 +1030,7 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() {
+ public long toSourceValueIndex() {
return lastComputedSourceValueIndex = super.toSourceValueIndex();
}
@@ -1056,7 +1087,7 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() { return currentSourceValueIndex; }
+ public long toSourceValueIndex() { return currentSourceValueIndex; }
@Override
long toIterationValueIndex() { return currentIterationValueIndex; }
@@ -1066,6 +1097,12 @@ public abstract class IndexedTensor implements Tensor {
return indexes[iterateDimension] + 1 < size;
}
+ @Override
+ int nextDimensionsAtStart() { return currentSourceValueIndex == 0 ? 1 : 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return currentSourceValueIndex == size - 1 ? 1 : 0; }
+
}
/** In this case we only need to keep track of one index */
@@ -1117,11 +1154,17 @@ public abstract class IndexedTensor implements Tensor {
}
@Override
- long toSourceValueIndex() { return currentValueIndex; }
+ public long toSourceValueIndex() { return currentValueIndex; }
@Override
long toIterationValueIndex() { return currentValueIndex; }
+ @Override
+ int nextDimensionsAtStart() { return currentValueIndex == 0 ? 1 : 0; }
+
+ @Override
+ int nextDimensionsAtEnd() { return currentValueIndex == size - 1 ? 1 : 0; }
+
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 8d07a1ed9a8..ea21249bede 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
@@ -23,11 +24,17 @@ class TensorParser {
Optional<TensorType> type;
String valueString;
+ // The order in which dimensions are written in the type string.
+ // This allows the user's explicit dimension order to decide what (dense) dimensions map to what, rather than
+ // the natural order of the tensor.
+ List<String> dimensionOrder;
+
tensorString = tensorString.trim();
if (tensorString.startsWith("tensor")) {
int colonIndex = tensorString.indexOf(':');
String typeString = tensorString.substring(0, colonIndex);
- TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ dimensionOrder = new ArrayList<>();
+ TensorType typeFromString = TensorTypeParser.fromSpec(typeString, dimensionOrder);
if (explicitType.isPresent() && ! explicitType.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
"passed type " + explicitType.get());
@@ -37,6 +44,7 @@ class TensorParser {
else {
type = explicitType;
valueString = tensorString;
+ dimensionOrder = null;
}
valueString = valueString.trim();
@@ -45,10 +53,10 @@ class TensorParser {
return tensorFromSparseValueString(valueString, type);
}
else if (valueString.startsWith("{")) {
- return tensorFromMixedValueString(valueString, type);
+ return tensorFromMixedValueString(valueString, type, dimensionOrder);
}
else if (valueString.startsWith("[")) {
- return tensorFromDenseValueString(valueString, type);
+ return tensorFromDenseValueString(valueString, type, dimensionOrder);
}
else {
if (explicitType.isPresent() && ! explicitType.get().equals(TensorType.empty))
@@ -102,7 +110,9 @@ class TensorParser {
}
}
- private static Tensor tensorFromMixedValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromMixedValueString(String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The mixed tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
@@ -117,7 +127,7 @@ class TensorParser {
throw new IllegalArgumentException("A mixed tensor must be enclosed in {}");
// TODO: Check if there is also at least one bound indexed dimension
MixedTensor.BoundBuilder builder = (MixedTensor.BoundBuilder)Tensor.Builder.of(type.get());
- MixedValueParser parser = new MixedValueParser(valueString, builder);
+ MixedValueParser parser = new MixedValueParser(valueString, dimensionOrder, builder);
parser.parse();
return builder.build();
}
@@ -126,7 +136,9 @@ class TensorParser {
}
}
- private static Tensor tensorFromDenseValueString(String valueString, Optional<TensorType> type) {
+ private static Tensor tensorFromDenseValueString(String valueString,
+ Optional<TensorType> type,
+ List<String> dimensionOrder) {
if (type.isEmpty())
throw new IllegalArgumentException("The dense tensor form requires an explicit tensor type " +
"on the form 'tensor(dimensions):...");
@@ -135,7 +147,7 @@ class TensorParser {
"only dense dimensions with a given size");
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder) IndexedTensor.Builder.of(type.get());
- new DenseValueParser(valueString, builder).parse();
+ new DenseValueParser(valueString, dimensionOrder, builder).parse();
return builder.build();
}
@@ -157,10 +169,10 @@ class TensorParser {
skipSpace();
if (position >= string.length())
- throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
"' but got the end of the string");
if ( string.charAt(position) != character)
- throw new IllegalArgumentException("At position " + position + ": Expected a '" + character +
+ throw new IllegalArgumentException("At value position " + position + ": Expected a '" + character +
"' but got '" + string.charAt(position) + "'");
position++;
}
@@ -176,10 +188,12 @@ class TensorParser {
private long tensorIndex = 0;
- public DenseValueParser(String string, IndexedTensor.DirectIndexBuilder builder) {
+ public DenseValueParser(String string,
+ List<String> dimensionOrder,
+ IndexedTensor.DirectIndexBuilder builder) {
super(string);
this.builder = builder;
- indexes = IndexedTensor.Indexes.of(builder.type());
+ indexes = IndexedTensor.Indexes.of(builder.type(), dimensionOrder);
hasInnerStructure = hasInnerStructure(string);
}
@@ -189,10 +203,10 @@ class TensorParser {
while (indexes.hasNext()) {
indexes.next();
- for (int i = 0; i < indexes.rightDimensionsAtStart() && hasInnerStructure; i++)
+ for (int i = 0; i < indexes.nextDimensionsAtStart() && hasInnerStructure; i++)
consume('[');
consumeNumber();
- for (int i = 0; i < indexes.rightDimensionsAtEnd() && hasInnerStructure; i++)
+ for (int i = 0; i < indexes.nextDimensionsAtEnd() && hasInnerStructure; i++)
consume(']');
if (indexes.hasNext())
consume(',');
@@ -220,14 +234,14 @@ class TensorParser {
String cellValueString = string.substring(position, nextNumberEnd);
try {
if (cellValueType == TensorType.Value.DOUBLE)
- builder.cellByDirectIndex(tensorIndex++, Double.parseDouble(cellValueString));
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), Double.parseDouble(cellValueString));
else if (cellValueType == TensorType.Value.FLOAT)
- builder.cellByDirectIndex(tensorIndex++, Float.parseFloat(cellValueString));
+ builder.cellByDirectIndex(indexes.toSourceValueIndex(), Float.parseFloat(cellValueString));
else
throw new IllegalArgumentException(cellValueType + " is not supported");
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("At position " + position + ": '" +
+ throw new IllegalArgumentException("At value position " + position + ": '" +
cellValueString + "' is not a valid " + cellValueType);
}
position = nextNumberEnd;
@@ -248,15 +262,19 @@ class TensorParser {
private static class MixedValueParser extends ValueParser {
private final MixedTensor.BoundBuilder builder;
+ private List<String> dimensionOrder;
- public MixedValueParser(String string, MixedTensor.BoundBuilder builder) {
+ public MixedValueParser(String string, List<String> dimensionOrder, MixedTensor.BoundBuilder builder) {
super(string);
+ this.dimensionOrder = dimensionOrder;
this.builder = builder;
}
private void parse() {
- TensorType.Dimension sparseDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
- TensorType sparseSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(sparseDimension));
+ TensorType.Dimension mappedDimension = builder.type().dimensions().stream().filter(d -> ! d.isIndexed()).findAny().get();
+ TensorType mappedSubtype = MixedTensor.createPartialType(builder.type().valueType(), List.of(mappedDimension));
+ if (dimensionOrder != null)
+ dimensionOrder.remove(mappedDimension.name());
skipSpace();
consume('{');
@@ -269,16 +287,18 @@ class TensorParser {
position = labelEnd + 1;
skipSpace();
- TensorAddress sparseAddress = new TensorAddress.Builder(sparseSubtype).add(sparseDimension.name(), label).build();
- parseDenseSubspace(sparseAddress);
+ TensorAddress mappedAddress = new TensorAddress.Builder(mappedSubtype).add(mappedDimension.name(), label).build();
+ parseDenseSubspace(mappedAddress, dimensionOrder);
if ( ! consumeOptional(','))
consume('}');
skipSpace();
}
}
- private void parseDenseSubspace(TensorAddress sparseAddress) {
- DenseValueParser denseParser = new DenseValueParser(string.substring(position), builder.denseSubspaceBuilder(sparseAddress));
+ private void parseDenseSubspace(TensorAddress sparseAddress, List<String> denseDimensionOrder) {
+ DenseValueParser denseParser = new DenseValueParser(string.substring(position),
+ denseDimensionOrder,
+ builder.denseSubspaceBuilder(sparseAddress));
denseParser.parse();
position+= denseParser.position();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index def3ab6b4ec..4fdb0906740 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -24,6 +24,13 @@ public class TensorTypeParser {
private static final Pattern mappedPattern = Pattern.compile("(\\w+)\\{\\}");
public static TensorType fromSpec(String specString) {
+ return fromSpec(specString, null);
+ }
+
+ /**
+ * @param dimensionOrder if not null, this will be populated with the dimension names in the order they are written
+ */
+ static TensorType fromSpec(String specString, List<String> dimensionOrder) {
specString = specString.trim();
if ( ! specString.startsWith(START_STRING) || ! specString.endsWith(END_STRING))
throw formatException(specString);
@@ -48,10 +55,14 @@ public class TensorTypeParser {
List<TensorType.Dimension> dimensions = new ArrayList<>();
for (String element : dimensionsSpec.split(",")) {
String trimmedElement = element.trim();
- boolean success = tryParseIndexedDimension(trimmedElement, dimensions) ||
- tryParseMappedDimension(trimmedElement, dimensions);
- if ( ! success)
+ TensorType.Dimension dimension = tryParseIndexedDimension(trimmedElement);
+ if (dimension == null)
+ dimension = tryParseMappedDimension(trimmedElement);
+ if (dimension == null)
throw formatException(specString, "Dimension '" + element + "' is on the wrong format");
+ dimensions.add(dimension);
+ if (dimensionOrder != null)
+ dimensionOrder.add(dimension.name());
}
return new TensorType.Builder(valueType, dimensions).build();
}
@@ -68,29 +79,26 @@ public class TensorTypeParser {
}
}
- private static boolean tryParseIndexedDimension(String element, List<TensorType.Dimension> dimensions) {
+ private static TensorType.Dimension tryParseIndexedDimension(String element) {
Matcher matcher = indexedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
String dimensionSize = matcher.group(2);
- if (dimensionSize.isEmpty()) {
- dimensions.add(TensorType.Dimension.indexed(dimensionName));
- } else {
- dimensions.add(TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize)));
- }
- return true;
+ if (dimensionSize.isEmpty())
+ return TensorType.Dimension.indexed(dimensionName);
+ else
+ return TensorType.Dimension.indexed(dimensionName, Integer.valueOf(dimensionSize));
}
- return false;
+ return null;
}
- private static boolean tryParseMappedDimension(String element, List<TensorType.Dimension> dimensions) {
+ private static TensorType.Dimension tryParseMappedDimension(String element) {
Matcher matcher = mappedPattern.matcher(element);
if (matcher.matches()) {
String dimensionName = matcher.group(1);
- dimensions.add(TensorType.Dimension.mapped(dimensionName));
- return true;
+ return TensorType.Dimension.mapped(dimensionName);
}
- return false;
+ return null;
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
index b2aba5b02eb..9dfdee29845 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorParserTestCase.java
@@ -75,6 +75,19 @@ public class TensorParserTestCase {
}
@Test
+ public void testDenseWrongOrder() {
+ assertEquals("Opposite order of dimensions",
+ Tensor.Builder.of(TensorType.fromSpec("tensor(x[3],y[2])"))
+ .cell(1, 0, 0)
+ .cell(4, 0, 1)
+ .cell(2, 1, 0)
+ .cell(5, 1, 1)
+ .cell(3, 2, 0)
+ .cell(6, 2, 1).build(),
+ Tensor.from("tensor(y[2],x[3]):[[1,2,3],[4,5,6]]"));
+ }
+
+ @Test
public void testMixedParsing() {
assertEquals(Tensor.Builder.of(TensorType.fromSpec("tensor(key{}, x[2])"))
.cell(TensorAddress.ofLabels("a", "0"), 1)
@@ -84,6 +97,28 @@ public class TensorParserTestCase {
Tensor.from("tensor(key{}, x[2]):{a:[1, 2], b:[3, 4]}"));
}
+ @Test
+ public void testMixedWrongOrder() {
+ assertEquals("Opposite order of dimensions",
+ Tensor.Builder.of(TensorType.fromSpec("tensor(key{},x[3],y[2])"))
+ .cell(TensorAddress.ofLabels("key1", "0", "0"), 1)
+ .cell(TensorAddress.ofLabels("key1", "0", "1"), 4)
+ .cell(TensorAddress.ofLabels("key1", "1", "0"), 2)
+ .cell(TensorAddress.ofLabels("key1", "1", "1"), 5)
+ .cell(TensorAddress.ofLabels("key1", "2", "0"), 3)
+ .cell(TensorAddress.ofLabels("key1", "2", "1"), 6)
+ .cell(TensorAddress.ofLabels("key2", "0", "0"), 7)
+ .cell(TensorAddress.ofLabels("key2", "0", "1"), 10)
+ .cell(TensorAddress.ofLabels("key2", "1", "0"), 8)
+ .cell(TensorAddress.ofLabels("key2", "1", "1"), 11)
+ .cell(TensorAddress.ofLabels("key2", "2", "0"), 9)
+ .cell(TensorAddress.ofLabels("key2", "2", "1"), 12).build(),
+ Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}"));
+ assertEquals("Opposite order of dimensions",
+ Tensor.from("tensor(key{},x[3],y[2]):{key1:[[1,4],[2,5],[3,6]], key2:[[7,10],[8,11],[9,12]]}"),
+ Tensor.from("tensor(key{},y[2],x[3]):{key1:[[1,2,3],[4,5,6]], key2:[[7,8,9],[10,11,12]]}"));
+ }
+
private void assertDense(Tensor expectedTensor, String denseFormat) {
assertEquals(denseFormat, expectedTensor, Tensor.from(denseFormat));
assertEquals(denseFormat, expectedTensor.toString());
@@ -99,7 +134,7 @@ public class TensorParserTestCase {
"{{\"x\":\"l0\", \"y\":\"l0\"}:1.0, {\"x\":\"l0\", \"y\":\"l1\"}:2.0}");
assertIllegal("At {x:0}: '1-.0' is not a valid double",
"{{x:0}:1-.0}");
- assertIllegal("At position 1: '1-.0' is not a valid double",
+ assertIllegal("At value position 1: '1-.0' is not a valid double",
"tensor(x[1]):[1-.0]");
}