aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/javacc/RankingExpressionParser.jj
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-11-05 22:49:08 +0100
committerJon Bratseth <bratseth@verizonmedia.com>2019-11-05 22:49:08 +0100
commited8c274dc76794efa692efba6cf509b058b13648 (patch)
treec1dcb9fbc70b851be5cfdb8c335089283715f698 /searchlib/src/main/javacc/RankingExpressionParser.jj
parent64c5daa351557869e64786188afa75ed3b59991b (diff)
Literal tensors with value expressions
Diffstat (limited to 'searchlib/src/main/javacc/RankingExpressionParser.jj')
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj96
1 files changed, 48 insertions, 48 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 8f411bf6593..47555d95e58 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+@SuppressWarnings({"rawtypes", "unchecked"})
public class RankingExpressionParser {
}
@@ -401,8 +402,8 @@ ExpressionNode tensorMap() :
}
{
<MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
- doubleMapper.asDoubleUnaryOperator())); }
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
}
ExpressionNode tensorReduce() :
@@ -413,7 +414,7 @@ ExpressionNode tensorReduce() :
}
{
<REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorReduceComposites() :
@@ -425,7 +426,7 @@ ExpressionNode tensorReduceComposites() :
{
aggregator = tensorReduceAggregator()
<LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorJoin() :
@@ -435,9 +436,9 @@ ExpressionNode tensorJoin() :
}
{
<JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- doubleJoiner.asDoubleBinaryOperator())); }
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
}
ExpressionNode tensorRename() :
@@ -450,7 +451,7 @@ ExpressionNode tensorRename() :
fromDimensions = bracedIdentifierList() <COMMA>
toDimensions = bracedIdentifierList()
<RBRACE>
- { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); }
}
ExpressionNode tensorConcat() :
@@ -460,8 +461,8 @@ ExpressionNode tensorConcat() :
}
{
<CONCAT> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = tag() <RBRACE>
- { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
+ { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
dimension)); }
}
@@ -522,7 +523,7 @@ ExpressionNode tensorL1Normalize() :
}
{
<L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorL2Normalize() :
@@ -532,7 +533,7 @@ ExpressionNode tensorL2Normalize() :
}
{
<L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorMatmul() :
@@ -542,9 +543,9 @@ ExpressionNode tensorMatmul() :
}
{
<MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- dimension)); }
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
}
ExpressionNode tensorSoftmax() :
@@ -554,7 +555,7 @@ ExpressionNode tensorSoftmax() :
}
{
<SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorXwPlusB() :
@@ -567,9 +568,9 @@ ExpressionNode tensorXwPlusB() :
tensor2 = expression() <COMMA>
tensor3 = expression() <COMMA>
dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- TensorFunctionNode.wrapArgument(tensor3),
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ TensorFunctionNode.wrap(tensor3),
dimension)); }
}
@@ -580,7 +581,7 @@ ExpressionNode tensorArgmax() :
}
{
<ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorArgmin() :
@@ -590,7 +591,7 @@ ExpressionNode tensorArgmin() :
}
{
<ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimension)); }
}
LambdaFunctionNode lambdaFunction() :
@@ -823,63 +824,62 @@ Value primitiveValue() :
{ return Value.parse(sign + token.image); }
}
-ConstantNode tensorValueBody(TensorType type) :
+ExpressionNode tensorValueBody(TensorType type) :
{
- Tensor.Builder builder = Tensor.Builder.of(type);
+ DynamicTensor dynamicTensor;
}
{
<COLON>
(
- mappedTensorValueBody(builder) |
- indexedTensorValueBody(builder)
+ dynamicTensor = mappedTensorValueBody(type) |
+ dynamicTensor = indexedTensorValueBody(type)
)
- { return new ConstantNode(new TensorValue(builder.build())); }
+ { return new TensorFunctionNode(dynamicTensor); }
}
-void mappedTensorValueBody(Tensor.Builder builder) : {}
+DynamicTensor mappedTensorValueBody(TensorType type) :
+{
+ java.util.Map cells = new LinkedHashMap();
+}
{
<LCURLY>
- ( tensorCell(builder.cell()))*
- ( <COMMA> tensorCell(builder.cell()))*
+ ( tensorCell(type, cells))*
+ ( <COMMA> tensorCell(type, cells))*
<RCURLY>
+ { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
}
-void indexedTensorValueBody(Tensor.Builder builder) :
+DynamicTensor indexedTensorValueBody(TensorType type) :
{
- IndexedTensor.BoundBuilder indexedBuilder;
- long index = 0;
- double value;
+ List cells = new ArrayList();
+ ExpressionNode value;
}
{
- {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
- throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " +
- "bound tensors, not " + builder.type());
- indexedBuilder = (IndexedTensor.BoundBuilder)builder;
- }
<LSQUARE>
- ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
- ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
+ ( value = expression() { cells.add(value); } )*
+ ( <COMMA> value = expression() { cells.add(value); } )*
<RSQUARE>
+ { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
}
-void tensorCell(Tensor.Builder.CellBuilder cellBuilder) :
+void tensorCell(TensorType type, java.util.Map cells) :
{
- double value;
+ ExpressionNode value;
+ TensorAddress.Builder addressBuilder = new TensorAddress.Builder(type);
}
{
<LCURLY>
- ( labelAndDimension(cellBuilder))*
- ( <COMMA> labelAndDimension(cellBuilder))*
+ ( labelAndDimension(addressBuilder))*
+ ( <COMMA> labelAndDimension(addressBuilder))*
<RCURLY>
- <COLON> value = doubleNumber() { cellBuilder.value(value); }
+ <COLON> value = expression() { cells.put(addressBuilder.build(), value); }
}
-void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) :
+void labelAndDimension(TensorAddress.Builder addressBuilder) :
{
String dimension, label;
}
{
dimension = identifier() <COLON> label = tag()
- { cellBuilder.label(dimension, label); }
+ { addressBuilder.add(dimension, label); }
} \ No newline at end of file