diff options
Diffstat (limited to 'searchlib/src/main/javacc/RankingExpressionParser.jj')
-rwxr-xr-x | searchlib/src/main/javacc/RankingExpressionParser.jj | 96 |
1 files changed, 48 insertions, 48 deletions
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 8f411bf6593..47555d95e58 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +@SuppressWarnings({"rawtypes", "unchecked"}) public class RankingExpressionParser { } @@ -401,8 +402,8 @@ ExpressionNode tensorMap() : } { <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), - doubleMapper.asDoubleUnaryOperator())); } + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor), + doubleMapper.asDoubleUnaryOperator())); } } ExpressionNode tensorReduce() : @@ -413,7 +414,7 @@ ExpressionNode tensorReduce() : } { <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } ExpressionNode tensorReduceComposites() : @@ -425,7 +426,7 @@ ExpressionNode tensorReduceComposites() : { aggregator = tensorReduceAggregator() <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } ExpressionNode tensorJoin() : @@ -435,9 +436,9 @@ ExpressionNode tensorJoin() : } { <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - doubleJoiner.asDoubleBinaryOperator())); } + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } } ExpressionNode tensorRename() : @@ -450,7 +451,7 @@ ExpressionNode tensorRename() : fromDimensions = bracedIdentifierList() <COMMA> toDimensions = bracedIdentifierList() <RBRACE> - { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); } } ExpressionNode tensorConcat() : @@ -460,8 +461,8 @@ ExpressionNode tensorConcat() : } { <CONCAT> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = tag() <RBRACE> - { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), + { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), dimension)); } } @@ -522,7 +523,7 @@ ExpressionNode tensorL1Normalize() : } { <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorL2Normalize() : @@ -532,7 +533,7 @@ ExpressionNode tensorL2Normalize() : } { <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorMatmul() : @@ -542,9 +543,9 @@ ExpressionNode tensorMatmul() : } { <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - dimension)); } + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + dimension)); } } ExpressionNode tensorSoftmax() : @@ -554,7 +555,7 @@ ExpressionNode tensorSoftmax() : } { <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorXwPlusB() : @@ -567,9 +568,9 @@ ExpressionNode tensorXwPlusB() : tensor2 = expression() <COMMA> tensor3 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - TensorFunctionNode.wrapArgument(tensor3), + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + TensorFunctionNode.wrap(tensor3), dimension)); } } @@ -580,7 +581,7 @@ ExpressionNode tensorArgmax() : } { <ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorArgmin() : @@ -590,7 +591,7 @@ ExpressionNode tensorArgmin() : } { <ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimension)); } } LambdaFunctionNode lambdaFunction() : @@ -823,63 +824,62 @@ Value primitiveValue() : { return Value.parse(sign + token.image); } } -ConstantNode tensorValueBody(TensorType type) : +ExpressionNode tensorValueBody(TensorType type) : { - Tensor.Builder builder = Tensor.Builder.of(type); + DynamicTensor dynamicTensor; } { <COLON> ( - mappedTensorValueBody(builder) | - indexedTensorValueBody(builder) + dynamicTensor = mappedTensorValueBody(type) | + dynamicTensor = indexedTensorValueBody(type) ) - { return new ConstantNode(new TensorValue(builder.build())); } + { return new TensorFunctionNode(dynamicTensor); } } -void mappedTensorValueBody(Tensor.Builder builder) : {} +DynamicTensor mappedTensorValueBody(TensorType type) : +{ + java.util.Map cells = new LinkedHashMap(); +} { <LCURLY> - ( tensorCell(builder.cell()))* - ( <COMMA> tensorCell(builder.cell()))* + ( tensorCell(type, cells))* + ( <COMMA> tensorCell(type, cells))* <RCURLY> + { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } } -void indexedTensorValueBody(Tensor.Builder builder) : +DynamicTensor indexedTensorValueBody(TensorType type) : { - IndexedTensor.BoundBuilder indexedBuilder; - long index = 0; - double value; + List cells = new ArrayList(); + ExpressionNode value; } { - { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) - throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " + - "bound tensors, not " + builder.type()); - indexedBuilder = (IndexedTensor.BoundBuilder)builder; - } <LSQUARE> - ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* - ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* + ( value = expression() { cells.add(value); } )* + ( <COMMA> value = expression() { cells.add(value); } )* <RSQUARE> + { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } } -void tensorCell(Tensor.Builder.CellBuilder cellBuilder) : +void tensorCell(TensorType type, java.util.Map cells) : { - double value; + ExpressionNode value; + TensorAddress.Builder addressBuilder = new TensorAddress.Builder(type); } { <LCURLY> - ( labelAndDimension(cellBuilder))* - ( <COMMA> labelAndDimension(cellBuilder))* + ( labelAndDimension(addressBuilder))* + ( <COMMA> labelAndDimension(addressBuilder))* <RCURLY> - <COLON> value = doubleNumber() { cellBuilder.value(value); } + <COLON> value = expression() { cells.put(addressBuilder.build(), value); } } -void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) : +void labelAndDimension(TensorAddress.Builder addressBuilder) : { String dimension, label; } { dimension = identifier() <COLON> label = tag() - { cellBuilder.label(dimension, label); } + { addressBuilder.add(dimension, label); } }
\ No newline at end of file |