diff options
17 files changed, 227 insertions, 70 deletions
diff --git a/config-model/src/test/derived/tensor/attributes.cfg b/config-model/src/test/derived/tensor/attributes.cfg index 0c556aad868..2e0a207d249 100644 --- a/config-model/src/test/derived/tensor/attributes.cfg +++ b/config-model/src/test/derived/tensor/attributes.cfg @@ -82,3 +82,24 @@ attribute[].upperbound 9223372036854775807 attribute[].densepostinglistthreshold 0.4 attribute[].tensortype "tensor<float>(x[10])" attribute[].imported false +attribute[].name "f6" +attribute[].datatype FLOAT +attribute[].collectiontype SINGLE +attribute[].removeifzero false +attribute[].createifnonexistent false +attribute[].fastsearch false +attribute[].huge false +attribute[].ismutable false +attribute[].sortascending true +attribute[].sortfunction UCA +attribute[].sortstrength PRIMARY +attribute[].sortlocale "" +attribute[].enablebitvectors false +attribute[].enableonlybitvector false +attribute[].fastaccess false +attribute[].arity 8 +attribute[].lowerbound -9223372036854775808 +attribute[].upperbound 9223372036854775807 +attribute[].densepostinglistthreshold 0.4 +attribute[].tensortype "" +attribute[].imported false diff --git a/config-model/src/test/derived/tensor/documenttypes.cfg b/config-model/src/test/derived/tensor/documenttypes.cfg index af1748e484e..72fae572b76 100644 --- a/config-model/src/test/derived/tensor/documenttypes.cfg +++ b/config-model/src/test/derived/tensor/documenttypes.cfg @@ -40,6 +40,10 @@ documenttype[].datatype[].sstruct.field[].name "f5" documenttype[].datatype[].sstruct.field[].id 329055840 documenttype[].datatype[].sstruct.field[].datatype 21 documenttype[].datatype[].sstruct.field[].detailedtype "tensor<float>(x[10])" +documenttype[].datatype[].sstruct.field[].name "f6" +documenttype[].datatype[].sstruct.field[].id 596352344 +documenttype[].datatype[].sstruct.field[].datatype 1 +documenttype[].datatype[].sstruct.field[].detailedtype "" documenttype[].datatype[].id -1903234535 documenttype[].datatype[].type STRUCT documenttype[].datatype[].array.element.id 0 @@ -60,3 +64,4 @@ documenttype[].fieldsets{[document]}.fields[] "f2" documenttype[].fieldsets{[document]}.fields[] "f3" documenttype[].fieldsets{[document]}.fields[] "f4" documenttype[].fieldsets{[document]}.fields[] "f5" +documenttype[].fieldsets{[document]}.fields[] "f6" diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 1ce9227d323..617901130a6 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -80,3 +80,33 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f4" rankprofile[].fef.property[].value "tensor(x[10],y[20])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].name "profile5" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(tensor<float>(d0[1],x[2]):{{d0:0,x:0}:attribute(f6),{d0:0,x:1}:reduce(attribute(f5), sum)}, sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].name "profile6" +rankprofile[].fef.property[].name "rankingExpression(joinedtensors).rankingScript" +rankprofile[].fef.property[].value "tensor(i[10])(i) * attribute(f4)" +rankprofile[].fef.property[].name "rankingExpression(joinedtensors).type" +rankprofile[].fef.property[].value "tensor(i[10],x[10],y[20])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(tensor<float>(d0[1],x[2]):{{d0:0,x:0}:attribute(f6),{d0:0,x:1}:reduce(rankingExpression(joinedtensors), sum)}, sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor<float>(x[10])" diff --git a/config-model/src/test/derived/tensor/summary.cfg b/config-model/src/test/derived/tensor/summary.cfg index 903b6033297..fb32eacbb4c 100644 --- a/config-model/src/test/derived/tensor/summary.cfg +++ b/config-model/src/test/derived/tensor/summary.cfg @@ -15,7 +15,7 @@ classes[].fields[].name "summaryfeatures" classes[].fields[].type "featuredata" classes[].fields[].name "documentid" classes[].fields[].type "longstring" -classes[].id 193983608 +classes[].id 1476352352 classes[].name "attributeprefetch" classes[].fields[].name "f2" classes[].fields[].type "tensor" @@ -25,6 +25,8 @@ classes[].fields[].name "f4" classes[].fields[].type "tensor" classes[].fields[].name "f5" classes[].fields[].type "tensor" +classes[].fields[].name "f6" +classes[].fields[].type "float" classes[].fields[].name "rankfeatures" classes[].fields[].type "featuredata" classes[].fields[].name "summaryfeatures" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index b31352a2105..13727d1ec49 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -17,6 +17,9 @@ search tensor { field f5 type tensor<float>(x[10]) { indexing: attribute | summary } + field f6 type float { + indexing: attribute + } } rank-profile profile1 { @@ -55,4 +58,24 @@ search tensor { } + rank-profile profile5 { + + first-phase { + expression: sum(tensor<float>(d0[1],x[2]):[[attribute(f6), sum(attribute(f5))]]) + } + + } + + rank-profile profile6 { + + first-phase { + expression: sum(tensor<float>(d0[1],x[2]):[[attribute(f6), sum(joinedtensors())]]) + } + + function joinedtensors() { + expression: tensor(i[10])(i) * attribute(f4) + } + + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index fc895b07d53..01fd7ee55bd 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -9,9 +9,6 @@ import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; -import com.yahoo.tensor.TensorType; -import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; import java.util.List; @@ -53,7 +50,7 @@ public class Const extends IntermediateOperation { } else { expressionNode = new ReferenceNode(Reference.simple("constant", vespaName())); } - return new TensorFunctionNode.TensorFunctionExpressionNode(expressionNode); + return new TensorFunctionNode.ExpressionTensorFunction(expressionNode); } /** Constant names are prefixed by "modelName_" to avoid name conflicts between models */ diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 26b376cce1c..87a3f1a8e66 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,7 +3,6 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; -import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -74,7 +73,7 @@ public abstract class IntermediateOperation { if (function == null) { if (isConstant()) { ExpressionNode constant = new ReferenceNode(Reference.simple("constant", vespaName())); - function = new TensorFunctionNode.TensorFunctionExpressionNode(constant); + function = new TensorFunctionNode.ExpressionTensorFunction(constant); } else if (outputs.size() > 1) { rankingExpressionFunction = lazyGetFunction(); function = new VariableTensor(rankingExpressionFunctionName(), type.type()); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 1258601a2d1..8d7bf4f9f14 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -1580,7 +1580,7 @@ ], "fields": [] }, - "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode": { + "com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction": { "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", "interfaces": [], "attributes": [ @@ -1612,7 +1612,7 @@ "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", - "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", "public static java.util.Map wrap(java.util.Map)", "public static java.util.List wrap(java.util.List)" ], diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index e6e49e07c34..4ffd40f00f7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.PrimitiveTensorFunction; +import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -21,7 +22,6 @@ import java.util.Deque; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; -import java.util.function.Function; import java.util.stream.Collectors; /** @@ -49,8 +49,8 @@ public class TensorFunctionNode extends CompositeNode { } private ExpressionNode toExpressionNode(TensorFunction f) { - if (f instanceof TensorFunctionExpressionNode) - return ((TensorFunctionExpressionNode)f).expression; + if (f instanceof ExpressionTensorFunction) + return ((ExpressionTensorFunction)f).expression; else return new TensorFunctionNode(f); } @@ -58,7 +58,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { List<TensorFunction> wrappedChildren = children.stream() - .map(TensorFunctionExpressionNode::new) + .map(ExpressionTensorFunction::new) .collect(Collectors.toList()); return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @@ -66,7 +66,7 @@ public class TensorFunctionNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder string, SerializationContext context, Deque<String> path, CompositeNode parent) { // Serialize as primitive - return string.append(function.toPrimitive().toString(new ExpressionNodeToStringContext(context, path, this))); + return string.append(function.toPrimitive().toString(new ExpressionToStringContext(context, path, this))); } @Override @@ -77,29 +77,29 @@ public class TensorFunctionNode extends CompositeNode { return new TensorValue(function.evaluate(context)); } - public static TensorFunctionExpressionNode wrap(ExpressionNode node) { - return new TensorFunctionExpressionNode(node); + public static ExpressionTensorFunction wrap(ExpressionNode node) { + return new ExpressionTensorFunction(node); } - public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) { - Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>(); + public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>(); for (var entry : nodes.entrySet()) - closures.put(entry.getKey(), new ExpressionClosure(entry.getValue())); - return closures; + functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue())); + return functions; } - public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) { - List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>(); + public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) { + List<ScalarFunction> functions = new ArrayList<>(); for (var entry : nodes) - closures.add(new ExpressionClosure(entry)); - return closures; + functions.add(new ExpressionScalarFunction(entry)); + return functions; } - private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> { + private static class ExpressionScalarFunction implements ScalarFunction { private final ExpressionNode expression; - public ExpressionClosure(ExpressionNode expression) { + public ExpressionScalarFunction(ExpressionNode expression) { this.expression = expression; } @@ -110,7 +110,18 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString() { - return expression.toString(); + return toString(ExpressionToStringContext.empty); + } + + @Override + public String toString(ToStringContext c) { + if (c instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext) c; + return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString(); + } + else { + return expression.toString(); + } } } @@ -119,12 +130,12 @@ public class TensorFunctionNode extends CompositeNode { * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ - public static class TensorFunctionExpressionNode extends PrimitiveTensorFunction { + public static class ExpressionTensorFunction extends PrimitiveTensorFunction { /** An expression which produces a tensor */ private final ExpressionNode expression; - public TensorFunctionExpressionNode(ExpressionNode expression) { + public ExpressionTensorFunction(ExpressionNode expression) { this.expression = expression; } @@ -132,7 +143,7 @@ public class TensorFunctionNode extends CompositeNode { public List<TensorFunction> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() - .map(TensorFunctionExpressionNode::new) + .map(ExpressionTensorFunction::new) .collect(Collectors.toList()); else return Collections.emptyList(); @@ -142,9 +153,9 @@ public class TensorFunctionNode extends CompositeNode { public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() - .map(arg -> ((TensorFunctionExpressionNode)arg).expression) + .map(arg -> ((ExpressionTensorFunction)arg).expression) .collect(Collectors.toList()); - return new TensorFunctionExpressionNode(((CompositeNode)expression).setChildren(unwrappedChildren)); + return new ExpressionTensorFunction(((CompositeNode)expression).setChildren(unwrappedChildren)); } @Override @@ -163,13 +174,13 @@ public class TensorFunctionNode extends CompositeNode { @Override public String toString() { - return toString(ExpressionNodeToStringContext.empty); + return toString(ExpressionToStringContext.empty); } @Override public String toString(ToStringContext c) { - if (c instanceof ExpressionNodeToStringContext) { - ExpressionNodeToStringContext context = (ExpressionNodeToStringContext) c; + if (c instanceof ExpressionToStringContext) { + ExpressionToStringContext context = (ExpressionToStringContext) c; return expression.toString(new StringBuilder(),context.context, context.path, context.parent).toString(); } else { @@ -180,17 +191,17 @@ public class TensorFunctionNode extends CompositeNode { } /** Allows passing serialization context arguments through TensorFunctions */ - private static class ExpressionNodeToStringContext implements ToStringContext { + private static class ExpressionToStringContext implements ToStringContext { final SerializationContext context; final Deque<String> path; final CompositeNode parent; - public static final ExpressionNodeToStringContext empty = new ExpressionNodeToStringContext(new SerializationContext(), - null, - null); + public static final ExpressionToStringContext empty = new ExpressionToStringContext(new SerializationContext(), + null, + null); - public ExpressionNodeToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { + public ExpressionToStringContext(SerializationContext context, Deque<String> path, CompositeNode parent) { this.context = context; this.path = path; this.parent = parent; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 6d687b015f1..9a38b5efc1f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1); + TensorFunctionNode.ExpressionTensorFunction expression = TensorFunctionNode.wrap(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index a41f24b3b8a..e7024b87452 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -1,20 +1,16 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression; -import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; -import com.yahoo.searchlib.rankingexpression.rule.ConstantNode; -import com.yahoo.searchlib.rankingexpression.rule.Function; import com.yahoo.searchlib.rankingexpression.rule.IfNode; import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.FunctionNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode; import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.TensorFunction; import org.junit.Test; import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; @@ -65,7 +61,7 @@ public class RankingExpressionTestCase { ReferenceNode input = new ReferenceNode("input"); ReferenceNode constant = new ReferenceNode("constant"); ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant); - Reduce sum = new Reduce(new TensorFunctionNode.TensorFunctionExpressionNode(product), Reduce.Aggregator.sum); + Reduce sum = new Reduce(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum)); RankingExpression expected = new RankingExpression("sum(input * constant)"); @@ -156,9 +152,9 @@ public class RankingExpressionTestCase { "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}", "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }"); - assertSerialization("tensor(x[3]):[1.0,2.0,3]", + assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3}", "tensor(x[3]):[1.0, 2.0, 3]"); - assertSerialization("tensor(x[3]):[1.0,reduce(tensor0 * tensor1, sum),3]", + assertSerialization("tensor(x[3]):{{x:0}:1.0,{x:1}:reduce(tensor0 * tensor1, sum),{x:2}:3}", "tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]"); } diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 6a93a17a8c1..47b066b15a6 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -707,6 +707,7 @@ "final" ], "methods": [ + "public static com.yahoo.tensor.DimensionSizes of(com.yahoo.tensor.TensorType)", "public long size(int)", "public int dimensions()", "public long totalSize()", @@ -820,7 +821,9 @@ "abstract" ], "methods": [ + "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.TensorType)", "public static com.yahoo.tensor.IndexedTensor$Indexes of(com.yahoo.tensor.DimensionSizes)", + "public com.yahoo.tensor.TensorAddress toAddress()", "public long[] indexesCopy()", "public long[] indexesForReading()", "public java.util.List toList()", @@ -1603,6 +1606,7 @@ "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.Map)", "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.List)" ], @@ -1832,6 +1836,23 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.ScalarFunction": { + "superClass": "java.lang.Object", + "interfaces": [ + "java.util.function.Function" + ], + "attributes": [ + "public", + "interface", + "abstract" + ], + "methods": [ + "public abstract java.lang.Double apply(com.yahoo.tensor.evaluation.EvaluationContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public bridge synthetic java.lang.Object apply(java.lang.Object)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.ScalarFunctions$Abs": { "superClass": "java.lang.Object", "interfaces": [ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index c0d817459d0..d81c02fb75f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -18,6 +18,21 @@ public final class DimensionSizes { } /** + * Create sizes from a type containing bound indexed dimensions only. + * + * @throws IllegalStateException if the type contains dimensions which are not bound and indexed + */ + public static DimensionSizes of(TensorType type) { + Builder b = new Builder(type.rank()); + for (int i = 0; i < type.rank(); i++) { + if ( type.dimensions().get(i).type() != TensorType.Dimension.Type.indexedBound) + throw new IllegalArgumentException(type + " contains dimensions without a size"); + b.set(i, type.dimensions().get(i).size().get()); + } + return b.build(); + } + + /** * Returns the length of this in the nth dimension * * @throws IllegalArgumentException if the index is larger than the number of dimensions in this tensor minus one diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 15476567fb2..176ddfefc13 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -758,6 +758,15 @@ public abstract class IndexedTensor implements Tensor { protected final long[] indexes; + /** + * Create indexes from a type containing bound indexed dimensions only. + * + * @throws IllegalStateException if the type contains dimensions which are not bound and indexed + */ + public static Indexes of(TensorType type) { + return of(DimensionSizes.of(type)); + } + public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); } @@ -824,7 +833,7 @@ public abstract class IndexedTensor implements Tensor { } /** Returns the address of the current position of these indexes */ - private TensorAddress toAddress() { + public TensorAddress toAddress() { return TensorAddress.of(indexes); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index 9ce2496c65b..b8b644f8b49 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -46,21 +46,28 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { TensorType type() { return type; } + @Override + public String toString(ToStringContext context) { + return type().toString() + ":" + contentToString(context); + } + + abstract String contentToString(ToStringContext context); + /** Creates a dynamic tensor function. The cell addresses must match the type. */ - public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) { return new MappedDynamicTensor(type, cells); } /** Creates a dynamic tensor function for a bound, indexed tensor */ - public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) { return new IndexedDynamicTensor(type, cells); } private static class MappedDynamicTensor extends DynamicTensor { - private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells; + private final ImmutableMap<TensorAddress, ScalarFunction> cells; - MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) { super(type); this.cells = ImmutableMap.copyOf(cells); } @@ -74,11 +81,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return type().toString() + ":" + contentToString(); - } - - private String contentToString() { + String contentToString(ToStringContext context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.values().iterator().next() + "}"; @@ -86,7 +89,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { StringBuilder b = new StringBuilder("{"); for (var cell : cells.entrySet()) { - b.append(cell.getKey().toString(type())).append(":").append(cell.getValue()); + b.append(cell.getKey().toString(type())).append(":").append(cell.getValue().toString(context)); b.append(","); } if (b.length() > 1) @@ -100,9 +103,9 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { private static class IndexedDynamicTensor extends DynamicTensor { - private final List<Function<EvaluationContext<?>, Double>> cells; + private final List<ScalarFunction> cells; - IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) { super(type); if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + @@ -119,24 +122,22 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { - return type().toString() + ":" + contentToString(); - } - - private String contentToString() { + String contentToString(ToStringContext context) { if (type().dimensions().isEmpty()) { if (cells.isEmpty()) return "{}"; return "{" + cells.get(0) + "}"; } - StringBuilder b = new StringBuilder("["); + IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(type()); + StringBuilder b = new StringBuilder("{"); for (var cell : cells) { - b.append(cell); + indexes.next(); + b.append(indexes.toAddress().toString(type())).append(":").append(cell.toString(context)); b.append(","); } if (b.length() > 1) b.setLength(b.length() - 1); - b.append("]"); + b.append("}"); return b.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java new file mode 100644 index 00000000000..c6a244b64df --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -0,0 +1,22 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.evaluation.EvaluationContext; + +import java.util.function.Function; + +/** + * A function which returns a scalar + * + * @author bratseth + */ +public interface ScalarFunction extends Function<EvaluationContext<?>, Double> { + + @Override + Double apply(EvaluationContext<?> context); + + default String toString(ToStringContext context) { + return toString(); + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 82652fb0e5d..925da9d3c89 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -24,15 +24,17 @@ public class DynamicTensorTestCase { DynamicTensor t1 = DynamicTensor.from(dense, List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); + assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); TensorType sparse = TensorType.fromSpec("tensor(x{})"); DynamicTensor t2 = DynamicTensor.from(sparse, Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); + assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } - private static class Constant implements Function<EvaluationContext<?>, Double> { + private static class Constant implements ScalarFunction { private final double value; @@ -41,6 +43,9 @@ public class DynamicTensorTestCase { @Override public Double apply(EvaluationContext<?> evaluationContext) { return value; } + @Override + public String toString() { return String.valueOf(value); } + } } |