summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/test/derived/tensor/attributes.cfg21
-rw-r--r--config-model/src/test/derived/tensor/documenttypes.cfg5
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg30
-rw-r--r--config-model/src/test/derived/tensor/summary.cfg4
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd23
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java3
-rw-r--r--searchlib/abi-spec.json4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java73
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java2
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java10
-rw-r--r--vespajlib/abi-spec.json21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java22
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java7
17 files changed, 227 insertions, 70 deletions
diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg
index 0c556aad868..2e0a207d249 100644
--- a/config-model/src/test/derived/tensor/attributes.cfg
+++ b/config-model/src/test/derived/tensor/attributes.cfg
@@ -82,3 +82,24 @@ attribute[].upperbound 9223372036854775807
attribute[].densepostinglistthreshold 0.4
attribute[].tensortype "tensor<float>(x[10])"
attribute[].imported false
+attribute[].name "f6"
+attribute[].datatype FLOAT
+attribute[].collectiontype SINGLE
+attribute[].removeifzero false
+attribute[].createifnonexistent false
+attribute[].fastsearch false
+attribute[].huge false
+attribute[].ismutable false
+attribute[].sortascending true
+attribute[].sortfunction UCA
+attribute[].sortstrength PRIMARY
+attribute[].sortlocale ""
+attribute[].enablebitvectors false
+attribute[].enableonlybitvector false
+attribute[].fastaccess false
+attribute[].arity 8
+attribute[].lowerbound -9223372036854775808
+attribute[].upperbound 9223372036854775807
+attribute[].densepostinglistthreshold 0.4
+attribute[].tensortype ""
+attribute[].imported false
diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg
index af1748e484e..72fae572b76 100644
--- a/config-model/src/test/derived/tensor/documenttypes.cfg
+++ b/config-model/src/test/derived/tensor/documenttypes.cfg
@@ -40,6 +40,10 @@ documenttype[].datatype[].sstruct.field[].name "f5"
documenttype[].datatype[].sstruct.field[].id 329055840
documenttype[].datatype[].sstruct.field[].datatype 21
documenttype[].datatype[].sstruct.field[].detailedtype "tensor<float>(x[10])"
+documenttype[].datatype[].sstruct.field[].name "f6"
+documenttype[].datatype[].sstruct.field[].id 596352344
+documenttype[].datatype[].sstruct.field[].datatype 1
+documenttype[].datatype[].sstruct.field[].detailedtype ""
documenttype[].datatype[].id -1903234535
documenttype[].datatype[].type STRUCT
documenttype[].datatype[].array.element.id 0
@@ -60,3 +64,4 @@ documenttype[].fieldsets{[document]}.fields[] "f2"
documenttype[].fieldsets{[document]}.fields[] "f3"
documenttype[].fieldsets{[document]}.fields[] "f4"
documenttype[].fieldsets{[document]}.fields[] "f5"
+documenttype[].fieldsets{[document]}.fields[] "f6"
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 1ce9227d323..617901130a6 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -80,3 +80,33 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f4"
rankprofile[].fef.property[].value "tensor(x[10],y[20])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
+rankprofile[].name "profile5"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "rankingExpression(firstphase)"
+rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(tensor<float>(d0[1],x[2]):{{d0:0,x:0}:attribute(f6),{d0:0,x:1}:reduce(attribute(f5), sum)}, sum)"
+rankprofile[].fef.property[].name "vespa.type.attribute.f2"
+rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].name "vespa.type.attribute.f3"
+rankprofile[].fef.property[].value "tensor(x{})"
+rankprofile[].fef.property[].name "vespa.type.attribute.f4"
+rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].name "vespa.type.attribute.f5"
+rankprofile[].fef.property[].value "tensor<float>(x[10])"
+rankprofile[].name "profile6"
+rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript"
+rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)"
+rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type"
+rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "rankingExpression(firstphase)"
+rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(tensor<float>(d0[1],x[2]):{{d0:0,x:0}:attribute(f6),{d0:0,x:1}:reduce(rankingExpression(joinedtensors), sum)}, sum)"
+rankprofile[].fef.property[].name "vespa.type.attribute.f2"
+rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].name "vespa.type.attribute.f3"
+rankprofile[].fef.property[].value "tensor(x{})"
+rankprofile[].fef.property[].name "vespa.type.attribute.f4"
+rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].name "vespa.type.attribute.f5"
+rankprofile[].fef.property[].value "tensor<float>(x[10])"
diff --git a/config-model/src/test/derived/tensor/summary.cfg b/config-model/src/test/derived/tensor/summary.cfg
index 903b6033297..fb32eacbb4c 100644
--- a/config-model/src/test/derived/tensor/summary.cfg
+++ b/config-model/src/test/derived/tensor/summary.cfg
@@ -15,7 +15,7 @@ classes[].fields[].name "summaryfeatures"
classes[].fields[].type "featuredata"
classes[].fields[].name "documentid"
classes[].fields[].type "longstring"
-classes[].id 193983608
+classes[].id 1476352352
classes[].name "attributeprefetch"
classes[].fields[].name "f2"
classes[].fields[].type "tensor"
@@ -25,6 +25,8 @@ classes[].fields[].name "f4"
classes[].fields[].type "tensor"
classes[].fields[].name "f5"
classes[].fields[].type "tensor"
+classes[].fields[].name "f6"
+classes[].fields[].type "float"
classes[].fields[].name "rankfeatures"
classes[].fields[].type "featuredata"
classes[].fields[].name "summaryfeatures"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index b31352a2105..13727d1ec49 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -17,6 +17,9 @@ search tensor {
field f5 type tensor<float>(x[10]) {
indexing: attribute | summary
}
+ field f6 type float {
+ indexing: attribute
+ }
}
rank-profile profile1 {
@@ -55,4 +58,24 @@ search tensor {
}
+ rank-profile profile5 {
+
+ first-phase {
+ expression: sum(tensor<float>(d0[1],x[2]):[[attribute(f6), sum(attribute(f5))]])
+ }
+
+ }
+
+ rank-profile profile6 {
+
+ first-phase {
+ expression: sum(tensor<float>(d0[1],x[2]):[[attribute(f6), sum(joinedtensors())]])
+ }
+
+ function joinedtensors() {
+ expression: tensor(i[10])(i) * attribute(f4)
+ }
+
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index fc895b07d53..01fd7ee55bd 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -9,9 +9,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.VariableTensor;
-import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
import java.util.List;
@@ -53,7 +50,7 @@ public class Const extends IntermediateOperation {
} else {
expressionNode = new ReferenceNode(Reference.simple("constant", vespaName()));
}
- return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode);
+ return new TensorFunctionNode.ExpressionTensorFunction(expressionNode);
}
/** Constant names are prefixed by "modelName_" to avoid name conflicts between models */
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 26b376cce1c..87a3f1a8e66 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,7 +3,6 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
-import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -74,7 +73,7 @@ public abstract class IntermediateOperation {
if (function == null) {
if (isConstant()) {
ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName()));
- function = new TensorFunctionNode.TensorFunctionExpressionNode(constant);
+ function = new TensorFunctionNode.ExpressionTensorFunction(constant);
} else if (outputs.size() > 1) {
rankingExpressionFunction = lazyGetFunction();
function = new VariableTensor(rankingExpressionFunctionName(), type.type());
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 1258601a2d1..8d7bf4f9f14 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -1580,7 +1580,7 @@
],
"fields": []
},
- "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode": {
+ "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction": {
"superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
"interfaces": [],
"attributes": [
@@ -1612,7 +1612,7 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
"public static java.util.Map wrap(java.util.Map)",
"public static java.util.List wrap(java.util.List)"
],
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index e6e49e07c34..4ffd40f00f7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.PrimitiveTensorFunction;
+import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
@@ -21,7 +22,6 @@ import java.util.Deque;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
-import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -49,8 +49,8 @@ public class TensorFunctionNode extends CompositeNode {
}
private ExpressionNode toExpressionNode(TensorFunction f) {
- if (f instanceof TensorFunctionExpressionNode)
- return ((TensorFunctionExpressionNode)f).expression;
+ if (f instanceof ExpressionTensorFunction)
+ return ((ExpressionTensorFunction)f).expression;
else
return new TensorFunctionNode(f);
}
@@ -58,7 +58,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
List<TensorFunction> wrappedChildren = children.stream()
- .map(TensorFunctionExpressionNode::new)
+ .map(ExpressionTensorFunction::new)
.collect(Collectors.toList());
return new TensorFunctionNode(function.withArguments(wrappedChildren));
}
@@ -66,7 +66,7 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) {
// Serialize as primitive
- return string.append(function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this)));
+ return string.append(function.toPrimitive().toString(new ExpressionToStringContext(context, path, this)));
}
@Override
@@ -77,29 +77,29 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorValue(function.evaluate(context));
}
- public static TensorFunctionExpressionNode wrap(ExpressionNode node) {
- return new TensorFunctionExpressionNode(node);
+ public static ExpressionTensorFunction wrap(ExpressionNode node) {
+ return new ExpressionTensorFunction(node);
}
- public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) {
- Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>();
+ public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>();
for (var entry : nodes.entrySet())
- closures.put(entry.getKey(), new ExpressionClosure(entry.getValue()));
- return closures;
+ functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue()));
+ return functions;
}
- public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) {
- List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>();
+ public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) {
+ List<ScalarFunction> functions = new ArrayList<>();
for (var entry : nodes)
- closures.add(new ExpressionClosure(entry));
- return closures;
+ functions.add(new ExpressionScalarFunction(entry));
+ return functions;
}
- private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> {
+ private static class ExpressionScalarFunction implements ScalarFunction {
private final ExpressionNode expression;
- public ExpressionClosure(ExpressionNode expression) {
+ public ExpressionScalarFunction(ExpressionNode expression) {
this.expression = expression;
}
@@ -110,7 +110,18 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public String toString() {
- return expression.toString();
+ return toString(ExpressionToStringContext.empty);
+ }
+
+ @Override
+ public String toString(ToStringContext c) {
+ if (c instanceof ExpressionToStringContext) {
+ ExpressionToStringContext context = (ExpressionToStringContext) c;
+ return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString();
+ }
+ else {
+ return expression.toString();
+ }
}
}
@@ -119,12 +130,12 @@ public class TensorFunctionNode extends CompositeNode {
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
- public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction {
+ public static class ExpressionTensorFunction extends PrimitiveTensorFunction {
/** An expression which produces a tensor */
private final ExpressionNode expression;
- public TensorFunctionExpressionNode(ExpressionNode expression) {
+ public ExpressionTensorFunction(ExpressionNode expression) {
this.expression = expression;
}
@@ -132,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode {
public List<TensorFunction> arguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
- .map(TensorFunctionExpressionNode::new)
+ .map(ExpressionTensorFunction::new)
.collect(Collectors.toList());
else
return Collections.emptyList();
@@ -142,9 +153,9 @@ public class TensorFunctionNode extends CompositeNode {
public TensorFunction withArguments(List<TensorFunction> arguments) {
if (arguments.size() == 0) return this;
List<ExpressionNode> unwrappedChildren = arguments.stream()
- .map(arg -> ((TensorFunctionExpressionNode)arg).expression)
+ .map(arg -> ((ExpressionTensorFunction)arg).expression)
.collect(Collectors.toList());
- return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren));
+ return new ExpressionTensorFunction(((CompositeNode)expression).setChildren(unwrappedChildren));
}
@Override
@@ -163,13 +174,13 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public String toString() {
- return toString(ExpressionNodeToStringContext.empty);
+ return toString(ExpressionToStringContext.empty);
}
@Override
public String toString(ToStringContext c) {
- if (c instanceof ExpressionNodeToStringContext) {
- ExpressionNodeToStringContext context = (ExpressionNodeToStringContext) c;
+ if (c instanceof ExpressionToStringContext) {
+ ExpressionToStringContext context = (ExpressionToStringContext) c;
return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString();
}
else {
@@ -180,17 +191,17 @@ public class TensorFunctionNode extends CompositeNode {
}
/** Allows passing serialization context arguments through TensorFunctions */
- private static class ExpressionNodeToStringContext implements ToStringContext {
+ private static class ExpressionToStringContext implements ToStringContext {
final SerializationContext context;
final Deque<String> path;
final CompositeNode parent;
- public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(),
- null,
- null);
+ public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(),
+ null,
+ null);
- public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
+ public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) {
this.context = context;
this.path = path;
this.parent = parent;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
index 6d687b015f1..9a38b5efc1f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
- TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1);
+ TensorFunctionNode.ExpressionTensorFunction expression = TensorFunctionNode.wrap(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index a41f24b3b8a..e7024b87452 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -1,20 +1,16 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression;
-import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
-import com.yahoo.searchlib.rankingexpression.rule.ConstantNode;
-import com.yahoo.searchlib.rankingexpression.rule.Function;
import com.yahoo.searchlib.rankingexpression.rule.IfNode;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.FunctionNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode;
import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.TensorFunction;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
@@ -65,7 +61,7 @@ public class RankingExpressionTestCase {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
- Reduce sum = new Reduce(new TensorFunctionNode.TensorFunctionExpressionNode(product), Reduce.Aggregator.sum);
+ Reduce sum = new Reduce(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
RankingExpression expected = new RankingExpression("sum(input * constant)");
@@ -156,9 +152,9 @@ public class RankingExpressionTestCase {
"xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)");
assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}",
"tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }");
- assertSerialization("tensor(x[3]):[1.0,2.0,3]",
+ assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3}",
"tensor(x[3]):[1.0, 2.0, 3]");
- assertSerialization("tensor(x[3]):[1.0,reduce(tensor0 * tensor1, sum),3]",
+ assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:reduce(tensor0 * tensor1, sum),{x:2}:3}",
"tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]");
}
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 6a93a17a8c1..47b066b15a6 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -707,6 +707,7 @@
"final"
],
"methods": [
+ "public static com.yahoo.tensor.DimensionSizes of(com.yahoo.tensor.TensorType)",
"public long size(int)",
"public int dimensions()",
"public long totalSize()",
@@ -820,7 +821,9 @@
"abstract"
],
"methods": [
+ "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)",
"public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)",
+ "public com.yahoo.tensor.TensorAddress toAddress()",
"public long[] indexesCopy()",
"public long[] indexesForReading()",
"public java.util.List toList()",
@@ -1603,6 +1606,7 @@
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
"public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.Map)",
"public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.List)"
],
@@ -1832,6 +1836,23 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.ScalarFunction": {
+ "superClass": "java.lang.Object",
+ "interfaces": [
+ "java.util.function.Function"
+ ],
+ "attributes": [
+ "public",
+ "interface",
+ "abstract"
+ ],
+ "methods": [
+ "public abstract java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public bridge synthetic java.lang.Object apply(java.lang.Object)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.ScalarFunctions$Abs": {
"superClass": "java.lang.Object",
"interfaces": [
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
index c0d817459d0..d81c02fb75f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java
@@ -18,6 +18,21 @@ public final class DimensionSizes {
}
/**
+ * Create sizes from a type containing bound indexed dimensions only.
+ *
+ * @throws IllegalStateException if the type contains dimensions which are not bound and indexed
+ */
+ public static DimensionSizes of(TensorType type) {
+ Builder b = new Builder(type.rank());
+ for (int i = 0; i < type.rank(); i++) {
+ if ( type.dimensions().get(i).type() != TensorType.Dimension.Type.indexedBound)
+ throw new IllegalArgumentException(type + " contains dimensions without a size");
+ b.set(i, type.dimensions().get(i).size().get());
+ }
+ return b.build();
+ }
+
+ /**
* Returns the length of this in the nth dimension
*
* @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index 15476567fb2..176ddfefc13 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -758,6 +758,15 @@ public abstract class IndexedTensor implements Tensor {
protected final long[] indexes;
+ /**
+ * Create indexes from a type containing bound indexed dimensions only.
+ *
+ * @throws IllegalStateException if the type contains dimensions which are not bound and indexed
+ */
+ public static Indexes of(TensorType type) {
+ return of(DimensionSizes.of(type));
+ }
+
public static Indexes of(DimensionSizes sizes) {
return of(sizes, sizes);
}
@@ -824,7 +833,7 @@ public abstract class IndexedTensor implements Tensor {
}
/** Returns the address of the current position of these indexes */
- private TensorAddress toAddress() {
+ public TensorAddress toAddress() {
return TensorAddress.of(indexes);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index 9ce2496c65b..b8b644f8b49 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -46,21 +46,28 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
TensorType type() { return type; }
+ @Override
+ public String toString(ToStringContext context) {
+ return type().toString() + ":" + contentToString(context);
+ }
+
+ abstract String contentToString(ToStringContext context);
+
/** Creates a dynamic tensor function. The cell addresses must match the type. */
- public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) {
+ public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
return new MappedDynamicTensor(type, cells);
}
/** Creates a dynamic tensor function for a bound, indexed tensor */
- public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) {
+ public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) {
return new IndexedDynamicTensor(type, cells);
}
private static class MappedDynamicTensor extends DynamicTensor {
- private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells;
+ private final ImmutableMap<TensorAddress, ScalarFunction> cells;
- MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) {
+ MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
super(type);
this.cells = ImmutableMap.copyOf(cells);
}
@@ -74,11 +81,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return type().toString() + ":" + contentToString();
- }
-
- private String contentToString() {
+ String contentToString(ToStringContext context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.values().iterator().next() + "}";
@@ -86,7 +89,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
StringBuilder b = new StringBuilder("{");
for (var cell : cells.entrySet()) {
- b.append(cell.getKey().toString(type())).append(":").append(cell.getValue());
+ b.append(cell.getKey().toString(type())).append(":").append(cell.getValue().toString(context));
b.append(",");
}
if (b.length() > 1)
@@ -100,9 +103,9 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
private static class IndexedDynamicTensor extends DynamicTensor {
- private final List<Function<EvaluationContext<?>, Double>> cells;
+ private final List<ScalarFunction> cells;
- IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) {
+ IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) {
super(type);
if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
@@ -119,24 +122,22 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) {
- return type().toString() + ":" + contentToString();
- }
-
- private String contentToString() {
+ String contentToString(ToStringContext context) {
if (type().dimensions().isEmpty()) {
if (cells.isEmpty()) return "{}";
return "{" + cells.get(0) + "}";
}
- StringBuilder b = new StringBuilder("[");
+ IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type());
+ StringBuilder b = new StringBuilder("{");
for (var cell : cells) {
- b.append(cell);
+ indexes.next();
+ b.append(indexes.toAddress().toString(type())).append(":").append(cell.toString(context));
b.append(",");
}
if (b.length() > 1)
b.setLength(b.length() - 1);
- b.append("]");
+ b.append("}");
return b.toString();
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
new file mode 100644
index 00000000000..c6a244b64df
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -0,0 +1,22 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.evaluation.EvaluationContext;
+
+import java.util.function.Function;
+
+/**
+ * A function which returns a scalar
+ *
+ * @author bratseth
+ */
+public interface ScalarFunction extends Function<EvaluationContext<?>, Double> {
+
+ @Override
+ Double apply(EvaluationContext<?> context);
+
+ default String toString(ToStringContext context) {
+ return toString();
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 82652fb0e5d..925da9d3c89 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -24,15 +24,17 @@ public class DynamicTensorTestCase {
DynamicTensor t1 = DynamicTensor.from(dense,
List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
+ assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
TensorType sparse = TensorType.fromSpec("tensor(x{})");
DynamicTensor t2 = DynamicTensor.from(sparse,
Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
+ assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
- private static class Constant implements Function<EvaluationContext<?>, Double> {
+ private static class Constant implements ScalarFunction {
private final double value;
@@ -41,6 +43,9 @@ public class DynamicTensorTestCase {
@Override
public Double apply(EvaluationContext<?> evaluationContext) { return value; }
+ @Override
+ public String toString() { return String.valueOf(value); }
+
}
}