diff options
13 files changed, 319 insertions, 64 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index 4acb47df179..bbfd2004caa 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -432,7 +432,7 @@ public class ConvertedModel { if (node instanceof ReferenceNode) { ReferenceNode referenceNode = (ReferenceNode) node; if (model.inputTypeSpec(referenceNode.getName()).isPresent()) { - return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext); + return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext); } } if (node instanceof CompositeNode) { @@ -485,7 +485,7 @@ public class ConvertedModel { new GeneratorLambdaFunctionNode(expandDimensionsType, generatedExpression) .asLongListToDoubleOperator()); - Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply()); + Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply()); return new TensorFunctionNode(expand); } return node; diff --git a/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java b/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java index 20f034df1df..a70d653b90a 100644 --- a/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java +++ b/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java @@ -24,8 +24,8 @@ public class WandItem extends WeightedSetItem { /** * Creates an empty WandItem. * - * @param fieldName The name of the weighted set field to search with this WandItem. - * @param targetNumHits The target for minimum number of hits to produce by the backend search operator handling this WandItem. + * @param fieldName the name of the weighted set field to search with this WandItem. + * @param targetNumHits the target for minimum number of hits to produce by the backend search operator handling this WandItem. */ public WandItem(String fieldName, int targetNumHits) { super(fieldName); diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 5ef3cd61366..1258601a2d1 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -909,11 +909,11 @@ "public final java.util.List tagCommaLeadingList()", "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()", - "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode tensorValueBody(com.yahoo.tensor.TensorType)", - "public final void mappedTensorValueBody(com.yahoo.tensor.Tensor$Builder)", - "public final void indexedTensorValueBody(com.yahoo.tensor.Tensor$Builder)", - "public final void tensorCell(com.yahoo.tensor.Tensor$Builder$CellBuilder)", - "public final void labelAndDimension(com.yahoo.tensor.Tensor$Builder$CellBuilder)", + "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)", + "public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)", + "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)", "public void <init>(java.io.InputStream)", "public void <init>(java.io.InputStream, java.lang.String)", "public void ReInit(java.io.InputStream)", @@ -1612,7 +1612,9 @@ "public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)", "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", - "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrapArgument(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" + "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", + "public static java.util.Map wrap(java.util.Map)", + "public static java.util.List wrap(java.util.List)" ], "fields": [] }, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index c1732aabf0b..e6e49e07c34 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.TypeContext; @@ -14,9 +15,13 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; +import java.util.LinkedHashMap; import java.util.List; +import java.util.Map; +import java.util.function.Function; import java.util.stream.Collectors; /** @@ -72,10 +77,44 @@ public class TensorFunctionNode extends CompositeNode { return new TensorValue(function.evaluate(context)); } - public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) { + public static TensorFunctionExpressionNode wrap(ExpressionNode node) { return new TensorFunctionExpressionNode(node); } + public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>(); + for (var entry : nodes.entrySet()) + closures.put(entry.getKey(), new ExpressionClosure(entry.getValue())); + return closures; + } + + public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) { + List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>(); + for (var entry : nodes) + closures.add(new ExpressionClosure(entry)); + return closures; + } + + private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> { + + private final ExpressionNode expression; + + public ExpressionClosure(ExpressionNode expression) { + this.expression = expression; + } + + @Override + public Double apply(EvaluationContext<?> context) { + return expression.evaluate((Context)context).asDouble(); + } + + @Override + public String toString() { + return expression.toString(); + } + + } + /** * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 979c5b0f88c..6d687b015f1 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E ExpressionNode arg1 = node.children().get(0); ExpressionNode arg2 = node.children().get(1); - TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1); + TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1); Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 8f411bf6593..47555d95e58 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -30,6 +30,7 @@ import java.util.ArrayList; import java.util.List; import java.util.Optional; +@SuppressWarnings({"rawtypes", "unchecked"}) public class RankingExpressionParser { } @@ -401,8 +402,8 @@ ExpressionNode tensorMap() : } { <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor), - doubleMapper.asDoubleUnaryOperator())); } + { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor), + doubleMapper.asDoubleUnaryOperator())); } } ExpressionNode tensorReduce() : @@ -413,7 +414,7 @@ ExpressionNode tensorReduce() : } { <REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } ExpressionNode tensorReduceComposites() : @@ -425,7 +426,7 @@ ExpressionNode tensorReduceComposites() : { aggregator = tensorReduceAggregator() <LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE> - { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); } + { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } ExpressionNode tensorJoin() : @@ -435,9 +436,9 @@ ExpressionNode tensorJoin() : } { <JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE> - { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - doubleJoiner.asDoubleBinaryOperator())); } + { return new TensorFunctionNode(new Join(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + doubleJoiner.asDoubleBinaryOperator())); } } ExpressionNode tensorRename() : @@ -450,7 +451,7 @@ ExpressionNode tensorRename() : fromDimensions = bracedIdentifierList() <COMMA> toDimensions = bracedIdentifierList() <RBRACE> - { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); } + { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); } } ExpressionNode tensorConcat() : @@ -460,8 +461,8 @@ ExpressionNode tensorConcat() : } { <CONCAT> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = tag() <RBRACE> - { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), + { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), dimension)); } } @@ -522,7 +523,7 @@ ExpressionNode tensorL1Normalize() : } { <L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorL2Normalize() : @@ -532,7 +533,7 @@ ExpressionNode tensorL2Normalize() : } { <L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorMatmul() : @@ -542,9 +543,9 @@ ExpressionNode tensorMatmul() : } { <MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - dimension)); } + { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + dimension)); } } ExpressionNode tensorSoftmax() : @@ -554,7 +555,7 @@ ExpressionNode tensorSoftmax() : } { <SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorXwPlusB() : @@ -567,9 +568,9 @@ ExpressionNode tensorXwPlusB() : tensor2 = expression() <COMMA> tensor3 = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1), - TensorFunctionNode.wrapArgument(tensor2), - TensorFunctionNode.wrapArgument(tensor3), + { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrap(tensor1), + TensorFunctionNode.wrap(tensor2), + TensorFunctionNode.wrap(tensor3), dimension)); } } @@ -580,7 +581,7 @@ ExpressionNode tensorArgmax() : } { <ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); } } ExpressionNode tensorArgmin() : @@ -590,7 +591,7 @@ ExpressionNode tensorArgmin() : } { <ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE> - { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); } + { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimension)); } } LambdaFunctionNode lambdaFunction() : @@ -823,63 +824,62 @@ Value primitiveValue() : { return Value.parse(sign + token.image); } } -ConstantNode tensorValueBody(TensorType type) : +ExpressionNode tensorValueBody(TensorType type) : { - Tensor.Builder builder = Tensor.Builder.of(type); + DynamicTensor dynamicTensor; } { <COLON> ( - mappedTensorValueBody(builder) | - indexedTensorValueBody(builder) + dynamicTensor = mappedTensorValueBody(type) | + dynamicTensor = indexedTensorValueBody(type) ) - { return new ConstantNode(new TensorValue(builder.build())); } + { return new TensorFunctionNode(dynamicTensor); } } -void mappedTensorValueBody(Tensor.Builder builder) : {} +DynamicTensor mappedTensorValueBody(TensorType type) : +{ + java.util.Map cells = new LinkedHashMap(); +} { <LCURLY> - ( tensorCell(builder.cell()))* - ( <COMMA> tensorCell(builder.cell()))* + ( tensorCell(type, cells))* + ( <COMMA> tensorCell(type, cells))* <RCURLY> + { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } } -void indexedTensorValueBody(Tensor.Builder builder) : +DynamicTensor indexedTensorValueBody(TensorType type) : { - IndexedTensor.BoundBuilder indexedBuilder; - long index = 0; - double value; + List cells = new ArrayList(); + ExpressionNode value; } { - { - if ( ! (builder instanceof IndexedTensor.BoundBuilder)) - throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " + - "bound tensors, not " + builder.type()); - indexedBuilder = (IndexedTensor.BoundBuilder)builder; - } <LSQUARE> - ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* - ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )* + ( value = expression() { cells.add(value); } )* + ( <COMMA> value = expression() { cells.add(value); } )* <RSQUARE> + { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } } -void tensorCell(Tensor.Builder.CellBuilder cellBuilder) : +void tensorCell(TensorType type, java.util.Map cells) : { - double value; + ExpressionNode value; + TensorAddress.Builder addressBuilder = new TensorAddress.Builder(type); } { <LCURLY> - ( labelAndDimension(cellBuilder))* - ( <COMMA> labelAndDimension(cellBuilder))* + ( labelAndDimension(addressBuilder))* + ( <COMMA> labelAndDimension(addressBuilder))* <RCURLY> - <COLON> value = doubleNumber() { cellBuilder.value(value); } + <COLON> value = expression() { cells.put(addressBuilder.build(), value); } } -void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) : +void labelAndDimension(TensorAddress.Builder addressBuilder) : { String dimension, label; } { dimension = identifier() <COLON> label = tag() - { cellBuilder.label(dimension, label); } + { addressBuilder.add(dimension, label); } }
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index 571e1f4d608..a41f24b3b8a 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -154,7 +154,12 @@ public class RankingExpressionTestCase { "map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)"); assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))", "xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)"); - + assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}", + "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }"); + assertSerialization("tensor(x[3]):[1.0,2.0,3]", + "tensor(x[3]):[1.0, 2.0, 3]"); + assertSerialization("tensor(x[3]):[1.0,reduce(tensor0 * tensor1, sum),3]", + "tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]"); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 7aafb8efee7..e28daefdabf 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -368,6 +368,9 @@ public class EvaluationTestCase { "tensor(x{}):{}"); tester.assertEvaluates("tensor():{{}:1}", "tensor():{{}:1}"); + tester.assertEvaluates("tensor(x{}):{ {x:a}:6.0, {x:b}:4.0, {x:c}:14.0 }", + "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }", + "{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }"); } @Test diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 9590a97ea55..6a93a17a8c1 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1591,6 +1591,23 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.DynamicTensor": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public", + "abstract" + ], + "methods": [ + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.Map)", + "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.List)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.Generate": { "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java b/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java index ca0baf95ee2..7db43a7442a 100644 --- a/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java +++ b/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java @@ -19,7 +19,6 @@ import java.util.Set; * * @author bratseth */ -@Beta public class CopyOnWriteHashMap<K,V> extends AbstractMap<K,V> implements Cloneable { private Map<K,V> map; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index f5ef88016ac..15476567fb2 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -745,8 +745,6 @@ public abstract class IndexedTensor implements Tensor { } - // TODO: Make dimensionSizes a class - /** * An array of indexes into this tensor which are able to find the next index in the value order. * next() can be called once per element in the dimensions we iterate over. It must be called once diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java new file mode 100644 index 00000000000..9ce2496c65b --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -0,0 +1,146 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.collect.ImmutableMap; +import com.yahoo.tensor.IndexedTensor; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Function; + +/** + * A function which is a tensor whose values are computed by individual lambda functions on evaluation. + * + * @author bratseth + */ +public abstract class DynamicTensor extends PrimitiveTensorFunction { + + private final TensorType type; + + DynamicTensor(TensorType type) { + this.type = type; + } + + @Override + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + + @Override + public List<TensorFunction> arguments() { return Collections.emptyList(); } + + @Override + public TensorFunction withArguments(List<TensorFunction> arguments) { + if (arguments.size() != 0) + throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); + return this; + } + + @Override + public PrimitiveTensorFunction toPrimitive() { return this; } + + TensorType type() { return type; } + + /** Creates a dynamic tensor function. The cell addresses must match the type. */ + public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + return new MappedDynamicTensor(type, cells); + } + + /** Creates a dynamic tensor function for a bound, indexed tensor */ + public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + return new IndexedDynamicTensor(type, cells); + } + + private static class MappedDynamicTensor extends DynamicTensor { + + private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells; + + MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) { + super(type); + this.cells = ImmutableMap.copyOf(cells); + } + + @Override + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor.Builder builder = Tensor.Builder.of(type()); + for (var cell : cells.entrySet()) + builder.cell(cell.getKey(), cell.getValue().apply(context)); + return builder.build(); + } + + @Override + public String toString(ToStringContext context) { + return type().toString() + ":" + contentToString(); + } + + private String contentToString() { + if (type().dimensions().isEmpty()) { + if (cells.isEmpty()) return "{}"; + return "{" + cells.values().iterator().next() + "}"; + } + + StringBuilder b = new StringBuilder("{"); + for (var cell : cells.entrySet()) { + b.append(cell.getKey().toString(type())).append(":").append(cell.getValue()); + b.append(","); + } + if (b.length() > 1) + b.setLength(b.length() - 1); + b.append("}"); + + return b.toString(); + } + + } + + private static class IndexedDynamicTensor extends DynamicTensor { + + private final List<Function<EvaluationContext<?>, Double>> cells; + + IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) { + super(type); + if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) + throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + + "only indexed, bound dimensions, but this has " + type); + this.cells = List.copyOf(cells); + } + + @Override + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); + for (int i = 0; i < cells.size(); i++) + builder.cellByDirectIndex(i, cells.get(i).apply(context)); + return builder.build(); + } + + @Override + public String toString(ToStringContext context) { + return type().toString() + ":" + contentToString(); + } + + private String contentToString() { + if (type().dimensions().isEmpty()) { + if (cells.isEmpty()) return "{}"; + return "{" + cells.get(0) + "}"; + } + + StringBuilder b = new StringBuilder("["); + for (var cell : cells) { + b.append(cell); + b.append(","); + } + if (b.length() > 1) + b.setLength(b.length() - 1); + b.append("]"); + + return b.toString(); + } + + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java new file mode 100644 index 00000000000..82652fb0e5d --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -0,0 +1,46 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import org.junit.Test; + +import java.util.Collections; +import java.util.List; +import java.util.function.Function; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class DynamicTensorTestCase { + + @Test + public void testDynamicTensorFunction() { + TensorType dense = TensorType.fromSpec("tensor(x[3])"); + DynamicTensor t1 = DynamicTensor.from(dense, + List.of(new Constant(1), new Constant(2), new Constant(3))); + assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); + + TensorType sparse = TensorType.fromSpec("tensor(x{})"); + DynamicTensor t2 = DynamicTensor.from(sparse, + Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), + new Constant(5))); + assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); + } + + private static class Constant implements Function<EvaluationContext<?>, Double> { + + private final double value; + + public Constant(double value) { this.value = value; } + + @Override + public Double apply(EvaluationContext<?> evaluationContext) { return value; } + + } + +} |