aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2019-11-28 10:24:57 +0100
committerGitHub <noreply@github.com>2019-11-28 10:24:57 +0100
commita907f095507bfd9aec0d6bd168217b4a0471b651 (patch)
treef0fe57f98b48829ba186e4f543748b2c6f25fe4a
parenta5e8e198dabc9dcfc710200d2ed170193f9b253b (diff)
parent0e18f68b1583b3391859b3def7f3a168b5212d15 (diff)
Merge pull request #11435 from vespa-engine/bratseth/value-function
Bratseth/value function
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java8
-rw-r--r--config-model/src/test/derived/tensor/rank-profiles.cfg17
-rw-r--r--config-model/src/test/derived/tensor/tensor.sd12
-rw-r--r--model-evaluation/abi-spec.json2
-rw-r--r--searchlib/abi-spec.json58
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java5
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java2
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java75
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java2
-rwxr-xr-xsearchlib/src/main/javacc/RankingExpressionParser.jj135
-rwxr-xr-xsearchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java2
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java18
-rw-r--r--vespajlib/abi-spec.json49
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java41
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java13
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java37
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java117
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java24
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java23
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java11
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java26
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java21
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java5
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java15
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java185
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java31
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java11
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java16
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java9
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java66
52 files changed, 959 insertions, 422 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index bbfd2004caa..55979023119 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -411,7 +411,7 @@ public class ConvertedModel {
}
// Modify any renames in expression to disregard batch dimension
else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) {
- TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function());
+ TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function());
TensorType childType = childFunction.type(typeContext);
Rename rename = (Rename) tensorFunction;
List<String> from = new ArrayList<>();
@@ -422,10 +422,10 @@ public class ConvertedModel {
throw new IllegalArgumentException("Rename does not contain dimension '" +
dimension + "' in child expression type: " + childType);
}
- from.add(rename.fromDimensions().get(i));
- to.add(rename.toDimensions().get(i));
+ from.add((String)rename.fromDimensions().get(i));
+ to.add((String)rename.toDimensions().get(i));
}
- return new TensorFunctionNode(new Rename(childFunction, from, to));
+ return new TensorFunctionNode(new Rename<>(childFunction, from, to));
}
}
}
diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg
index 617901130a6..cebfa244159 100644
--- a/config-model/src/test/derived/tensor/rank-profiles.cfg
+++ b/config-model/src/test/derived/tensor/rank-profiles.cfg
@@ -110,3 +110,20 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f4"
rankprofile[].fef.property[].value "tensor(x[10],y[20])"
rankprofile[].fef.property[].name "vespa.type.attribute.f5"
rankprofile[].fef.property[].value "tensor<float>(x[10])"
+rankprofile[].name "profile7"
+rankprofile[].fef.property[].name "rankingExpression(reshaped).rankingScript"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])({x:1 - x, y:d0})"
+rankprofile[].fef.property[].name "rankingExpression(reshaped).type"
+rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])"
+rankprofile[].fef.property[].name "vespa.rank.firstphase"
+rankprofile[].fef.property[].value "rankingExpression(firstphase)"
+rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript"
+rankprofile[].fef.property[].value "reduce(rankingExpression(reshaped), sum)"
+rankprofile[].fef.property[].name "vespa.type.attribute.f2"
+rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])"
+rankprofile[].fef.property[].name "vespa.type.attribute.f3"
+rankprofile[].fef.property[].value "tensor(x{})"
+rankprofile[].fef.property[].name "vespa.type.attribute.f4"
+rankprofile[].fef.property[].value "tensor(x[10],y[20])"
+rankprofile[].fef.property[].name "vespa.type.attribute.f5"
+rankprofile[].fef.property[].value "tensor<float>(x[10])"
diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd
index 13727d1ec49..15d56517a43 100644
--- a/config-model/src/test/derived/tensor/tensor.sd
+++ b/config-model/src/test/derived/tensor/tensor.sd
@@ -78,4 +78,16 @@ search tensor {
}
+ rank-profile profile7 {
+
+ first-phase {
+ expression: sum(reshaped())
+ }
+
+ function reshaped() {
+ expression: tensor<float>(d0[1],x[2])(attribute(f2){x:1-x, y:d0})
+ }
+
+ }
+
}
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 1857511cbf2..5a75c8b31ea 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -39,7 +39,7 @@
"public java.util.Set names()",
"public java.util.Set arguments()",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value defaultValue()",
- "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)"
+ "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)"
],
"fields": []
},
diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json
index 8d7bf4f9f14..abfae426ad0 100644
--- a/searchlib/abi-spec.json
+++ b/searchlib/abi-spec.json
@@ -342,7 +342,7 @@
"fields": []
},
"com.yahoo.searchlib.rankingexpression.Reference": {
- "superClass": "com.yahoo.tensor.evaluation.TypeContext$Name",
+ "superClass": "com.yahoo.tensor.evaluation.Name",
"interfaces": [],
"attributes": [
"public"
@@ -414,7 +414,7 @@
"public final double getDouble(int)",
"public com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext clone()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()",
- "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)",
+ "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
"public bridge synthetic java.lang.Object clone()"
],
"fields": []
@@ -521,7 +521,7 @@
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)",
"public com.yahoo.searchlib.rankingexpression.evaluation.DoubleOnlyArrayContext clone()",
"public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()",
- "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)",
+ "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
"public bridge synthetic java.lang.Object clone()"
],
"fields": []
@@ -592,7 +592,7 @@
"public java.util.Set names()",
"public java.lang.String toString()",
"public static com.yahoo.searchlib.rankingexpression.evaluation.MapContext fromString(java.lang.String)",
- "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)"
+ "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)"
],
"fields": []
},
@@ -610,7 +610,7 @@
"public com.yahoo.tensor.TensorType getType(java.lang.String)",
"public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)",
"public java.util.Map bindings()",
- "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)"
+ "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)"
],
"fields": []
},
@@ -872,25 +872,25 @@
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arg()",
"public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode function()",
"public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorFunction()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMap()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorReduce()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorReduceComposites()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorJoin()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRename()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorConcat()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerate()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerateBody(com.yahoo.tensor.TensorType)",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRange()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorDiag()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRandom()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorL1Normalize()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorL2Normalize()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMatmul()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorSoftmax()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorXwPlusB()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmax()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmin()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMap()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduce()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduceComposites()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorJoin()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRename()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorConcat()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorGenerate()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorGenerateBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRange()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorDiag()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRandom()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()",
"public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()",
"public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()",
"public final com.yahoo.tensor.TensorType tensorType()",
@@ -909,11 +909,14 @@
"public final java.util.List tagCommaLeadingList()",
"public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()",
"public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()",
- "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorValueBody(com.yahoo.tensor.TensorType)",
+ "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)",
"public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)",
"public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)",
"public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)",
"public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)",
+ "public final void labelAndDimensionValues(java.util.List)",
+ "public final java.util.List valueAddress()",
+ "public final com.yahoo.tensor.functions.Value$DimensionValue dimensionValue(java.util.Optional)",
"public void <init>(java.io.InputStream)",
"public void <init>(java.io.InputStream, java.lang.String)",
"public void ReInit(java.io.InputStream)",
@@ -1613,8 +1616,9 @@
"public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
"public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)",
"public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)",
- "public static java.util.Map wrap(java.util.Map)",
- "public static java.util.List wrap(java.util.List)"
+ "public static java.util.Map wrapScalars(java.util.Map)",
+ "public static java.util.List wrapScalars(java.util.List)",
+ "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)"
],
"fields": []
},
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
index 829a796eee0..fa2d0f1ee45 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java
@@ -7,19 +7,18 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.NameNode;
import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
-import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.evaluation.Name;
import java.util.Deque;
import java.util.Objects;
import java.util.Optional;
-import java.util.stream.Collectors;
/**
* A reference to a feature, function, or value in ranking expressions
*
* @author bratseth
*/
-public class Reference extends TypeContext.Name {
+public class Reference extends Name {
private final int hashCode;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
index d68f8c85ad1..cf17c6465f3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java
@@ -34,7 +34,7 @@ public abstract class Context implements EvaluationContext<Reference> {
@Override
public TensorType getType(String reference) {
- throw new UnsupportedOperationException("Not able to parse gereral references from string form");
+ throw new UnsupportedOperationException("Not able to parse general references from string form");
}
/** Returns a variable as a tensor */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
index 63cea371d14..41b01c9a2cb 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java
@@ -2,6 +2,7 @@
package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex;
import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport;
import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer;
@@ -58,11 +59,12 @@ public class TensorOptimizer extends Optimizer {
* The ReduceJoin class determines whether or not the arguments are
* compatible with the optimization.
*/
+ @SuppressWarnings("unchecked")
private ExpressionNode optimizeReduceJoin(ExpressionNode node) {
if ( ! (node instanceof TensorFunctionNode)) {
return node;
}
- TensorFunction function = ((TensorFunctionNode) node).function();
+ TensorFunction<Reference> function = ((TensorFunctionNode) node).function();
if ( ! (function instanceof Reduce)) {
return node;
}
@@ -74,10 +76,10 @@ public class TensorOptimizer extends Optimizer {
if ( ! (child instanceof TensorFunctionNode)) {
return node;
}
- TensorFunction argument = ((TensorFunctionNode) child).function();
+ TensorFunction<Reference> argument = ((TensorFunctionNode) child).function();
if (argument instanceof Join) {
report.incMetric("Replaced reduce->join", 1);
- return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument));
+ return new TensorFunctionNode(new ReduceJoin<>((Reduce<Reference>)function, (Join<Reference>)argument));
}
return node;
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
index 4ffd40f00f7..cec8837abcd 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java
@@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.ScalarFunction;
import com.yahoo.tensor.functions.TensorFunction;
import com.yahoo.tensor.functions.ToStringContext;
+import java.sql.Ref;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
@@ -32,14 +33,14 @@ import java.util.stream.Collectors;
@Beta
public class TensorFunctionNode extends CompositeNode {
- private final TensorFunction function;
+ private final TensorFunction<Reference> function;
- public TensorFunctionNode(TensorFunction function) {
+ public TensorFunctionNode(TensorFunction<Reference> function) {
this.function = function;
}
/** Returns the tensor function wrapped by this */
- public TensorFunction function() { return function; }
+ public TensorFunction<Reference> function() { return function; }
@Override
public List<ExpressionNode> children() {
@@ -48,7 +49,7 @@ public class TensorFunctionNode extends CompositeNode {
.collect(Collectors.toList());
}
- private ExpressionNode toExpressionNode(TensorFunction f) {
+ private ExpressionNode toExpressionNode(TensorFunction<Reference> f) {
if (f instanceof ExpressionTensorFunction)
return ((ExpressionTensorFunction)f).expression;
else
@@ -57,9 +58,9 @@ public class TensorFunctionNode extends CompositeNode {
@Override
public CompositeNode setChildren(List<ExpressionNode> children) {
- List<TensorFunction> wrappedChildren = children.stream()
- .map(ExpressionTensorFunction::new)
- .collect(Collectors.toList());
+ List<TensorFunction<Reference>> wrappedChildren = children.stream()
+ .map(ExpressionTensorFunction::new)
+ .collect(Collectors.toList());
return new TensorFunctionNode(function.withArguments(wrappedChildren));
}
@@ -81,21 +82,22 @@ public class TensorFunctionNode extends CompositeNode {
return new ExpressionTensorFunction(node);
}
- public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) {
- Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>();
+ public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) {
+ Map<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<>();
for (var entry : nodes.entrySet())
- functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue()));
+ functions.put(entry.getKey(), wrapScalar(entry.getValue()));
return functions;
}
- public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) {
- List<ScalarFunction> functions = new ArrayList<>();
- for (var entry : nodes)
- functions.add(new ExpressionScalarFunction(entry));
- return functions;
+ public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) {
+ return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList());
+ }
+
+ public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) {
+ return new ExpressionScalarFunction(node);
}
- private static class ExpressionScalarFunction implements ScalarFunction {
+ private static class ExpressionScalarFunction implements ScalarFunction<Reference> {
private final ExpressionNode expression;
@@ -104,8 +106,8 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public Double apply(EvaluationContext<?> context) {
- return expression.evaluate((Context)context).asDouble();
+ public Double apply(EvaluationContext<Reference> context) {
+ return expression.evaluate(new ContextWrapper(context)).asDouble();
}
@Override
@@ -130,7 +132,7 @@ public class TensorFunctionNode extends CompositeNode {
* A tensor function implemented by an expression.
* This allows us to pass expressions as tensor function arguments.
*/
- public static class ExpressionTensorFunction extends PrimitiveTensorFunction {
+ public static class ExpressionTensorFunction extends PrimitiveTensorFunction<Reference> {
/** An expression which produces a tensor */
private final ExpressionNode expression;
@@ -140,7 +142,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public List<TensorFunction> arguments() {
+ public List<TensorFunction<Reference>> arguments() {
if (expression instanceof CompositeNode)
return ((CompositeNode)expression).children().stream()
.map(ExpressionTensorFunction::new)
@@ -150,7 +152,7 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) {
if (arguments.size() == 0) return this;
List<ExpressionNode> unwrappedChildren = arguments.stream()
.map(arg -> ((ExpressionTensorFunction)arg).expression)
@@ -159,16 +161,15 @@ public class TensorFunctionNode extends CompositeNode {
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<Reference> toPrimitive() { return this; }
@Override
- @SuppressWarnings("unchecked") // Generics awkwardness
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
- return expression.type((TypeContext<Reference>)context);
+ public TensorType type(TypeContext<Reference> context) {
+ return expression.type(context);
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<Reference> context) {
return expression.evaluate((Context)context).asTensor();
}
@@ -209,4 +210,26 @@ public class TensorFunctionNode extends CompositeNode {
}
+ /** Turns an EvaluationContext into a Context */
+ // TODO: We should be able to change RankingExpression.evaluate to take an EvaluationContext and then get rid of this
+ private static class ContextWrapper extends Context {
+
+ private final EvaluationContext<Reference> delegate;
+
+ public ContextWrapper(EvaluationContext<Reference> delegate) {
+ this.delegate = delegate;
+ }
+
+ @Override
+ public Value get(String name) {
+ return new TensorValue(delegate.getTensor(name));
+ }
+
+ @Override
+ public TensorType getType(Reference name) {
+ return delegate.getType(name);
+ }
+
+ }
+
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
index 9a38b5efc1f..9bed4a4ea7c 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java
@@ -87,7 +87,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E
Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name());
String dimension = ((ReferenceNode) arg2).getName();
- return new TensorFunctionNode(new Reduce(expression, aggregator, dimension));
+ return new TensorFunctionNode(new Reduce<>(expression, aggregator, dimension));
}
}
diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj
index 01eed897bfd..c7870182939 100755
--- a/searchlib/src/main/javacc/RankingExpressionParser.jj
+++ b/searchlib/src/main/javacc/RankingExpressionParser.jj
@@ -18,7 +18,6 @@ PARSER_BEGIN(RankingExpressionParser)
package com.yahoo.searchlib.rankingexpression.parser;
import com.yahoo.searchlib.rankingexpression.rule.*;
-import com.yahoo.searchlib.rankingexpression.evaluation.Value;
import com.yahoo.searchlib.rankingexpression.evaluation.StringValue;
import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
import com.yahoo.tensor.*;
@@ -231,26 +230,28 @@ TruthOperator comparator() : { }
ExpressionNode value() :
{
- ExpressionNode ret;
+ ExpressionNode value;
boolean neg = false;
boolean not = false;
+ List valueAddress;
}
{
(
[ <NOT> { not = true; } ]
[ LOOKAHEAD(2) <SUB> { neg = true; } ]
- ( ret = constantPrimitive() |
- LOOKAHEAD(2) ret = ifExpression() |
- LOOKAHEAD(4) ret = function() |
- ret = feature() |
- ret = legacyQueryFeature() |
- ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) )
+ ( value = constantPrimitive() |
+ LOOKAHEAD(2) value = ifExpression() |
+ LOOKAHEAD(4) value = function() |
+ value = feature() |
+ value = legacyQueryFeature() |
+ ( <LBRACE> value = expression() <RBRACE> { value = new EmbracedNode(value); } ) )
)
+ [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Value(TensorFunctionNode.wrap(value), valueAddress)); } ]
{
- ret = not ? new NotNode(ret) : ret;
- ret = neg ? new NegativeNode(ret) : ret;
- return ret;
+ value = not ? new NotNode(value) : value;
+ value = neg ? new NegativeNode(value) : value;
+ return value;
}
}
@@ -323,12 +324,12 @@ List<ExpressionNode> args() :
{ return arguments; }
}
-// TODO: Replace use of this for macro arguments with value()
+// TODO: Replace use of this for function arguments with value()
// For that to work with the current search execution framework
-// we need to generate another macro for the argument such that we can replace
-// instances of the argument with the reference to that macro in the same way
+// we need to generate another function for the argument such that we can replace
+// instances of the argument with the reference to that function in the same way
// as we replace by constants/names today (this can make for some fun combinatorial explosion).
-// Simon also points out that we should stop doing macro expansion in the toString of a macro.
+// We should also stop doing function expansion in the toString of a function.
// - Jon 2014-05-02
ExpressionNode arg() :
{
@@ -368,9 +369,9 @@ FunctionNode scalarOrTensorFunction() :
)
}
-ExpressionNode tensorFunction() :
+TensorFunctionNode tensorFunction() :
{
- ExpressionNode tensorExpression;
+ TensorFunctionNode tensorExpression;
}
{
(
@@ -395,7 +396,7 @@ ExpressionNode tensorFunction() :
{ return tensorExpression; }
}
-ExpressionNode tensorMap() :
+TensorFunctionNode tensorMap() :
{
ExpressionNode tensor;
LambdaFunctionNode doubleMapper;
@@ -403,10 +404,10 @@ ExpressionNode tensorMap() :
{
<MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE>
{ return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor),
- doubleMapper.asDoubleUnaryOperator())); }
+ doubleMapper.asDoubleUnaryOperator())); }
}
-ExpressionNode tensorReduce() :
+TensorFunctionNode tensorReduce() :
{
ExpressionNode tensor;
Reduce.Aggregator aggregator;
@@ -417,7 +418,7 @@ ExpressionNode tensorReduce() :
{ return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
-ExpressionNode tensorReduceComposites() :
+TensorFunctionNode tensorReduceComposites() :
{
ExpressionNode tensor;
Reduce.Aggregator aggregator;
@@ -429,7 +430,7 @@ ExpressionNode tensorReduceComposites() :
{ return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); }
}
-ExpressionNode tensorJoin() :
+TensorFunctionNode tensorJoin() :
{
ExpressionNode tensor1, tensor2;
LambdaFunctionNode doubleJoiner;
@@ -441,7 +442,7 @@ ExpressionNode tensorJoin() :
doubleJoiner.asDoubleBinaryOperator())); }
}
-ExpressionNode tensorRename() :
+TensorFunctionNode tensorRename() :
{
ExpressionNode tensor;
List<String> fromDimensions, toDimensions;
@@ -454,7 +455,7 @@ ExpressionNode tensorRename() :
{ return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); }
}
-ExpressionNode tensorConcat() :
+TensorFunctionNode tensorConcat() :
{
ExpressionNode tensor1, tensor2;
String dimension;
@@ -466,10 +467,10 @@ ExpressionNode tensorConcat() :
dimension)); }
}
-ExpressionNode tensorGenerate() :
+TensorFunctionNode tensorGenerate() :
{
TensorType type;
- ExpressionNode expression;
+ TensorFunctionNode expression;
}
{
<TENSOR> type = tensorType()
@@ -480,16 +481,16 @@ ExpressionNode tensorGenerate() :
{ return expression; }
}
-ExpressionNode tensorGenerateBody(TensorType type) :
+TensorFunctionNode tensorGenerateBody(TensorType type) :
{
ExpressionNode generator;
}
{
<LBRACE> generator = expression() <RBRACE>
- { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); }
+ { return new TensorFunctionNode(Generate.bound(type, TensorFunctionNode.wrapScalar(generator))); }
}
-ExpressionNode tensorRange() :
+TensorFunctionNode tensorRange() :
{
TensorType type;
}
@@ -498,7 +499,7 @@ ExpressionNode tensorRange() :
{ return new TensorFunctionNode(new Range(type)); }
}
-ExpressionNode tensorDiag() :
+TensorFunctionNode tensorDiag() :
{
TensorType type;
}
@@ -507,7 +508,7 @@ ExpressionNode tensorDiag() :
{ return new TensorFunctionNode(new Diag(type)); }
}
-ExpressionNode tensorRandom() :
+TensorFunctionNode tensorRandom() :
{
TensorType type;
}
@@ -516,7 +517,7 @@ ExpressionNode tensorRandom() :
{ return new TensorFunctionNode(new Random(type)); }
}
-ExpressionNode tensorL1Normalize() :
+TensorFunctionNode tensorL1Normalize() :
{
ExpressionNode tensor;
String dimension;
@@ -526,7 +527,7 @@ ExpressionNode tensorL1Normalize() :
{ return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
-ExpressionNode tensorL2Normalize() :
+TensorFunctionNode tensorL2Normalize() :
{
ExpressionNode tensor;
String dimension;
@@ -536,7 +537,7 @@ ExpressionNode tensorL2Normalize() :
{ return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); }
}
-ExpressionNode tensorMatmul() :
+TensorFunctionNode tensorMatmul() :
{
ExpressionNode tensor1, tensor2;
String dimension;
@@ -548,7 +549,7 @@ ExpressionNode tensorMatmul() :
dimension)); }
}
-ExpressionNode tensorSoftmax() :
+TensorFunctionNode tensorSoftmax() :
{
ExpressionNode tensor;
String dimension;
@@ -558,7 +559,7 @@ ExpressionNode tensorSoftmax() :
{ return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
-ExpressionNode tensorXwPlusB() :
+TensorFunctionNode tensorXwPlusB() :
{
ExpressionNode tensor1, tensor2, tensor3;
String dimension;
@@ -574,7 +575,7 @@ ExpressionNode tensorXwPlusB() :
dimension)); }
}
-ExpressionNode tensorArgmax() :
+TensorFunctionNode tensorArgmax() :
{
ExpressionNode tensor;
String dimension;
@@ -584,7 +585,7 @@ ExpressionNode tensorArgmax() :
{ return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); }
}
-ExpressionNode tensorArgmin() :
+TensorFunctionNode tensorArgmin() :
{
ExpressionNode tensor;
String dimension;
@@ -811,20 +812,20 @@ ConstantNode constantPrimitive() :
( <INTEGER> { value = token.image; } |
<FLOAT> { value = token.image; } |
<STRING> { value = token.image; } )
- { return new ConstantNode(Value.parse(sign + value),sign + value); }
+ { return new ConstantNode(com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + value),sign + value); }
}
-Value primitiveValue() :
+com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue() :
{
String sign = "";
}
{
( <SUB> { sign = "-";} ) ?
( <INTEGER> | <FLOAT> | <STRING> )
- { return Value.parse(sign + token.image); }
+ { return com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + token.image); }
}
-ExpressionNode tensorValueBody(TensorType type) :
+TensorFunctionNode tensorValueBody(TensorType type) :
{
DynamicTensor dynamicTensor;
}
@@ -846,7 +847,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) :
( tensorCell(type, cells))*
( <COMMA> tensorCell(type, cells))*
<RCURLY>
- { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
DynamicTensor indexedTensorValueBody(TensorType type) :
@@ -859,7 +860,7 @@ DynamicTensor indexedTensorValueBody(TensorType type) :
( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )*
// <RSQUARE>
- { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); }
+ { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); }
}
void tensorCell(TensorType type, java.util.Map cells) :
@@ -882,4 +883,50 @@ void labelAndDimension(TensorAddress.Builder addressBuilder) :
{
dimension = identifier() <COLON> label = tag()
{ addressBuilder.add(dimension, label); }
+}
+
+void labelAndDimensionValues(List addressValues) :
+{
+ String dimension;
+ Value.DimensionValue dimensionValue;
+}
+{
+ dimension = identifier() <COLON> dimensionValue = dimensionValue(Optional.of(dimension))
+ { addressValues.add(dimensionValue); }
+}
+
+/** A tensor address (possibly on short form) represented as a list because the tensor type is not available */
+List valueAddress() :
+{
+ List dimensionValues = new ArrayList();
+ ExpressionNode valueExpression;
+ Value.DimensionValue dimensionValue;
+}
+{
+ (
+ ( <LSQUARE> ( valueExpression = expression() { dimensionValues.add(new Value.DimensionValue(TensorFunctionNode.wrapScalar(valueExpression))); } ) <RSQUARE> )
+ |
+ LOOKAHEAD(3) ( <LCURLY>
+ ( labelAndDimensionValues(dimensionValues))+
+ ( <COMMA> labelAndDimensionValues(dimensionValues))*
+ <RCURLY>
+ )
+ |
+ ( <LCURLY> dimensionValue = dimensionValue(Optional.empty()) { dimensionValues.add(dimensionValue); } <RCURLY> )
+ )
+ { return dimensionValues;}
+}
+
+Value.DimensionValue dimensionValue(Optional dimensionName) :
+{
+ ExpressionNode value;
+}
+{
+ value = expression()
+ {
+ if (value instanceof ReferenceNode && ((ReferenceNode)value).reference().isIdentifier())
+ return new Value.DimensionValue(dimensionName, ((ReferenceNode)value).reference().name());
+ else
+ return new Value.DimensionValue(dimensionName, TensorFunctionNode.wrapScalar(value));
+ }
} \ No newline at end of file
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
index e7024b87452..26fcec9efba 100755
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java
@@ -61,7 +61,7 @@ public class RankingExpressionTestCase {
ReferenceNode input = new ReferenceNode("input");
ReferenceNode constant = new ReferenceNode("constant");
ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant);
- Reduce sum = new Reduce(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
+ Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum);
RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum));
RankingExpression expected = new RankingExpression("sum(input * constant)");
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
index 6064035702e..99047aeb79d 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java
@@ -317,6 +317,12 @@ public class EvaluationTestCase {
tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])");
tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)");
+ // tensor value
+ tester.assertEvaluates("3.0", "tensor0{x:1}", "{ {x:0}:1, {x:1}:3 }");
+ tester.assertEvaluates("1.2", "tensor0{key:foo,x:0}", true, "{ {key:foo,x:0}:1.2, {key:bar,x:0}:3 }");
+ tester.assertEvaluates("3.0", "tensor0{bar}", true, "{ {x:foo}:1, {x:bar}:3 }");
+ tester.assertEvaluates("3.3", "tensor0[2]", "tensor(values[4]):[1.1, 2.2, 3.3, 4.4]]");
+
// composite functions
tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }");
tester.assertEvaluates("{ {x:0}:0.31622776601683794, {x:1}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }");
@@ -349,6 +355,18 @@ public class EvaluationTestCase {
tester.assertEvaluates("0",
"reduce(join(tensor0, tensor1, f(x,y) (if(x > y, 1.0, 0.0))), sum, tag) == reduce(tensor0, count, tag)",
"tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}");
+ tester.assertEvaluates("500",
+ "join(tensor0, tensor1, f(x,y) (x*y)){tag2}",
+ "tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}");
+ tester.assertEvaluates("tensor(j[3]):[3, 3, 3]",
+ "tensor(j[3])(tensor0[2])",
+ "tensor(values[5]):[1, 2, 3, 4, 5]");
+ tester.assertEvaluates("tensor(j[3]):[5, 4, 3]",
+ "tensor(j[3])(tensor0[4-j])",
+ "tensor(values[5]):[1, 2, 3, 4, 5]");
+ tester.assertEvaluates("tensor(j[2]):[6, 5]",
+ "tensor(j[2])(tensor0{key:bar,i:2-j})",
+ "tensor(key{},i[5]):{{key:foo,i:0}:1,{key:foo,i:1}:2,{key:foo,i:2}:2,{key:bar,i:0}:4,{key:bar,i:1}:5,{key:bar,i:2}:6}");
// tensor result dimensions are given from argument dimensions, not the resulting values
tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }");
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 4ec1a5a234b..59474021de2 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1448,12 +1448,12 @@
"public void <init>()",
"public void put(java.lang.String, com.yahoo.tensor.Tensor)",
"public com.yahoo.tensor.TensorType getType(java.lang.String)",
- "public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)",
+ "public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
"public com.yahoo.tensor.Tensor getTensor(java.lang.String)"
],
"fields": []
},
- "com.yahoo.tensor.evaluation.TypeContext$Name": {
+ "com.yahoo.tensor.evaluation.Name": {
"superClass": "java.lang.Object",
"interfaces": [],
"attributes": [
@@ -1477,7 +1477,7 @@
"abstract"
],
"methods": [
- "public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)",
+ "public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)",
"public abstract com.yahoo.tensor.TensorType getType(java.lang.String)"
],
"fields": []
@@ -1620,6 +1620,8 @@
],
"methods": [
"public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)",
+ "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)",
+ "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)",
"public java.util.List arguments()",
"public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)",
"public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
@@ -2506,6 +2508,47 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.Value$DimensionValue": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(java.lang.String, java.lang.String)",
+ "public void <init>(java.lang.String, int)",
+ "public void <init>(int)",
+ "public void <init>(java.lang.String)",
+ "public void <init>(com.yahoo.tensor.functions.ScalarFunction)",
+ "public void <init>(java.util.Optional, java.lang.String)",
+ "public void <init>(java.util.Optional, com.yahoo.tensor.functions.ScalarFunction)",
+ "public void <init>(java.lang.String, com.yahoo.tensor.functions.ScalarFunction)",
+ "public java.util.Optional dimension()",
+ "public java.util.Optional label()",
+ "public java.util.Optional index()",
+ "public java.lang.String toString()"
+ ],
+ "fields": []
+ },
+ "com.yahoo.tensor.functions.Value": {
+ "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.Value withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public java.lang.String toString()",
+ "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.XwPlusB": {
"superClass": "com.yahoo.tensor.functions.CompositeTensorFunction",
"interfaces": [],
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index afd82751137..b8ef84cabb7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -1,6 +1,7 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.evaluation.TypeContext;
import com.yahoo.tensor.functions.Argmax;
import com.yahoo.tensor.functions.Argmin;
import com.yahoo.tensor.functions.Concat;
@@ -37,7 +38,7 @@ import java.util.function.Function;
* Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines
* the location of that cell. Both dimensions and labels are string on the form of an identifier or integer.
* <p>
- * The size of the set of dimensions of a tensor is called its <i>order</i>.
+ * The size of the set of dimensions of a tensor is called its <i>rank</i>.
* <p>
* In contrast to regular mathematical formulations of tensors, this definition of a tensor allows <i>sparseness</i>
* as there is no built-in notion of a contiguous space, and even in cases where a space is implied (such as when
@@ -144,25 +145,25 @@ public interface Tensor {
// ----------------- Primitive tensor functions
default Tensor map(DoubleUnaryOperator mapper) {
- return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate();
+ return new com.yahoo.tensor.functions.Map<>(new ConstantTensor<>(this), mapper).evaluate();
}
/** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) {
- return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate();
+ return new Reduce<>(new ConstantTensor<>(this), aggregator, Arrays.asList(dimensions)).evaluate();
}
/** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */
default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) {
- return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate();
+ return new Reduce<>(new ConstantTensor<>(this), aggregator, dimensions).evaluate();
}
default Tensor join(Tensor argument, DoubleBinaryOperator combinator) {
- return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate();
+ return new Join<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate();
}
default Tensor rename(String fromDimension, String toDimension) {
- return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension),
- Collections.singletonList(toDimension)).evaluate();
+ return new Rename<>(new ConstantTensor<>(this), Collections.singletonList(fromDimension),
+ Collections.singletonList(toDimension)).evaluate();
}
default Tensor concat(double argument, String dimension) {
@@ -170,50 +171,50 @@ public interface Tensor {
}
default Tensor concat(Tensor argument, String dimension) {
- return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
+ return new Concat<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), dimension).evaluate();
}
default Tensor rename(List<String> fromDimensions, List<String> toDimensions) {
- return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate();
+ return new Rename<>(new ConstantTensor<>(this), fromDimensions, toDimensions).evaluate();
}
static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) {
- return new Generate(type, valueSupplier).evaluate();
+ return new Generate<>(type, valueSupplier).evaluate();
}
// ----------------- Composite tensor functions which have a defined primitive mapping
default Tensor l1Normalize(String dimension) {
- return new L1Normalize(new ConstantTensor(this), dimension).evaluate();
+ return new L1Normalize<>(new ConstantTensor<>(this), dimension).evaluate();
}
default Tensor l2Normalize(String dimension) {
- return new L2Normalize(new ConstantTensor(this), dimension).evaluate();
+ return new L2Normalize<>(new ConstantTensor<>(this), dimension).evaluate();
}
default Tensor matmul(Tensor argument, String dimension) {
- return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate();
+ return new Matmul<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), dimension).evaluate();
}
default Tensor softmax(String dimension) {
- return new Softmax(new ConstantTensor(this), dimension).evaluate();
+ return new Softmax<>(new ConstantTensor<>(this), dimension).evaluate();
}
default Tensor xwPlusB(Tensor w, Tensor b, String dimension) {
- return new XwPlusB(new ConstantTensor(this), new ConstantTensor(w), new ConstantTensor(b), dimension).evaluate();
+ return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate();
}
default Tensor argmax(String dimension) {
- return new Argmax(new ConstantTensor(this), dimension).evaluate();
+ return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate();
}
- default Tensor argmin(String dimension) { return new Argmin(new ConstantTensor(this), dimension).evaluate(); }
+ default Tensor argmin(String dimension) { return new Argmin<>(new ConstantTensor<>(this), dimension).evaluate(); }
- static Tensor diag(TensorType type) { return new Diag(type).evaluate(); }
+ static Tensor diag(TensorType type) { return new Diag<>(type).evaluate(); }
- static Tensor random(TensorType type) { return new Random(type).evaluate(); }
+ static Tensor random(TensorType type) { return new Random<>(type).evaluate(); }
- static Tensor range(TensorType type) { return new Range(type).evaluate(); }
+ static Tensor range(TensorType type) { return new Range<>(type).evaluate(); }
// ----------------- Composite tensor functions mapped to primitives here on the fly
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 30f7185959c..3812dd26370 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -110,7 +110,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return Long.parseLong(labels[i]);
}
catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i);
+ throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'");
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
index 9ec105f8174..6e6b42cc1cd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.Tensor;
*
* @author bratseth
*/
-public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> {
+public interface EvaluationContext<NAMETYPE extends Name> extends TypeContext<NAMETYPE> {
/** Returns the tensor bound to this name, or null if none */
Tensor getTensor(String name);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
index e302e317418..f684987476f 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java
@@ -9,7 +9,7 @@ import java.util.HashMap;
/**
* @author bratseth
*/
-public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> {
+public class MapEvaluationContext<NAMETYPE extends Name> implements EvaluationContext<NAMETYPE> {
private final java.util.Map<String, Tensor> bindings = new HashMap<>();
@@ -17,14 +17,14 @@ public class MapEvaluationContext implements EvaluationContext<TypeContext.Name>
@Override
public TensorType getType(String name) {
- return getType(new Name(name));
+ Tensor tensor = bindings.get(name);
+ if (tensor == null) return null;
+ return tensor.type();
}
@Override
- public TensorType getType(Name name) {
- Tensor tensor = bindings.get(name.toString());
- if (tensor == null) return null;
- return tensor.type();
+ public TensorType getType(NAMETYPE name) {
+ return getType(name.name());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java
new file mode 100644
index 00000000000..9033af1d7ec
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java
@@ -0,0 +1,28 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.evaluation;
+
+/** A name which is just a string. Names are value objects. */
+public class Name {
+
+ private final String name;
+
+ public Name(String name) {
+ this.name = name;
+ }
+
+ public String name() { return name; }
+
+ @Override
+ public String toString() { return name; }
+
+ @Override
+ public int hashCode() { return name.hashCode(); }
+
+ @Override
+ public boolean equals(Object other) {
+ if (other == this) return true;
+ if ( ! (other instanceof Name)) return false;
+ return ((Name)other).name.equals(this.name);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
index 1437fd91974..84d82b624ba 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java
@@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType;
*
* @author bratseth
*/
-public interface TypeContext<NAMETYPE extends TypeContext.Name> {
+public interface TypeContext<NAMETYPE extends Name> {
/**
* Returns the type of the tensor with this name.
@@ -26,31 +26,5 @@ public interface TypeContext<NAMETYPE extends TypeContext.Name> {
*/
TensorType getType(String name);
- /** A name which is just a string. Names are value objects. */
- class Name {
-
- private final String name;
-
- public Name(String name) {
- this.name = name;
- }
-
- public String name() { return name; }
-
- @Override
- public String toString() { return name; }
-
- @Override
- public int hashCode() { return name.hashCode(); }
-
- @Override
- public boolean equals(Object other) {
- if (other == this) return true;
- if ( ! (other instanceof Name)) return false;
- return ((Name)other).name.equals(this.name);
- }
-
- }
-
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
index c1cfa319664..8ea82aa4a79 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java
@@ -16,7 +16,7 @@ import java.util.Optional;
*
* @author bratseth
*/
-public class VariableTensor extends PrimitiveTensorFunction {
+public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final String name;
private final Optional<TensorType> requiredType;
@@ -33,16 +33,16 @@ public class VariableTensor extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) { return this; }
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { return this; }
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
TensorType givenType = context.getType(name);
if (givenType == null) return null;
verifyType(givenType);
@@ -50,7 +50,7 @@ public class VariableTensor extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = context.getTensor(name);
if (tensor == null) return null;
verifyType(tensor.type());
@@ -67,4 +67,5 @@ public class VariableTensor extends PrimitiveTensorFunction {
throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " +
requiredType.get() + " but was " + givenType);
}
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
index 3478061b32c..a365f0f4bdc 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java
@@ -1,38 +1,40 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class Argmax extends CompositeTensorFunction {
+public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Argmax(TensorFunction argument, String dimension) {
+ public Argmax(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size());
- return new Argmax(arguments.get(0), dimension);
+ return new Argmax<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension),
- ScalarFunctions.equal());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(primitiveArgument,
+ new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimension),
+ ScalarFunctions.equal());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
index ba5b3c3e4b2..32ccdf51336 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java
@@ -1,38 +1,40 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class Argmin extends CompositeTensorFunction {
+public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Argmin(TensorFunction argument, String dimension) {
+ public Argmin(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size());
- return new Argmin(arguments.get(0), dimension);
+ return new Argmin<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension),
- ScalarFunctions.equal());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(primitiveArgument,
+ new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension),
+ ScalarFunctions.equal());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
index 5dd2cc442aa..eacc4493035 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
/**
@@ -12,17 +13,17 @@ import com.yahoo.tensor.evaluation.TypeContext;
*
* @author bratseth
*/
-public abstract class CompositeTensorFunction extends TensorFunction {
+public abstract class CompositeTensorFunction<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
/** Finds the type this produces by first converting it to a primitive function */
@Override
- public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public final TensorType type(TypeContext<NAMETYPE> context) {
return toPrimitive().type(context);
}
/** Evaluates this by first converting it to a primitive function */
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return toPrimitive().evaluate(context);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 42c6fe2f4aa..fff2ddaf320 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Arrays;
@@ -23,12 +24,12 @@ import java.util.stream.Collectors;
*
* @author bratseth
*/
-public class Concat extends PrimitiveTensorFunction {
+public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final String dimension;
- public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) {
+ public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(dimension, "The dimension cannot be null");
@@ -38,18 +39,18 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 2)
throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size());
- return new Concat(arguments.get(0), arguments.get(1), dimension);
+ return new Concat<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension);
}
@Override
@@ -58,7 +59,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argumentA.type(context), argumentB.type(context));
}
@@ -86,7 +87,7 @@ public class Concat extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
index 7c1ce068c90..bb7481f7c64 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java
@@ -4,6 +4,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
@@ -14,7 +15,7 @@ import java.util.List;
*
* @author bratseth
*/
-public class ConstantTensor extends PrimitiveTensorFunction {
+public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final Tensor constant;
@@ -27,23 +28,23 @@ public class ConstantTensor extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
+ public TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; }
@Override
public String toString(ToStringContext context) { return constant.toString(); }
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
index e302f6606e7..203331a1c0d 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
@@ -14,7 +15,7 @@ import java.util.stream.Stream;
*
* @author bratseth
*/
-public class Diag extends CompositeTensorFunction {
+public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorType type;
private final Function<List<Long>, Double> diagFunction;
@@ -25,18 +26,18 @@ public class Diag extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Generate(type, diagFunction);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Generate<>(type, diagFunction);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
index b8b644f8b49..416940a60eb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java
@@ -7,20 +7,19 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Map;
-import java.util.function.Function;
/**
* A function which is a tensor whose values are computed by individual lambda functions on evaluation.
*
* @author bratseth
*/
-public abstract class DynamicTensor extends PrimitiveTensorFunction {
+public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final TensorType type;
@@ -29,20 +28,20 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+ public TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 0)
throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
TensorType type() { return type; }
@@ -54,26 +53,26 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
abstract String contentToString(ToStringContext context);
/** Creates a dynamic tensor function. The cell addresses must match the type. */
- public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
- return new MappedDynamicTensor(type, cells);
+ public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
+ return new MappedDynamicTensor<>(type, cells);
}
/** Creates a dynamic tensor function for a bound, indexed tensor */
- public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) {
- return new IndexedDynamicTensor(type, cells);
+ public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, List<ScalarFunction<NAMETYPE>> cells) {
+ return new IndexedDynamicTensor<>(type, cells);
}
- private static class MappedDynamicTensor extends DynamicTensor {
+ private static class MappedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> {
- private final ImmutableMap<TensorAddress, ScalarFunction> cells;
+ private final ImmutableMap<TensorAddress, ScalarFunction<NAMETYPE>> cells;
- MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) {
+ MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) {
super(type);
this.cells = ImmutableMap.copyOf(cells);
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type());
for (var cell : cells.entrySet())
builder.cell(cell.getKey(), cell.getValue().apply(context));
@@ -101,11 +100,11 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
- private static class IndexedDynamicTensor extends DynamicTensor {
+ private static class IndexedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> {
- private final List<ScalarFunction> cells;
+ private final List<ScalarFunction<NAMETYPE>> cells;
- IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) {
+ IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) {
super(type);
if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound))
throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " +
@@ -114,7 +113,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type());
for (int i = 0; i < cells.size(); i++)
builder.cellByDirectIndex(i, cells.get(i).apply(context));
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
index 83cba3479e2..e5095178be7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java
@@ -6,11 +6,13 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
import java.util.List;
import java.util.Objects;
+import java.util.Optional;
import java.util.function.Function;
/**
@@ -18,25 +20,49 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public class Generate extends PrimitiveTensorFunction {
+public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
private final TensorType type;
- private final Function<List<Long>, Double> generator;
+
+ // One of these are null
+ private final Function<List<Long>, Double> freeGenerator;
+ private final ScalarFunction<NAMETYPE> boundGenerator;
+
+ /** The same as Generate.free */
+ public Generate(TensorType type, Function<List<Long>, Double> generator) {
+ this(type, Objects.requireNonNull(generator), null);
+ }
/**
- * Creates a generated tensor
+ * Creates a generated tensor from a free function
*
* @param type the type of the tensor
* @param generator the function generating values from a list of numbers specifying the indexes of the
* tensor cell which will receive the value
* @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
*/
- public Generate(TensorType type, Function<List<Long>, Double> generator) {
+ public static <NAMETYPE extends Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) {
+ return new Generate<>(type, Objects.requireNonNull(generator), null);
+ }
+
+ /**
+ * Creates a generated tensor from a bound function
+ *
+ * @param type the type of the tensor
+ * @param generator the function generating values from a list of numbers specifying the indexes of the
+ * tensor cell which will receive the value
+ * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound
+ */
+ public static <NAMETYPE extends Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) {
+ return new Generate<>(type, null, Objects.requireNonNull(generator));
+ }
+
+ private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) {
Objects.requireNonNull(type, "The argument tensor type cannot be null");
- Objects.requireNonNull(generator, "The argument function cannot be null");
validateType(type);
this.type = type;
- this.generator = generator;
+ this.freeGenerator = freeGenerator;
+ this.boundGenerator = boundGenerator;
}
private void validateType(TensorType type) {
@@ -46,28 +72,29 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; }
+ public TensorType type(TypeContext<NAMETYPE> context) { return type; }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor.Builder builder = Tensor.Builder.of(type);
IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type));
+ GenerateContext generateContext = new GenerateContext(type, context);
for (int i = 0; i < indexes.size(); i++) {
indexes.next();
- builder.cell(generator.apply(indexes.toList()), indexes.indexesForReading());
+ builder.cell(generateContext.apply(indexes), indexes.indexesForReading());
}
return builder.build();
}
@@ -80,6 +107,70 @@ public class Generate extends PrimitiveTensorFunction {
}
@Override
- public String toString(ToStringContext context) { return type + "(" + generator + ")"; }
+ public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; }
+
+ private String generatorToString(ToStringContext context) {
+ if (freeGenerator != null)
+ return freeGenerator.toString();
+ else
+ return boundGenerator.toString(context);
+ }
+
+ /**
+ * A context for generating all the values of a tensor produced by evaluating Generate.
+ * This returns all the current index values as variables and falls back to delivering from the given
+ * evaluation context.
+ */
+ private class GenerateContext implements EvaluationContext<NAMETYPE> {
+
+ private final TensorType type;
+ private final EvaluationContext<NAMETYPE> context;
+
+ private IndexedTensor.Indexes indexes;
+
+ GenerateContext(TensorType type, EvaluationContext<NAMETYPE> context) {
+ this.type = type;
+ this.context = context;
+ }
+
+ @SuppressWarnings("unchecked")
+ double apply(IndexedTensor.Indexes indexes) {
+ if (freeGenerator != null) {
+ return freeGenerator.apply(indexes.toList());
+ }
+ else {
+ this.indexes = indexes;
+ return boundGenerator.apply(this);
+ }
+ }
+
+ @Override
+ public Tensor getTensor(String name) {
+ Optional<Integer> index = type.indexOfDimension(name);
+ if (index.isPresent()) // this is the name of a dimension
+ return Tensor.from(indexes.indexesForReading()[index.get()]);
+ else
+ return context.getTensor(name);
+ }
+
+ @Override
+ public TensorType getType(NAMETYPE name) {
+ Optional<Integer> index = type.indexOfDimension(name.name());
+ if (index.isPresent()) // this is the name of a dimension
+ return TensorType.empty;
+ else
+ return context.getType(name);
+ }
+
+ @Override
+ public TensorType getType(String name) {
+ Optional<Integer> index = type.indexOfDimension(name);
+ if (index.isPresent()) // this is the name of a dimension
+ return TensorType.empty;
+ else
+ return context.getType(name);
+ }
+
+ }
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 2939b964f04..1e0eaa7fad3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -10,6 +10,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.ArrayList;
@@ -31,12 +32,12 @@ import java.util.function.DoubleBinaryOperator;
*
* @author bratseth
*/
-public class Join extends PrimitiveTensorFunction {
+public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final DoubleBinaryOperator combinator;
- public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) {
+ public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) {
Objects.requireNonNull(argumentA, "The first argument tensor cannot be null");
Objects.requireNonNull(argumentB, "The second argument tensor cannot be null");
Objects.requireNonNull(combinator, "The combinator function cannot be null");
@@ -53,18 +54,18 @@ public class Join extends PrimitiveTensorFunction {
public DoubleBinaryOperator combinator() { return combinator; }
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size());
- return new Join(arguments.get(0), arguments.get(1), combinator);
+ return new Join<>(arguments.get(0), arguments.get(1), combinator);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
}
@Override
@@ -73,12 +74,12 @@ public class Join extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build();
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
index 7939457a101..ed4da6678ce 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java
@@ -1,39 +1,41 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class L1Normalize extends CompositeTensorFunction {
+public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public L1Normalize(TensorFunction argument, String dimension) {
+ public L1Normalize(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size());
- return new L1Normalize(arguments.get(0), dimension);
+ return new L1Normalize<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
// join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y))
- return new Join(primitiveArgument,
- new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension),
- ScalarFunctions.divide());
+ return new Join<>(primitiveArgument,
+ new Reduce<>(primitiveArgument, Reduce.Aggregator.sum, dimension),
+ ScalarFunctions.divide());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
index 40edb8ba23f..93b2b377176 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java
@@ -1,40 +1,42 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+
import java.util.Collections;
import java.util.List;
/**
* @author bratseth
*/
-public class L2Normalize extends CompositeTensorFunction {
+public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public L2Normalize(TensorFunction argument, String dimension) {
+ public L2Normalize(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size());
- return new L2Normalize(arguments.get(0), dimension);
+ return new L2Normalize<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(primitiveArgument,
- new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.sqrt()),
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(primitiveArgument,
+ new Map<>(new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.square()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.sqrt()),
ScalarFunctions.divide());
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
index 016c60c6897..0ddf0bb4e63 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
@@ -18,12 +19,12 @@ import java.util.function.DoubleUnaryOperator;
*
* @author bratseth
*/
-public class Map extends PrimitiveTensorFunction {
+public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final DoubleUnaryOperator mapper;
- public Map(TensorFunction argument, DoubleUnaryOperator mapper) {
+ public Map(TensorFunction<NAMETYPE> argument, DoubleUnaryOperator mapper) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(mapper, "The argument function cannot be null");
this.argument = argument;
@@ -32,31 +33,31 @@ public class Map extends PrimitiveTensorFunction {
public static TensorType outputType(TensorType inputType) { return inputType; }
- public TensorFunction argument() { return argument; }
+ public TensorFunction<NAMETYPE> argument() { return argument; }
public DoubleUnaryOperator mapper() { return mapper; }
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size());
- return new Map(arguments.get(0), mapper);
+ return new Map<>(arguments.get(0), mapper);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Map(argument.toPrimitive(), mapper);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Map<>(argument.toPrimitive(), mapper);
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return argument.type(context);
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor argument = argument().evaluate(context);
Tensor.Builder builder = Tensor.Builder.of(argument.type());
for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
index 7c65afc98f9..54bfdd4a732 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java
@@ -3,18 +3,19 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
import java.util.List;
/**
* @author bratseth
*/
-public class Matmul extends CompositeTensorFunction {
+public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument1, argument2;
+ private final TensorFunction<NAMETYPE> argument1, argument2;
private final String dimension;
- public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) {
+ public Matmul(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) {
this.argument1 = argument1;
this.argument2 = argument2;
this.dimension = dimension;
@@ -25,22 +26,22 @@ public class Matmul extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argument1, argument2); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size());
- return new Matmul(arguments.get(0), arguments.get(1), dimension);
+ return new Matmul<>(arguments.get(0), arguments.get(1), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument1 = argument1.toPrimitive();
- TensorFunction primitiveArgument2 = argument2.toPrimitive();
- return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument1 = argument1.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveArgument2 = argument2.toPrimitive();
+ return new Reduce<>(new Join<>(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
index e2aae39f11f..99117bb250e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor.functions;
+import com.yahoo.tensor.evaluation.Name;
+
/**
* A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions.
* All tensor implementations must implement all primitive tensor functions.
@@ -8,6 +10,6 @@ package com.yahoo.tensor.functions;
*
* @author bratseth
*/
-public abstract class PrimitiveTensorFunction extends TensorFunction {
+public abstract class PrimitiveTensorFunction<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> {
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
index 7175c91ed33..b459b1a8ddd 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
@@ -13,7 +14,7 @@ import java.util.stream.Stream;
*
* @author bratseth
*/
-public class Random extends CompositeTensorFunction {
+public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorType type;
@@ -22,18 +23,18 @@ public class Random extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Generate(type, ScalarFunctions.random());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Generate<>(type, ScalarFunctions.random());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
index d951ec9ccbd..00d0e4b4818 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
@@ -15,7 +16,7 @@ import java.util.stream.Stream;
*
* @author bratseth
*/
-public class Range extends CompositeTensorFunction {
+public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
private final TensorType type;
private final Function<List<Long>, Double> rangeFunction;
@@ -26,18 +27,18 @@ public class Range extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.emptyList(); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 0)
throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size());
return this;
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Generate(type, rangeFunction);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Generate<>(type, rangeFunction);
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 017dc3920e6..1eb09a603fa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
@@ -24,21 +25,21 @@ import java.util.Set;
*
* @author bratseth
*/
-public class Reduce extends PrimitiveTensorFunction {
+public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
public enum Aggregator { avg, count, prod, sum, max, min; }
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final List<String> dimensions;
private final Aggregator aggregator;
/** Creates a reduce function reducing all dimensions */
- public Reduce(TensorFunction argument, Aggregator aggregator) {
+ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator) {
this(argument, aggregator, Collections.emptyList());
}
/** Creates a reduce function reducing a single dimension */
- public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) {
+ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, String dimension) {
this(argument, aggregator, Collections.singletonList(dimension));
}
@@ -51,7 +52,7 @@ public class Reduce extends PrimitiveTensorFunction {
* producing a dimensionless tensor (a scalar).
* @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor
*/
- public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) {
+ public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(aggregator, "The aggregator cannot be null");
Objects.requireNonNull(dimensions, "The dimensions cannot be null");
@@ -70,25 +71,25 @@ public class Reduce extends PrimitiveTensorFunction {
return b.build();
}
- public TensorFunction argument() { return argument; }
+ public TensorFunction<NAMETYPE> argument() { return argument; }
Aggregator aggregator() { return aggregator; }
List<String> dimensions() { return dimensions; }
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size());
- return new Reduce(arguments.get(0), aggregator, dimensions);
+ return new Reduce<>(arguments.get(0), aggregator, dimensions);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- return new Reduce(argument.toPrimitive(), aggregator, dimensions);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new Reduce<>(argument.toPrimitive(), aggregator, dimensions);
}
@Override
@@ -104,7 +105,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context), dimensions);
}
@@ -118,7 +119,7 @@ public class Reduce extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
return evaluate(this.argument.evaluate(context), dimensions, aggregator);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index 1134e8177ad..83807a20ec9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -7,7 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
-import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.evaluation.Name;
import java.util.Arrays;
import java.util.List;
@@ -26,19 +26,19 @@ import java.util.stream.Collectors;
*
* @author lesters
*/
-public class ReduceJoin extends CompositeTensorFunction {
+public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argumentA, argumentB;
+ private final TensorFunction<NAMETYPE> argumentA, argumentB;
private final DoubleBinaryOperator combinator;
private final Reduce.Aggregator aggregator;
private final List<String> dimensions;
- public ReduceJoin(Reduce reduce, Join join) {
+ public ReduceJoin(Reduce<NAMETYPE> reduce, Join<NAMETYPE> join) {
this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions());
}
- public ReduceJoin(TensorFunction argumentA,
- TensorFunction argumentB,
+ public ReduceJoin(TensorFunction<NAMETYPE> argumentA,
+ TensorFunction<NAMETYPE> argumentB,
DoubleBinaryOperator combinator,
Reduce.Aggregator aggregator,
List<String> dimensions) {
@@ -50,25 +50,25 @@ public class ReduceJoin extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() {
+ public List<TensorFunction<NAMETYPE>> arguments() {
return ImmutableList.of(argumentA, argumentB);
}
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 2)
throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size());
- return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
+ return new ReduceJoin<>(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
- return new Reduce(join, aggregator, dimensions);
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ Join<NAMETYPE> join = new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator);
+ return new Reduce<>(join, aggregator, dimensions);
}
@Override
- public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public final Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index 5694684956e..275b546c0aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -6,6 +6,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.Collections;
@@ -20,18 +21,18 @@ import java.util.Objects;
*
* @author bratseth
*/
-public class Rename extends PrimitiveTensorFunction {
+public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final List<String> fromDimensions;
private final List<String> toDimensions;
private final Map<String, String> fromToMap;
- public Rename(TensorFunction argument, String fromDimension, String toDimension) {
+ public Rename(TensorFunction<NAMETYPE> argument, String fromDimension, String toDimension) {
this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension));
}
- public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) {
+ public Rename(TensorFunction<NAMETYPE> argument, List<String> fromDimensions, List<String> toDimensions) {
Objects.requireNonNull(argument, "The argument tensor cannot be null");
Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null");
Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null");
@@ -57,20 +58,20 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size());
- return new Rename(arguments.get(0), fromDimensions, toDimensions);
+ return new Rename<>(arguments.get(0), fromDimensions, toDimensions);
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return type(argument.type(context));
}
@@ -82,7 +83,7 @@ public class Rename extends PrimitiveTensorFunction {
}
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
TensorType renamedType = type(tensor.type());
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
index c6a244b64df..07b3658fb58 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import java.util.function.Function;
@@ -10,10 +11,10 @@ import java.util.function.Function;
*
* @author bratseth
*/
-public interface ScalarFunction extends Function<EvaluationContext<?>, Double> {
+public interface ScalarFunction<NAMETYPE extends Name> extends Function<EvaluationContext<NAMETYPE>, Double> {
@Override
- Double apply(EvaluationContext<?> context);
+ Double apply(EvaluationContext<NAMETYPE> context);
default String toString(ToStringContext context) {
return toString();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
index bd732cdc11e..755711a4d44 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
import java.util.Collections;
import java.util.List;
@@ -10,12 +11,12 @@ import java.util.List;
/**
* @author bratseth
*/
-public class Softmax extends CompositeTensorFunction {
+public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
+ private final TensorFunction<NAMETYPE> argument;
private final String dimension;
- public Softmax(TensorFunction argument, String dimension) {
+ public Softmax(TensorFunction<NAMETYPE> argument, String dimension) {
this.argument = argument;
this.dimension = dimension;
}
@@ -25,23 +26,23 @@ public class Softmax extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return Collections.singletonList(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 1)
throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size());
- return new Softmax(arguments.get(0), dimension);
+ return new Softmax<>(arguments.get(0), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveArgument = argument.toPrimitive();
- return new Join(new Map(primitiveArgument, ScalarFunctions.exp()),
- new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()),
- Reduce.Aggregator.sum,
- dimension),
- ScalarFunctions.divide());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive();
+ return new Join<>(new Map<>(primitiveArgument, ScalarFunctions.exp()),
+ new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.exp()),
+ Reduce.Aggregator.sum,
+ dimension),
+ ScalarFunctions.divide());
}
@Override
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 810651bbcfb..b4c5dedbf4e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
import java.util.List;
@@ -16,17 +17,17 @@ import java.util.List;
*
* @author bratseth
*/
-public abstract class TensorFunction {
+public abstract class TensorFunction<NAMETYPE extends Name> {
/** Returns the function arguments of this node in the order they are applied */
- public abstract List<TensorFunction> arguments();
+ public abstract List<TensorFunction<NAMETYPE>> arguments();
/**
* Returns a copy of this tensor function with the arguments replaced by the given list of arguments.
*
* @throws IllegalArgumentException if the argument list has the wrong size for this function
*/
- public abstract TensorFunction withArguments(List<TensorFunction> arguments);
+ public abstract TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments);
/**
* Translate this function - and all of its arguments recursively -
@@ -34,24 +35,24 @@ public abstract class TensorFunction {
*
* @return a tree of primitive functions implementing this
*/
- public abstract PrimitiveTensorFunction toPrimitive();
+ public abstract PrimitiveTensorFunction<NAMETYPE> toPrimitive();
/**
* Evaluates this tensor.
*
* @param context a context which must be passed to all nested functions when evaluating
*/
- public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context);
+ public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context);
/**
* Returns the type of the tensor this produces given the input types in the context
*
* @param context a context which must be passed to all nexted functions when evaluating
*/
- public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context);
+ public abstract TensorType type(TypeContext<NAMETYPE> context);
/** Evaluate with no context */
- public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); }
+ public final Tensor evaluate() { return evaluate(new MapEvaluationContext<>()); }
/**
* Return a string representation of this context.
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
new file mode 100644
index 00000000000..cb14711c0dd
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
@@ -0,0 +1,185 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.stream.Collectors;
+
+/**
+ * Returns the value of a cell of a tensor (as a rank 0 tensor).
+ *
+ * @author bratseth
+ */
+@Beta
+public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final List<DimensionValue<NAMETYPE>> cellAddress;
+
+ /**
+ * Creates a value function
+ *
+ * @param argument the tensor to return a cell value from
+ * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress
+ * because those require a type, but a type is not resolved until this is evaluated
+ */
+ public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) {
+ this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
+ if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
+ throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " +
+ "Specify dimension names explicitly");
+ this.cellAddress = cellAddress;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
+
+ @Override
+ public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 1)
+ throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
+ return new Value<NAMETYPE>(arguments.get(0), cellAddress);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor tensor = argument.evaluate(context);
+ if (tensor.type().rank() != cellAddress.size())
+ throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() +
+ " to a tensor of type " + tensor.type());
+ TensorAddress.Builder b = new TensorAddress.Builder(tensor.type());
+ for (int i = 0; i < cellAddress.size(); i++) {
+ if (cellAddress.get(i).label().isPresent())
+ b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
+ cellAddress.get(i).label().get());
+ else
+ b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
+ String.valueOf(cellAddress.get(i).index().get().apply(context).intValue()));
+ }
+ return Tensor.from(tensor.get(b.build()));
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return new TensorType.Builder(argument.type(context).valueType()).build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ return toString();
+ }
+
+ @Override
+ public String toString() {
+ if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) {
+ if (cellAddress.get(0).index().isPresent())
+ return "[" + cellAddress.get(0).index().get() + "]";
+ else
+ return "{" + cellAddress.get(0).label() + "}";
+ }
+ else {
+ return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}";
+ }
+ }
+
+ public static class DimensionValue<NAMETYPE extends Name> {
+
+ private final Optional<String> dimension;
+
+ /** The label of this, or null if index is set */
+ private final String label;
+
+ /** The function returning the index of this, or null if label is set */
+ private final ScalarFunction<NAMETYPE> index;
+
+ public DimensionValue(String dimension, String label) {
+ this(Optional.of(dimension), label, null);
+ }
+
+ public DimensionValue(String dimension, int index) {
+ this(Optional.of(dimension), null, new ConstantScalarFunction<>(index));
+ }
+
+ public DimensionValue(int index) {
+ this(Optional.empty(), null, new ConstantScalarFunction<>(index));
+ }
+
+ public DimensionValue(String label) {
+ this(Optional.empty(), label, null);
+ }
+
+ public DimensionValue(ScalarFunction<NAMETYPE> index) {
+ this(Optional.empty(), null, index);
+ }
+
+ public DimensionValue(Optional<String> dimension, String label) {
+ this(dimension, label, null);
+ }
+
+ public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
+ this(dimension, null, index);
+ }
+
+ public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
+ this(Optional.of(dimension), null, index);
+ }
+
+ private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
+ this.dimension = dimension;
+ this.label = label;
+ this.index = index;
+ }
+
+ /**
+ * Returns the given name of the dimension, or null if dense form is used, such that name
+ * must be inferred from order
+ */
+ public Optional<String> dimension() { return dimension; }
+
+ /** Returns the label for this dimension or empty if it is provided by an index function */
+ public Optional<String> label() { return Optional.ofNullable(label); }
+
+ /** Returns the index expression for this dimension, or empty if it is not a number */
+ public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
+
+ @Override
+ public String toString() {
+ StringBuilder b = new StringBuilder();
+ dimension.ifPresent(d -> b.append(d).append(":"));
+ if (label != null)
+ b.append(label);
+ else
+ b.append(index);
+ return b.toString();
+ }
+
+ }
+
+ private static class ConstantScalarFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
+
+ private final Double value;
+
+ public ConstantScalarFunction(int value) {
+ this.value = (double)value;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<NAMETYPE> context) {
+ return value;
+ }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
index 4c0748ee39a..53e23674617 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java
@@ -2,18 +2,19 @@
package com.yahoo.tensor.functions;
import com.google.common.collect.ImmutableList;
+import com.yahoo.tensor.evaluation.Name;
import java.util.List;
/**
* @author bratseth
*/
-public class XwPlusB extends CompositeTensorFunction {
+public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> {
- private final TensorFunction x, w, b;
+ private final TensorFunction<NAMETYPE> x, w, b;
private final String dimension;
- public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) {
+ public XwPlusB(TensorFunction<NAMETYPE> x, TensorFunction<NAMETYPE> w, TensorFunction<NAMETYPE> b, String dimension) {
this.x = x;
this.w = w;
this.b = b;
@@ -21,25 +22,25 @@ public class XwPlusB extends CompositeTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(x, w, b); }
@Override
- public TensorFunction withArguments(List<TensorFunction> arguments) {
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if ( arguments.size() != 3)
throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size());
- return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
+ return new XwPlusB<>(arguments.get(0), arguments.get(1), arguments.get(2), dimension);
}
@Override
- public PrimitiveTensorFunction toPrimitive() {
- TensorFunction primitiveX = x.toPrimitive();
- TensorFunction primitiveW = w.toPrimitive();
- TensorFunction primitiveB = b.toPrimitive();
- return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()),
- Reduce.Aggregator.sum,
- dimension),
- primitiveB,
- ScalarFunctions.add());
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ TensorFunction<NAMETYPE> primitiveX = x.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveW = w.toPrimitive();
+ TensorFunction<NAMETYPE> primitiveB = b.toPrimitive();
+ return new Join<>(new Reduce<>(new Join<>(primitiveX, primitiveW, ScalarFunctions.multiply()),
+ Reduce.Aggregator.sum,
+ dimension),
+ primitiveB,
+ ScalarFunctions.add());
}
@Override
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java
index 439aac5578a..c334c58042c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java
@@ -2,17 +2,16 @@
package com.yahoo.tensor;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
import com.yahoo.tensor.functions.Reduce;
import com.yahoo.tensor.functions.TensorFunction;
-import java.util.ArrayList;
import java.util.Collections;
import java.util.List;
import java.util.Random;
-import java.util.stream.Collectors;
/**
* Microbenchmark of a "dot product" of two mapped rank 2 tensors
@@ -42,10 +41,10 @@ public class MatrixDotProductBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
- Reduce.Aggregator.sum).toPrimitive();
- MapEvaluationContext context = new MapEvaluationContext();
+ TensorFunction<Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor),
+ new VariableTensor<>("argument"), (a, b) -> a * b),
+ Reduce.Aggregator.sum).toPrimitive();
+ MapEvaluationContext<Name> context = new MapEvaluationContext<>();
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index 7b856dde2d5..b3c6fbc6862 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
@@ -49,10 +50,10 @@ public class TensorFunctionBenchmark {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double largest = Double.MIN_VALUE;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
- Reduce.Aggregator.sum).toPrimitive();
- MapEvaluationContext context = new MapEvaluationContext();
+ TensorFunction<Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor),
+ new VariableTensor<>("argument"), (a, b) -> a * b),
+ Reduce.Aggregator.sum).toPrimitive();
+ MapEvaluationContext<Name> context = new MapEvaluationContext<>();
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
index c6fbb9c009d..11365531019 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor;
import com.google.common.collect.ImmutableList;
import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.VariableTensor;
import com.yahoo.tensor.functions.ConstantTensor;
import com.yahoo.tensor.functions.Join;
@@ -307,10 +308,10 @@ public class TensorTestCase {
private double dotProduct(Tensor tensor, List<Tensor> tensors) {
double sum = 0;
- TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor),
- new VariableTensor("argument"), (a, b) -> a * b),
- Reduce.Aggregator.sum).toPrimitive();
- MapEvaluationContext context = new MapEvaluationContext();
+ TensorFunction<Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor),
+ new VariableTensor<>("argument"), (a, b) -> a * b),
+ Reduce.Aggregator.sum).toPrimitive();
+ MapEvaluationContext<Name> context = new MapEvaluationContext<>();
for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor
context.put("argument", tensorElement);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
index eafa5c4addf..0476fe1c757 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java
@@ -101,8 +101,8 @@ public class ConcatTestCase {
private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) {
Tensor expectedAsTensor = Tensor.from(expected);
- TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension)
- .type(new MapEvaluationContext());
+ TensorType inferredType = new Concat<>(new ConstantTensor<>(a), new ConstantTensor<>(b), dimension)
+ .type(new MapEvaluationContext<>());
Tensor result = a.concat(b, dimension);
if (expectedType != null)
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
index 925da9d3c89..e16b7b90a1d 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java
@@ -5,11 +5,11 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
import org.junit.Test;
import java.util.Collections;
import java.util.List;
-import java.util.function.Function;
import static org.junit.Assert.assertEquals;
@@ -21,27 +21,27 @@ public class DynamicTensorTestCase {
@Test
public void testDynamicTensorFunction() {
TensorType dense = TensorType.fromSpec("tensor(x[3])");
- DynamicTensor t1 = DynamicTensor.from(dense,
- List.of(new Constant(1), new Constant(2), new Constant(3)));
+ DynamicTensor<Name> t1 = DynamicTensor.from(dense,
+ List.of(new Constant(1), new Constant(2), new Constant(3)));
assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate());
assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString());
TensorType sparse = TensorType.fromSpec("tensor(x{})");
- DynamicTensor t2 = DynamicTensor.from(sparse,
- Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
- new Constant(5)));
+ DynamicTensor<Name> t2 = DynamicTensor.from(sparse,
+ Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(),
+ new Constant(5)));
assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate());
assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString());
}
- private static class Constant implements ScalarFunction {
+ private static class Constant implements ScalarFunction<Name> {
private final double value;
public Constant(double value) { this.value = value; }
@Override
- public Double apply(EvaluationContext<?> evaluationContext) { return value; }
+ public Double apply(EvaluationContext<Name> evaluationContext) { return value; }
@Override
public String toString() { return String.valueOf(value); }
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
index e37bee2d990..e6560242d5c 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java
@@ -2,6 +2,7 @@
package com.yahoo.tensor.functions;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.Name;
import org.junit.Test;
import static org.junit.Assert.assertEquals;
@@ -15,14 +16,14 @@ public class TensorFunctionTestCase {
@Test
public void testTranslation() {
assertTranslated("join(tensor(x{}):{{x:1}:1.0}, reduce(tensor(x{}):{{x:1}:1.0}, sum, x), f(a,b)(a / b))",
- new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x"));
+ new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x"));
assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))",
- new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
+ new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build()));
assertTranslated("join(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))",
- new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
+ new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x"));
}
- private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) {
+ private void assertTranslated(String expectedTranslation, TensorFunction<Name> inputFunction) {
assertEquals(expectedTranslation, inputFunction.toPrimitive().toString());
}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
new file mode 100644
index 00000000000..7127abde016
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
@@ -0,0 +1,66 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class ValueTestCase {
+
+ private static final double delta = 0.000001;
+
+ @Test
+ public void testValueFunctionGeneralForm() {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ Tensor result = new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>("key", "bar"),
+ new Value.DimensionValue<>("x", 0)))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(2.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testValueFunctionSingleMappedDimension() {
+ Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }");
+ Tensor result = new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>("foo")))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(1.4, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testValueFunctionSingleIndexedDimension() {
+ Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
+ Tensor result = new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>(2)))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(3.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() {
+ try {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ new Value<>(new ConstantTensor<>(input),
+ List.of(new Value.DimensionValue<>("bar"),
+ new Value.DimensionValue<>(0)))
+ .evaluate();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Short form of cell addresses is only supported with a single dimension: Specify dimension names explicitly",
+ e.getMessage());
+ }
+ }
+
+}