summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java4
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/query/WandItem.java4
-rw-r--r--searchlib/abi-spec.json14
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java41
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj96
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java3
-rw-r--r--vespajlib/abi-spec.json17
-rw-r--r--vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java146
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java46
13 files changed, 319 insertions, 64 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index 4acb47df179..bbfd2004caa 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -432,7 +432,7 @@ public class ConvertedModel {
if (node instanceof ReferenceNode) {
ReferenceNode referenceNode = (ReferenceNode) node;
if (model.inputTypeSpec(referenceNode.getName()).isPresent()) {
- return reduceBatchDimensionExpression(TensorFunctionNode.wrapArgument(node), typeContext);
+ return reduceBatchDimensionExpression(TensorFunctionNode.wrap(node), typeContext);
}
}
if (node instanceof CompositeNode) {
@@ -485,7 +485,7 @@ public class ConvertedModel {
new GeneratorLambdaFunctionNode(expandDimensionsType,
generatedExpression)
.asLongListToDoubleOperator());
- Join expand = new Join(TensorFunctionNode.wrapArgument(node), generatedFunction, ScalarFunctions.multiply());
+ Join expand = new Join(TensorFunctionNode.wrap(node), generatedFunction, ScalarFunctions.multiply());
return new TensorFunctionNode(expand);
}
return node;
diff --git a/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java b/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java
index 20f034df1df..a70d653b90a 100644
--- a/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java
+++ b/container-search/src/main/java/com/yahoo/prelude/query/WandItem.java
@@ -24,8 +24,8 @@ public class WandItem extends WeightedSetItem {
/**
* Creates an empty WandItem.
*
- * @param fieldName The name of the weighted set field to search with this WandItem.
- * @param targetNumHits The target for minimum number of hits to produce by the backend search operator handling this WandItem.
+ * @param fieldName the name of the weighted set field to search with this WandItem.
+ * @param targetNumHits the target for minimum number of hits to produce by the backend search operator handling this WandItem.
*/
public WandItem(String fieldName, int targetNumHits) {
super(fieldName);
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 5ef3cd61366..1258601a2d1 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -909,11 +909,11 @@
"public final java.util.List tagCommaLeadingList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode tensorValueBody(com.yahoo.tensor.TensorType)",
- "public final void mappedTensorValueBody(com.yahoo.tensor.Tensor$Builder)",
- "public final void indexedTensorValueBody(com.yahoo.tensor.Tensor$Builder)",
- "public final void tensorCell(com.yahoo.tensor.Tensor$Builder$CellBuilder)",
- "public final void labelAndDimension(com.yahoo.tensor.Tensor$Builder$CellBuilder)",
+ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)",
+ "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)",
"public void <init>(java.io.InputStream)",
"public void <init>(java.io.InputStream, java.lang.String)",
"public void ReInit(java.io.InputStream)",
@@ -1612,7 +1612,9 @@
"public java.lang.StringBuilder toString(java.lang.StringBuilder, com.yahoo.searchlib.rankingexpression.rule.SerializationContext, java.util.Deque, com.yahoo.searchlib.rankingexpression.rule.CompositeNode)",
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
- "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrapArgument(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
+ "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$TensorFunctionExpressionNode wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
+ "public static java.util.Map wrap(java.util.Map)",
+ "public static java.util.List wrap(java.util.List)"
],
"fields": []
},
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index c1732aabf0b..e6e49e07c34 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -7,6 +7,7 @@ import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -14,9 +15,13 @@ import com.yahoo.tensor.functions.PrimitiveTensorFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
+import java.util.LinkedHashMap;
import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
import java.util.stream.Collectors;
/**
@@ -72,10 +77,44 @@ public class TensorFunctionNode extends CompositeNode {
return new TensorValue(function.evaluate(context));
}
- public static TensorFunctionExpressionNode wrapArgument(ExpressionNode node) {
+ public static TensorFunctionExpressionNode wrap(ExpressionNode node) {
return new TensorFunctionExpressionNode(node);
}
+ public static Map<TensorAddress, Function<EvaluationContext<?>, Double>> wrap(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, Function<EvaluationContext<?>, Double>> closures = new LinkedHashMap<>();
+ for (var entry : nodes.entrySet())
+ closures.put(entry.getKey(), new ExpressionClosure(entry.getValue()));
+ return closures;
+ }
+
+ public static List<Function<EvaluationContext<?>, Double>> wrap(List<ExpressionNode> nodes) {
+ List<Function<EvaluationContext<?>, Double>> closures = new ArrayList<>();
+ for (var entry : nodes)
+ closures.add(new ExpressionClosure(entry));
+ return closures;
+ }
+
+ private static class ExpressionClosure implements java.util.function.Function<EvaluationContext<?> , Double> {
+
+ private final ExpressionNode expression;
+
+ public ExpressionClosure(ExpressionNode expression) {
+ this.expression = expression;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<?> context) {
+ return expression.evaluate((Context)context).asDouble();
+ }
+
+ @Override
+ public String toString() {
+ return expression.toString();
+ }
+
+ }
+
/**
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
index 979c5b0f88c..6d687b015f1 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -83,7 +83,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E
ExpressionNode arg1 = node.children().get(0);
ExpressionNode arg2 = node.children().get(1);
- TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrapArgument(arg1);
+ TensorFunctionNode.TensorFunctionExpressionNode expression = TensorFunctionNode.wrap(arg1);
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 8f411bf6593..47555d95e58 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -30,6 +30,7 @@ import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
+@SuppressWarnings({"rawtypes", "unchecked"})
public class RankingExpressionParser {
}
@@ -401,8 +402,8 @@ ExpressionNode tensorMap() :
}
{
<MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Map(TensorFunctionNode.wrapArgument(tensor),
- doubleMapper.asDoubleUnaryOperator())); }
+ { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor),
+ doubleMapper.asDoubleUnaryOperator())); }
}
ExpressionNode tensorReduce() :
@@ -413,7 +414,7 @@ ExpressionNode tensorReduce() :
}
{
<REDUCE> <LBRACE> tensor = expression() <COMMA> aggregator = tensorReduceAggregator() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorReduceComposites() :
@@ -425,7 +426,7 @@ ExpressionNode tensorReduceComposites() :
{
aggregator = tensorReduceAggregator()
<LBRACE> tensor = expression() dimensions = tagCommaLeadingList() <RBRACE>
- { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrapArgument(tensor), aggregator, dimensions)); }
+ { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
ExpressionNode tensorJoin() :
@@ -435,9 +436,9 @@ ExpressionNode tensorJoin() :
}
{
<JOIN> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> doubleJoiner = lambdaFunction() <RBRACE>
- { return new TensorFunctionNode(new Join(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- doubleJoiner.asDoubleBinaryOperator())); }
+ { return new TensorFunctionNode(new Join(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ doubleJoiner.asDoubleBinaryOperator())); }
}
ExpressionNode tensorRename() :
@@ -450,7 +451,7 @@ ExpressionNode tensorRename() :
fromDimensions = bracedIdentifierList() <COMMA>
toDimensions = bracedIdentifierList()
<RBRACE>
- { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrapArgument(tensor), fromDimensions, toDimensions)); }
+ { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); }
}
ExpressionNode tensorConcat() :
@@ -460,8 +461,8 @@ ExpressionNode tensorConcat() :
}
{
<CONCAT> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = tag() <RBRACE>
- { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
+ { return new TensorFunctionNode(new Concat(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
dimension)); }
}
@@ -522,7 +523,7 @@ ExpressionNode tensorL1Normalize() :
}
{
<L1_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorL2Normalize() :
@@ -532,7 +533,7 @@ ExpressionNode tensorL2Normalize() :
}
{
<L2_NORMALIZE> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorMatmul() :
@@ -542,9 +543,9 @@ ExpressionNode tensorMatmul() :
}
{
<MATMUL> <LBRACE> tensor1 = expression() <COMMA> tensor2 = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- dimension)); }
+ { return new TensorFunctionNode(new Matmul(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ dimension)); }
}
ExpressionNode tensorSoftmax() :
@@ -554,7 +555,7 @@ ExpressionNode tensorSoftmax() :
}
{
<SOFTMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorXwPlusB() :
@@ -567,9 +568,9 @@ ExpressionNode tensorXwPlusB() :
tensor2 = expression() <COMMA>
tensor3 = expression() <COMMA>
dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrapArgument(tensor1),
- TensorFunctionNode.wrapArgument(tensor2),
- TensorFunctionNode.wrapArgument(tensor3),
+ { return new TensorFunctionNode(new XwPlusB(TensorFunctionNode.wrap(tensor1),
+ TensorFunctionNode.wrap(tensor2),
+ TensorFunctionNode.wrap(tensor3),
dimension)); }
}
@@ -580,7 +581,7 @@ ExpressionNode tensorArgmax() :
}
{
<ARGMAX> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
ExpressionNode tensorArgmin() :
@@ -590,7 +591,7 @@ ExpressionNode tensorArgmin() :
}
{
<ARGMIN> <LBRACE> tensor = expression() <COMMA> dimension = identifier() <RBRACE>
- { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrapArgument(tensor), dimension)); }
+ { return new TensorFunctionNode(new Argmin(TensorFunctionNode.wrap(tensor), dimension)); }
}
LambdaFunctionNode lambdaFunction() :
@@ -823,63 +824,62 @@ Value primitiveValue() :
{ return Value.parse(sign + token.image); }
}
-ConstantNode tensorValueBody(TensorType type) :
+ExpressionNode tensorValueBody(TensorType type) :
{
- Tensor.Builder builder = Tensor.Builder.of(type);
+ DynamicTensor dynamicTensor;
}
{
<COLON>
(
- mappedTensorValueBody(builder) |
- indexedTensorValueBody(builder)
+ dynamicTensor = mappedTensorValueBody(type) |
+ dynamicTensor = indexedTensorValueBody(type)
)
- { return new ConstantNode(new TensorValue(builder.build())); }
+ { return new TensorFunctionNode(dynamicTensor); }
}
-void mappedTensorValueBody(Tensor.Builder builder) : {}
+DynamicTensor mappedTensorValueBody(TensorType type) :
+{
+ java.util.Map cells = new LinkedHashMap();
+}
{
<LCURLY>
- ( tensorCell(builder.cell()))*
- ( <COMMA> tensorCell(builder.cell()))*
+ ( tensorCell(type, cells))*
+ ( <COMMA> tensorCell(type, cells))*
<RCURLY>
+ { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
}
-void indexedTensorValueBody(Tensor.Builder builder) :
+DynamicTensor indexedTensorValueBody(TensorType type) :
{
- IndexedTensor.BoundBuilder indexedBuilder;
- long index = 0;
- double value;
+ List cells = new ArrayList();
+ ExpressionNode value;
}
{
- {
- if ( ! (builder instanceof IndexedTensor.BoundBuilder))
- throw new IllegalArgumentException("The tensor short form [n, n, ...] can only be used for indexed " +
- "bound tensors, not " + builder.type());
- indexedBuilder = (IndexedTensor.BoundBuilder)builder;
- }
<LSQUARE>
- ( value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
- ( <COMMA> value = doubleNumber() { indexedBuilder.cellByDirectIndex(index++, value); } )*
+ ( value = expression() { cells.add(value); } )*
+ ( <COMMA> value = expression() { cells.add(value); } )*
<RSQUARE>
+ { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
}
-void tensorCell(Tensor.Builder.CellBuilder cellBuilder) :
+void tensorCell(TensorType type, java.util.Map cells) :
{
- double value;
+ ExpressionNode value;
+ TensorAddress.Builder addressBuilder = new TensorAddress.Builder(type);
}
{
<LCURLY>
- ( labelAndDimension(cellBuilder))*
- ( <COMMA> labelAndDimension(cellBuilder))*
+ ( labelAndDimension(addressBuilder))*
+ ( <COMMA> labelAndDimension(addressBuilder))*
<RCURLY>
- <COLON> value = doubleNumber() { cellBuilder.value(value); }
+ <COLON> value = expression() { cells.put(addressBuilder.build(), value); }
}
-void labelAndDimension(Tensor.Builder.CellBuilder cellBuilder) :
+void labelAndDimension(TensorAddress.Builder addressBuilder) :
{
String dimension, label;
}
{
dimension = identifier() <COLON> label = tag()
- { cellBuilder.label(dimension, label); }
+ { addressBuilder.add(dimension, label); }
} \ No newline at end of file
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index 571e1f4d608..a41f24b3b8a 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -154,7 +154,12 @@ public class RankingExpressionTestCase {
"map(constant(tensor0), f(a)(cos(a))) + l2_normalize(attribute(tensor1), x)");
assertSerialization("join(reduce(join(reduce(join(constant(tensor0), attribute(tensor1), f(a,b)(a * b)), sum, x), attribute(tensor1), f(a,b)(a * b)), sum, y), query(tensor2), f(a,b)(a + b))",
"xw_plus_b(matmul(constant(tensor0), attribute(tensor1), x), attribute(tensor1), query(tensor2), y)");
-
+ assertSerialization("tensor(x{}):{{x:a}:1 + 2 + 3,{x:b}:if (1 > 2, 3, 4),{x:c}:reduce(tensor0 * tensor1, sum)}",
+ "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }");
+ assertSerialization("tensor(x[3]):[1.0,2.0,3]",
+ "tensor(x[3]):[1.0, 2.0, 3]");
+ assertSerialization("tensor(x[3]):[1.0,reduce(tensor0 * tensor1, sum),3]",
+ "tensor(x[3]):[1.0, sum(tensor0*tensor1), 3]");
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 7aafb8efee7..e28daefdabf 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -368,6 +368,9 @@ public class EvaluationTestCase {
"tensor(x{}):{}");
tester.assertEvaluates("tensor():{{}:1}",
"tensor():{{}:1}");
+ tester.assertEvaluates("tensor(x{}):{ {x:a}:6.0, {x:b}:4.0, {x:c}:14.0 }",
+ "tensor(x{}):{ {x:a}:1+2+3, {x:b}:if(1>2,3,4), {x:c}:sum(tensor0*tensor1) }",
+ "{ {x:0}:7 }", "tensor(x{}):{ {x:0}:2 }");
}
@Test
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 9590a97ea55..6a93a17a8c1 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1591,6 +1591,23 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.DynamicTensor": {
+ "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public",
+ "abstract"
+ ],
+ "methods": [
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.Map)",
+ "public static com.yahoo.tensor.functions.DynamicTensor from(com.yahoo.tensor.TensorType, java.util.List)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.Generate": {
"superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
"interfaces": [],
diff --git a/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java b/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java
index ca0baf95ee2..7db43a7442a 100644
--- a/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java
+++ b/vespajlib/src/main/java/com/yahoo/collections/CopyOnWriteHashMap.java
@@ -19,7 +19,6 @@ import java.util.Set;
*
* @author bratseth
*/
-@Beta
public class CopyOnWriteHashMap<K,V> extends AbstractMap<K,V> implements Cloneable {
private Map<K,V> map;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
index f5ef88016ac..15476567fb2 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java
@@ -745,8 +745,6 @@ public abstract class IndexedTensor implements Tensor {
}
- // TODO: Make dimensionSizes a class
-
/**
* An array of indexes into this tensor which are able to find the next index in the value order.
* next() can be called once per element in the dimensions we iterate over. It must be called once
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
new file mode 100644
index 00000000000..9ce2496c65b
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -0,0 +1,146 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.tensor.IndexedTensor;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.ArrayList;
+import java.util.Collections;
+import java.util.List;
+import java.util.Map;
+import java.util.function.Function;
+
+/**
+ * A function which is a tensor whose values are computed by individual lambda functions on evaluation.
+ *
+ * @author bratseth
+ */
+public abstract class DynamicTensor extends PrimitiveTensorFunction {
+
+ private final TensorType type;
+
+ DynamicTensor(TensorType type) {
+ this.type = type;
+ }
+
+ @Override
+ public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+
+ @Override
+ public List<TensorFunction> arguments() { return Collections.emptyList(); }
+
+ @Override
+ public TensorFunction withArguments(List<TensorFunction> arguments) {
+ if (arguments.size() != 0)
+ throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
+ return this;
+ }
+
+ @Override
+ public PrimitiveTensorFunction toPrimitive() { return this; }
+
+ TensorType type() { return type; }
+
+ /** Creates a dynamic tensor function. The cell addresses must match the type. */
+ public static DynamicTensor from(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) {
+ return new MappedDynamicTensor(type, cells);
+ }
+
+ /** Creates a dynamic tensor function for a bound, indexed tensor */
+ public static DynamicTensor from(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) {
+ return new IndexedDynamicTensor(type, cells);
+ }
+
+ private static class MappedDynamicTensor extends DynamicTensor {
+
+ private final ImmutableMap<TensorAddress, Function<EvaluationContext<?> , Double>> cells;
+
+ MappedDynamicTensor(TensorType type, Map<TensorAddress, Function<EvaluationContext<?> , Double>> cells) {
+ super(type);
+ this.cells = ImmutableMap.copyOf(cells);
+ }
+
+ @Override
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor.Builder builder = Tensor.Builder.of(type());
+ for (var cell : cells.entrySet())
+ builder.cell(cell.getKey(), cell.getValue().apply(context));
+ return builder.build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return type().toString() + ":" + contentToString();
+ }
+
+ private String contentToString() {
+ if (type().dimensions().isEmpty()) {
+ if (cells.isEmpty()) return "{}";
+ return "{" + cells.values().iterator().next() + "}";
+ }
+
+ StringBuilder b = new StringBuilder("{");
+ for (var cell : cells.entrySet()) {
+ b.append(cell.getKey().toString(type())).append(":").append(cell.getValue());
+ b.append(",");
+ }
+ if (b.length() > 1)
+ b.setLength(b.length() - 1);
+ b.append("}");
+
+ return b.toString();
+ }
+
+ }
+
+ private static class IndexedDynamicTensor extends DynamicTensor {
+
+ private final List<Function<EvaluationContext<?>, Double>> cells;
+
+ IndexedDynamicTensor(TensorType type, List<Function<EvaluationContext<?> , Double>> cells) {
+ super(type);
+ if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
+ throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
+ "only indexed, bound dimensions, but this has " + type);
+ this.cells = List.copyOf(cells);
+ }
+
+ @Override
+ public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
+ for (int i = 0; i < cells.size(); i++)
+ builder.cellByDirectIndex(i, cells.get(i).apply(context));
+ return builder.build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return type().toString() + ":" + contentToString();
+ }
+
+ private String contentToString() {
+ if (type().dimensions().isEmpty()) {
+ if (cells.isEmpty()) return "{}";
+ return "{" + cells.get(0) + "}";
+ }
+
+ StringBuilder b = new StringBuilder("[");
+ for (var cell : cells) {
+ b.append(cell);
+ b.append(",");
+ }
+ if (b.length() > 1)
+ b.setLength(b.length() - 1);
+ b.append("]");
+
+ return b.toString();
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
new file mode 100644
index 00000000000..82652fb0e5d
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -0,0 +1,46 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import org.junit.Test;
+
+import java.util.Collections;
+import java.util.List;
+import java.util.function.Function;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class DynamicTensorTestCase {
+
+ @Test
+ public void testDynamicTensorFunction() {
+ TensorType dense = TensorType.fromSpec("tensor(x[3])");
+ DynamicTensor t1 = DynamicTensor.from(dense,
+ List.of(new Constant(1), new Constant(2), new Constant(3)));
+ assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
+
+ TensorType sparse = TensorType.fromSpec("tensor(x{})");
+ DynamicTensor t2 = DynamicTensor.from(sparse,
+ Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
+ new Constant(5)));
+ assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
+ }
+
+ private static class Constant implements Function<EvaluationContext<?>, Double> {
+
+ private final double value;
+
+ public Constant(double value) { this.value = value; }
+
+ @Override
+ public Double apply(EvaluationContext<?> evaluationContext) { return value; }
+
+ }
+
+}