diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 12:45:20 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-11-26 12:45:20 +0200 |
commit | 2b59a443c1c37dcdcc5d77fe13b93b5ce383fee2 (patch) | |
tree | 8e0c3e037cbc96f83bc975c0d152b10ab44e6c41 | |
parent | f9da8909ad49e2bb494dd445344f429dc82fabce (diff) |
Parse@ tensor value expressions
5 files changed, 102 insertions, 63 deletions
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 8d7bf4f9f14..d5970a4b69e 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -872,25 +872,25 @@ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arg()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode function()", "public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorFunction()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMap()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorReduce()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorReduceComposites()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorJoin()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRename()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorConcat()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerate()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerateBody(com.yahoo.tensor.TensorType)", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRange()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorDiag()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRandom()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorL1Normalize()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorL2Normalize()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMatmul()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorSoftmax()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorXwPlusB()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmax()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmin()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMap()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduce()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduceComposites()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorJoin()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRename()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorConcat()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorGenerate()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorGenerateBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRange()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorDiag()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRandom()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", "public final com.yahoo.tensor.TensorType tensorType()", @@ -909,11 +909,13 @@ "public final java.util.List tagCommaLeadingList()", "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)", "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)", "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)", "public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)", "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)", + "public final void labelAndDimensionValues(java.util.List)", + "public final java.util.List valueAddress()", "public void <init>(java.io.InputStream)", "public void <init>(java.io.InputStream, java.lang.String)", "public void ReInit(java.io.InputStream)", diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 01eed897bfd..3e9649cd9c6 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -18,7 +18,6 @@ PARSER_BEGIN(RankingExpressionParser) package com.yahoo.searchlib.rankingexpression.parser; import com.yahoo.searchlib.rankingexpression.rule.*; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.*; @@ -231,26 +230,28 @@ TruthOperator comparator() : { } ExpressionNode value() : { - ExpressionNode ret; + ExpressionNode value; boolean neg = false; boolean not = false; + List valueAddress; } { ( [ <NOT> { not = true; } ] [ LOOKAHEAD(2) <SUB> { neg = true; } ] - ( ret = constantPrimitive() | - LOOKAHEAD(2) ret = ifExpression() | - LOOKAHEAD(4) ret = function() | - ret = feature() | - ret = legacyQueryFeature() | - ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) + ( value = constantPrimitive() | + LOOKAHEAD(2) value = ifExpression() | + LOOKAHEAD(4) value = function() | + value = feature() | + value = legacyQueryFeature() | + ( <LBRACE> value = expression() <RBRACE> { value = new EmbracedNode(value); } ) ) ) + [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Value(TensorFunctionNode.wrap(value), valueAddress)); } ] { - ret = not ? new NotNode(ret) : ret; - ret = neg ? new NegativeNode(ret) : ret; - return ret; + value = not ? new NotNode(value) : value; + value = neg ? new NegativeNode(value) : value; + return value; } } @@ -323,12 +324,12 @@ List<ExpressionNode> args() : { return arguments; } } -// TODO: Replace use of this for macro arguments with value() +// TODO: Replace use of this for function arguments with value() // For that to work with the current search execution framework -// we need to generate another macro for the argument such that we can replace -// instances of the argument with the reference to that macro in the same way +// we need to generate another function for the argument such that we can replace +// instances of the argument with the reference to that function in the same way // as we replace by constants/names today (this can make for some fun combinatorial explosion). -// Simon also points out that we should stop doing macro expansion in the toString of a macro. +// We should also stop doing function expansion in the toString of a function. // - Jon 2014-05-02 ExpressionNode arg() : { @@ -368,9 +369,9 @@ FunctionNode scalarOrTensorFunction() : ) } -ExpressionNode tensorFunction() : +TensorFunctionNode tensorFunction() : { - ExpressionNode tensorExpression; + TensorFunctionNode tensorExpression; } { ( @@ -395,7 +396,7 @@ ExpressionNode tensorFunction() : { return tensorExpression; } } -ExpressionNode tensorMap() : +TensorFunctionNode tensorMap() : { ExpressionNode tensor; LambdaFunctionNode doubleMapper; @@ -403,10 +404,10 @@ ExpressionNode tensorMap() : { <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor), - doubleMapper.asDoubleUnaryOperator())); } + doubleMapper.asDoubleUnaryOperator())); } } -ExpressionNode tensorReduce() : +TensorFunctionNode tensorReduce() : { ExpressionNode tensor; Reduce.Aggregator aggregator; @@ -417,7 +418,7 @@ ExpressionNode tensorReduce() : { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } -ExpressionNode tensorReduceComposites() : +TensorFunctionNode tensorReduceComposites() : { ExpressionNode tensor; Reduce.Aggregator aggregator; @@ -429,7 +430,7 @@ ExpressionNode tensorReduceComposites() : { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } -ExpressionNode tensorJoin() : +TensorFunctionNode tensorJoin() : { ExpressionNode tensor1, tensor2; LambdaFunctionNode doubleJoiner; @@ -441,7 +442,7 @@ ExpressionNode tensorJoin() : doubleJoiner.asDoubleBinaryOperator())); } } -ExpressionNode tensorRename() : +TensorFunctionNode tensorRename() : { ExpressionNode tensor; List<String> fromDimensions, toDimensions; @@ -454,7 +455,7 @@ ExpressionNode tensorRename() : { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); } } -ExpressionNode tensorConcat() : +TensorFunctionNode tensorConcat() : { ExpressionNode tensor1, tensor2; String dimension; @@ -466,10 +467,10 @@ ExpressionNode tensorConcat() : dimension)); } } -ExpressionNode tensorGenerate() : +TensorFunctionNode tensorGenerate() : { TensorType type; - ExpressionNode expression; + TensorFunctionNode expression; } { <TENSOR> type = tensorType() @@ -480,7 +481,7 @@ ExpressionNode tensorGenerate() : { return expression; } } -ExpressionNode tensorGenerateBody(TensorType type) : +TensorFunctionNode tensorGenerateBody(TensorType type) : { ExpressionNode generator; } @@ -489,7 +490,7 @@ ExpressionNode tensorGenerateBody(TensorType type) : { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } } -ExpressionNode tensorRange() : +TensorFunctionNode tensorRange() : { TensorType type; } @@ -498,7 +499,7 @@ ExpressionNode tensorRange() : { return new TensorFunctionNode(new Range(type)); } } -ExpressionNode tensorDiag() : +TensorFunctionNode tensorDiag() : { TensorType type; } @@ -507,7 +508,7 @@ ExpressionNode tensorDiag() : { return new TensorFunctionNode(new Diag(type)); } } -ExpressionNode tensorRandom() : +TensorFunctionNode tensorRandom() : { TensorType type; } @@ -516,7 +517,7 @@ ExpressionNode tensorRandom() : { return new TensorFunctionNode(new Random(type)); } } -ExpressionNode tensorL1Normalize() : +TensorFunctionNode tensorL1Normalize() : { ExpressionNode tensor; String dimension; @@ -526,7 +527,7 @@ ExpressionNode tensorL1Normalize() : { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorL2Normalize() : +TensorFunctionNode tensorL2Normalize() : { ExpressionNode tensor; String dimension; @@ -536,7 +537,7 @@ ExpressionNode tensorL2Normalize() : { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorMatmul() : +TensorFunctionNode tensorMatmul() : { ExpressionNode tensor1, tensor2; String dimension; @@ -548,7 +549,7 @@ ExpressionNode tensorMatmul() : dimension)); } } -ExpressionNode tensorSoftmax() : +TensorFunctionNode tensorSoftmax() : { ExpressionNode tensor; String dimension; @@ -558,7 +559,7 @@ ExpressionNode tensorSoftmax() : { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorXwPlusB() : +TensorFunctionNode tensorXwPlusB() : { ExpressionNode tensor1, tensor2, tensor3; String dimension; @@ -574,7 +575,7 @@ ExpressionNode tensorXwPlusB() : dimension)); } } -ExpressionNode tensorArgmax() : +TensorFunctionNode tensorArgmax() : { ExpressionNode tensor; String dimension; @@ -584,7 +585,7 @@ ExpressionNode tensorArgmax() : { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorArgmin() : +TensorFunctionNode tensorArgmin() : { ExpressionNode tensor; String dimension; @@ -811,20 +812,20 @@ ConstantNode constantPrimitive() : ( <INTEGER> { value = token.image; } | <FLOAT> { value = token.image; } | <STRING> { value = token.image; } ) - { return new ConstantNode(Value.parse(sign + value),sign + value); } + { return new ConstantNode(com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + value),sign + value); } } -Value primitiveValue() : +com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue() : { String sign = ""; } { ( <SUB> { sign = "-";} ) ? ( <INTEGER> | <FLOAT> | <STRING> ) - { return Value.parse(sign + token.image); } + { return com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + token.image); } } -ExpressionNode tensorValueBody(TensorType type) : +TensorFunctionNode tensorValueBody(TensorType type) : { DynamicTensor dynamicTensor; } @@ -882,4 +883,34 @@ void labelAndDimension(TensorAddress.Builder addressBuilder) : { dimension = identifier() <COLON> label = tag() { addressBuilder.add(dimension, label); } +} + +void labelAndDimensionValues(List addressValues) : +{ + String dimension, label; +} +{ + dimension = identifier() <COLON> label = tag() + { addressValues.add(new Value.DimensionValue(dimension, label)); } +} + +/** A tensor address (possibly on short form) represented as a list because the tensor type is not available */ +List valueAddress() : +{ + List dimensionValues = new ArrayList(); + String label; +} +{ + ( + ( <LSQUARE> ( <INTEGER> { dimensionValues.add(new Value.DimensionValue(token.image)); } ) <RSQUARE> ) + | + LOOKAHEAD(3) ( <LCURLY> label = tag() { dimensionValues.add(new Value.DimensionValue(label)); } <RCURLY> ) + | + ( <LCURLY> + ( labelAndDimensionValues(dimensionValues))* + ( <COMMA> labelAndDimensionValues(dimensionValues))* + <RCURLY> + ) + ) + { return dimensionValues;} }
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 6064035702e..30e35139edc 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -317,6 +317,12 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])"); tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)"); + // tensor value + tester.assertEvaluates("3.0", "tensor0{x:1}", "{ {x:0}:1, {x:1}:3 }"); + tester.assertEvaluates("1.2", "tensor0{key:foo,x:0}", true, "{ {key:foo,x:0}:1.2, {key:bar,x:0}:3 }"); + tester.assertEvaluates("3.0", "tensor0{bar}", true, "{ {x:foo}:1, {x:bar}:3 }"); + tester.assertEvaluates("3.3", "tensor0[2]", "tensor(values[4]):[1.1, 2.2, 3.3, 4.4]]"); + // composite functions tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); tester.assertEvaluates("{ {x:0}:0.31622776601683794, {x:1}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 30f7185959c..3812dd26370 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -110,7 +110,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return Long.parseLong(labels[i]); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i); + throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java index 3113d48335a..0a881c0a290 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java @@ -82,7 +82,7 @@ public class Value extends PrimitiveTensorFunction { if (cellAddress.get(0).index().isPresent()) return "[" + cellAddress.get(0).index().get() + "]"; else - return "{" + cellAddress.get(0).index().get() + "}"; + return "{" + cellAddress.get(0).label() + "}"; } else { return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}"; |