diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2019-11-28 10:24:57 +0100 |
---|---|---|
committer | GitHub <noreply@github.com> | 2019-11-28 10:24:57 +0100 |
commit | a907f095507bfd9aec0d6bd168217b4a0471b651 (patch) | |
tree | f0fe57f98b48829ba186e4f543748b2c6f25fe4a | |
parent | a5e8e198dabc9dcfc710200d2ed170193f9b253b (diff) | |
parent | 0e18f68b1583b3391859b3def7f3a168b5212d15 (diff) |
Merge pull request #11435 from vespa-engine/bratseth/value-function
Bratseth/value function
52 files changed, 959 insertions, 422 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index bbfd2004caa..55979023119 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -411,7 +411,7 @@ public class ConvertedModel { } // Modify any renames in expression to disregard batch dimension else if (children.size() == 1 && children.get(0) instanceof TensorFunctionNode) { - TensorFunction childFunction = (((TensorFunctionNode) children.get(0)).function()); + TensorFunction<Reference> childFunction = (((TensorFunctionNode) children.get(0)).function()); TensorType childType = childFunction.type(typeContext); Rename rename = (Rename) tensorFunction; List<String> from = new ArrayList<>(); @@ -422,10 +422,10 @@ public class ConvertedModel { throw new IllegalArgumentException("Rename does not contain dimension '" + dimension + "' in child expression type: " + childType); } - from.add(rename.fromDimensions().get(i)); - to.add(rename.toDimensions().get(i)); + from.add((String)rename.fromDimensions().get(i)); + to.add((String)rename.toDimensions().get(i)); } - return new TensorFunctionNode(new Rename(childFunction, from, to)); + return new TensorFunctionNode(new Rename<>(childFunction, from, to)); } } } diff --git a/config-model/src/test/derived/tensor/rank-profiles.cfg b/config-model/src/test/derived/tensor/rank-profiles.cfg index 617901130a6..cebfa244159 100644 --- a/config-model/src/test/derived/tensor/rank-profiles.cfg +++ b/config-model/src/test/derived/tensor/rank-profiles.cfg @@ -110,3 +110,20 @@ rankprofile[].fef.property[].name "vespa.type.attribute.f4" rankprofile[].fef.property[].value "tensor(x[10],y[20])" rankprofile[].fef.property[].name "vespa.type.attribute.f5" rankprofile[].fef.property[].value "tensor<float>(x[10])" +rankprofile[].name "profile7" +rankprofile[].fef.property[].name "rankingExpression(reshaped).rankingScript" +rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])({x:1 - x, y:d0})" +rankprofile[].fef.property[].name "rankingExpression(reshaped).type" +rankprofile[].fef.property[].value "tensor<float>(d0[1],x[2])" +rankprofile[].fef.property[].name "vespa.rank.firstphase" +rankprofile[].fef.property[].value "rankingExpression(firstphase)" +rankprofile[].fef.property[].name "rankingExpression(firstphase).rankingScript" +rankprofile[].fef.property[].value "reduce(rankingExpression(reshaped), sum)" +rankprofile[].fef.property[].name "vespa.type.attribute.f2" +rankprofile[].fef.property[].value "tensor<float>(x[2],y[1])" +rankprofile[].fef.property[].name "vespa.type.attribute.f3" +rankprofile[].fef.property[].value "tensor(x{})" +rankprofile[].fef.property[].name "vespa.type.attribute.f4" +rankprofile[].fef.property[].value "tensor(x[10],y[20])" +rankprofile[].fef.property[].name "vespa.type.attribute.f5" +rankprofile[].fef.property[].value "tensor<float>(x[10])" diff --git a/config-model/src/test/derived/tensor/tensor.sd b/config-model/src/test/derived/tensor/tensor.sd index 13727d1ec49..15d56517a43 100644 --- a/config-model/src/test/derived/tensor/tensor.sd +++ b/config-model/src/test/derived/tensor/tensor.sd @@ -78,4 +78,16 @@ search tensor { } + rank-profile profile7 { + + first-phase { + expression: sum(reshaped()) + } + + function reshaped() { + expression: tensor<float>(d0[1],x[2])(attribute(f2){x:1-x, y:d0}) + } + + } + } diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json index 1857511cbf2..5a75c8b31ea 100644 --- a/model-evaluation/abi-spec.json +++ b/model-evaluation/abi-spec.json @@ -39,7 +39,7 @@ "public java.util.Set names()", "public java.util.Set arguments()", "public com.yahoo.searchlib.rankingexpression.evaluation.Value defaultValue()", - "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)" + "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)" ], "fields": [] }, diff --git a/searchlib/abi-spec.json b/searchlib/abi-spec.json index 8d7bf4f9f14..abfae426ad0 100644 --- a/searchlib/abi-spec.json +++ b/searchlib/abi-spec.json @@ -342,7 +342,7 @@ "fields": [] }, "com.yahoo.searchlib.rankingexpression.Reference": { - "superClass": "com.yahoo.tensor.evaluation.TypeContext$Name", + "superClass": "com.yahoo.tensor.evaluation.Name", "interfaces": [], "attributes": [ "public" @@ -414,7 +414,7 @@ "public final double getDouble(int)", "public com.yahoo.searchlib.rankingexpression.evaluation.ArrayContext clone()", "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()", - "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)", + "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", "public bridge synthetic java.lang.Object clone()" ], "fields": [] @@ -521,7 +521,7 @@ "public final com.yahoo.searchlib.rankingexpression.evaluation.Value get(int)", "public com.yahoo.searchlib.rankingexpression.evaluation.DoubleOnlyArrayContext clone()", "public bridge synthetic com.yahoo.searchlib.rankingexpression.evaluation.AbstractArrayContext clone()", - "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)", + "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", "public bridge synthetic java.lang.Object clone()" ], "fields": [] @@ -592,7 +592,7 @@ "public java.util.Set names()", "public java.lang.String toString()", "public static com.yahoo.searchlib.rankingexpression.evaluation.MapContext fromString(java.lang.String)", - "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)" + "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)" ], "fields": [] }, @@ -610,7 +610,7 @@ "public com.yahoo.tensor.TensorType getType(java.lang.String)", "public com.yahoo.tensor.TensorType getType(com.yahoo.searchlib.rankingexpression.Reference)", "public java.util.Map bindings()", - "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)" + "public bridge synthetic com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)" ], "fields": [] }, @@ -872,25 +872,25 @@ "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode arg()", "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode function()", "public final com.yahoo.searchlib.rankingexpression.rule.FunctionNode scalarOrTensorFunction()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorFunction()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMap()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorReduce()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorReduceComposites()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorJoin()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRename()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorConcat()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerate()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorGenerateBody(com.yahoo.tensor.TensorType)", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRange()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorDiag()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorRandom()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorL1Normalize()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorL2Normalize()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorMatmul()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorSoftmax()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorXwPlusB()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmax()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorArgmin()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorFunction()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMap()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduce()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorReduceComposites()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorJoin()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRename()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorConcat()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorGenerate()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorGenerateBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRange()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorDiag()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorRandom()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL1Normalize()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorL2Normalize()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorMatmul()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorSoftmax()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorXwPlusB()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmax()", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorArgmin()", "public final com.yahoo.searchlib.rankingexpression.rule.LambdaFunctionNode lambdaFunction()", "public final com.yahoo.tensor.functions.Reduce$Aggregator tensorReduceAggregator()", "public final com.yahoo.tensor.TensorType tensorType()", @@ -909,11 +909,14 @@ "public final java.util.List tagCommaLeadingList()", "public final com.yahoo.searchlib.rankingexpression.rule.ConstantNode constantPrimitive()", "public final com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue()", - "public final com.yahoo.searchlib.rankingexpression.rule.ExpressionNode tensorValueBody(com.yahoo.tensor.TensorType)", + "public final com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode tensorValueBody(com.yahoo.tensor.TensorType)", "public final com.yahoo.tensor.functions.DynamicTensor mappedTensorValueBody(com.yahoo.tensor.TensorType)", "public final com.yahoo.tensor.functions.DynamicTensor indexedTensorValueBody(com.yahoo.tensor.TensorType)", "public final void tensorCell(com.yahoo.tensor.TensorType, java.util.Map)", "public final void labelAndDimension(com.yahoo.tensor.TensorAddress$Builder)", + "public final void labelAndDimensionValues(java.util.List)", + "public final java.util.List valueAddress()", + "public final com.yahoo.tensor.functions.Value$DimensionValue dimensionValue(java.util.Optional)", "public void <init>(java.io.InputStream)", "public void <init>(java.io.InputStream, java.lang.String)", "public void ReInit(java.io.InputStream)", @@ -1613,8 +1616,9 @@ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", "public com.yahoo.searchlib.rankingexpression.evaluation.Value evaluate(com.yahoo.searchlib.rankingexpression.evaluation.Context)", "public static com.yahoo.searchlib.rankingexpression.rule.TensorFunctionNode$ExpressionTensorFunction wrap(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)", - "public static java.util.Map wrap(java.util.Map)", - "public static java.util.List wrap(java.util.List)" + "public static java.util.Map wrapScalars(java.util.Map)", + "public static java.util.List wrapScalars(java.util.List)", + "public static com.yahoo.tensor.functions.ScalarFunction wrapScalar(com.yahoo.searchlib.rankingexpression.rule.ExpressionNode)" ], "fields": [] }, diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java index 829a796eee0..fa2d0f1ee45 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/Reference.java @@ -7,19 +7,18 @@ import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; import com.yahoo.searchlib.rankingexpression.rule.NameNode; import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; import com.yahoo.searchlib.rankingexpression.rule.SerializationContext; -import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.evaluation.Name; import java.util.Deque; import java.util.Objects; import java.util.Optional; -import java.util.stream.Collectors; /** * A reference to a feature, function, or value in ranking expressions * * @author bratseth */ -public class Reference extends TypeContext.Name { +public class Reference extends Name { private final int hashCode; diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java index d68f8c85ad1..cf17c6465f3 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/Context.java @@ -34,7 +34,7 @@ public abstract class Context implements EvaluationContext<Reference> { @Override public TensorType getType(String reference) { - throw new UnsupportedOperationException("Not able to parse gereral references from string form"); + throw new UnsupportedOperationException("Not able to parse general references from string form"); } /** Returns a variable as a tensor */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java index 63cea371d14..41b01c9a2cb 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/tensoroptimization/TensorOptimizer.java @@ -2,6 +2,7 @@ package com.yahoo.searchlib.rankingexpression.evaluation.tensoroptimization; import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.ContextIndex; import com.yahoo.searchlib.rankingexpression.evaluation.OptimizationReport; import com.yahoo.searchlib.rankingexpression.evaluation.Optimizer; @@ -58,11 +59,12 @@ public class TensorOptimizer extends Optimizer { * The ReduceJoin class determines whether or not the arguments are * compatible with the optimization. */ + @SuppressWarnings("unchecked") private ExpressionNode optimizeReduceJoin(ExpressionNode node) { if ( ! (node instanceof TensorFunctionNode)) { return node; } - TensorFunction function = ((TensorFunctionNode) node).function(); + TensorFunction<Reference> function = ((TensorFunctionNode) node).function(); if ( ! (function instanceof Reduce)) { return node; } @@ -74,10 +76,10 @@ public class TensorOptimizer extends Optimizer { if ( ! (child instanceof TensorFunctionNode)) { return node; } - TensorFunction argument = ((TensorFunctionNode) child).function(); + TensorFunction<Reference> argument = ((TensorFunctionNode) child).function(); if (argument instanceof Join) { report.incMetric("Replaced reduce->join", 1); - return new TensorFunctionNode(new ReduceJoin((Reduce)function, (Join)argument)); + return new TensorFunctionNode(new ReduceJoin<>((Reduce<Reference>)function, (Join<Reference>)argument)); } return node; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java index 4ffd40f00f7..cec8837abcd 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/TensorFunctionNode.java @@ -16,6 +16,7 @@ import com.yahoo.tensor.functions.ScalarFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; +import java.sql.Ref; import java.util.ArrayList; import java.util.Collections; import java.util.Deque; @@ -32,14 +33,14 @@ import java.util.stream.Collectors; @Beta public class TensorFunctionNode extends CompositeNode { - private final TensorFunction function; + private final TensorFunction<Reference> function; - public TensorFunctionNode(TensorFunction function) { + public TensorFunctionNode(TensorFunction<Reference> function) { this.function = function; } /** Returns the tensor function wrapped by this */ - public TensorFunction function() { return function; } + public TensorFunction<Reference> function() { return function; } @Override public List<ExpressionNode> children() { @@ -48,7 +49,7 @@ public class TensorFunctionNode extends CompositeNode { .collect(Collectors.toList()); } - private ExpressionNode toExpressionNode(TensorFunction f) { + private ExpressionNode toExpressionNode(TensorFunction<Reference> f) { if (f instanceof ExpressionTensorFunction) return ((ExpressionTensorFunction)f).expression; else @@ -57,9 +58,9 @@ public class TensorFunctionNode extends CompositeNode { @Override public CompositeNode setChildren(List<ExpressionNode> children) { - List<TensorFunction> wrappedChildren = children.stream() - .map(ExpressionTensorFunction::new) - .collect(Collectors.toList()); + List<TensorFunction<Reference>> wrappedChildren = children.stream() + .map(ExpressionTensorFunction::new) + .collect(Collectors.toList()); return new TensorFunctionNode(function.withArguments(wrappedChildren)); } @@ -81,21 +82,22 @@ public class TensorFunctionNode extends CompositeNode { return new ExpressionTensorFunction(node); } - public static Map<TensorAddress, ScalarFunction> wrap(Map<TensorAddress, ExpressionNode> nodes) { - Map<TensorAddress, ScalarFunction> functions = new LinkedHashMap<>(); + public static Map<TensorAddress, ScalarFunction<Reference>> wrapScalars(Map<TensorAddress, ExpressionNode> nodes) { + Map<TensorAddress, ScalarFunction<Reference>> functions = new LinkedHashMap<>(); for (var entry : nodes.entrySet()) - functions.put(entry.getKey(), new ExpressionScalarFunction(entry.getValue())); + functions.put(entry.getKey(), wrapScalar(entry.getValue())); return functions; } - public static List<ScalarFunction> wrap(List<ExpressionNode> nodes) { - List<ScalarFunction> functions = new ArrayList<>(); - for (var entry : nodes) - functions.add(new ExpressionScalarFunction(entry)); - return functions; + public static List<ScalarFunction<Reference>> wrapScalars(List<ExpressionNode> nodes) { + return nodes.stream().map(node -> wrapScalar(node)).collect(Collectors.toList()); + } + + public static ScalarFunction<Reference> wrapScalar(ExpressionNode node) { + return new ExpressionScalarFunction(node); } - private static class ExpressionScalarFunction implements ScalarFunction { + private static class ExpressionScalarFunction implements ScalarFunction<Reference> { private final ExpressionNode expression; @@ -104,8 +106,8 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public Double apply(EvaluationContext<?> context) { - return expression.evaluate((Context)context).asDouble(); + public Double apply(EvaluationContext<Reference> context) { + return expression.evaluate(new ContextWrapper(context)).asDouble(); } @Override @@ -130,7 +132,7 @@ public class TensorFunctionNode extends CompositeNode { * A tensor function implemented by an expression. * This allows us to pass expressions as tensor function arguments. */ - public static class ExpressionTensorFunction extends PrimitiveTensorFunction { + public static class ExpressionTensorFunction extends PrimitiveTensorFunction<Reference> { /** An expression which produces a tensor */ private final ExpressionNode expression; @@ -140,7 +142,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public List<TensorFunction> arguments() { + public List<TensorFunction<Reference>> arguments() { if (expression instanceof CompositeNode) return ((CompositeNode)expression).children().stream() .map(ExpressionTensorFunction::new) @@ -150,7 +152,7 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<Reference> withArguments(List<TensorFunction<Reference>> arguments) { if (arguments.size() == 0) return this; List<ExpressionNode> unwrappedChildren = arguments.stream() .map(arg -> ((ExpressionTensorFunction)arg).expression) @@ -159,16 +161,15 @@ public class TensorFunctionNode extends CompositeNode { } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<Reference> toPrimitive() { return this; } @Override - @SuppressWarnings("unchecked") // Generics awkwardness - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { - return expression.type((TypeContext<Reference>)context); + public TensorType type(TypeContext<Reference> context) { + return expression.type(context); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<Reference> context) { return expression.evaluate((Context)context).asTensor(); } @@ -209,4 +210,26 @@ public class TensorFunctionNode extends CompositeNode { } + /** Turns an EvaluationContext into a Context */ + // TODO: We should be able to change RankingExpression.evaluate to take an EvaluationContext and then get rid of this + private static class ContextWrapper extends Context { + + private final EvaluationContext<Reference> delegate; + + public ContextWrapper(EvaluationContext<Reference> delegate) { + this.delegate = delegate; + } + + @Override + public Value get(String name) { + return new TensorValue(delegate.getTensor(name)); + } + + @Override + public TensorType getType(Reference name) { + return delegate.getType(name); + } + + } + } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java index 9a38b5efc1f..9bed4a4ea7c 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/transform/TensorMaxMinTransformer.java @@ -87,7 +87,7 @@ public class TensorMaxMinTransformer<CONTEXT extends TransformContext> extends E Reduce.Aggregator aggregator = Reduce.Aggregator.valueOf(node.getFunction().name()); String dimension = ((ReferenceNode) arg2).getName(); - return new TensorFunctionNode(new Reduce(expression, aggregator, dimension)); + return new TensorFunctionNode(new Reduce<>(expression, aggregator, dimension)); } } diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 01eed897bfd..c7870182939 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -18,7 +18,6 @@ PARSER_BEGIN(RankingExpressionParser) package com.yahoo.searchlib.rankingexpression.parser; import com.yahoo.searchlib.rankingexpression.rule.*; -import com.yahoo.searchlib.rankingexpression.evaluation.Value; import com.yahoo.searchlib.rankingexpression.evaluation.StringValue; import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue; import com.yahoo.tensor.*; @@ -231,26 +230,28 @@ TruthOperator comparator() : { } ExpressionNode value() : { - ExpressionNode ret; + ExpressionNode value; boolean neg = false; boolean not = false; + List valueAddress; } { ( [ <NOT> { not = true; } ] [ LOOKAHEAD(2) <SUB> { neg = true; } ] - ( ret = constantPrimitive() | - LOOKAHEAD(2) ret = ifExpression() | - LOOKAHEAD(4) ret = function() | - ret = feature() | - ret = legacyQueryFeature() | - ( <LBRACE> ret = expression() <RBRACE> { ret = new EmbracedNode(ret); } ) ) + ( value = constantPrimitive() | + LOOKAHEAD(2) value = ifExpression() | + LOOKAHEAD(4) value = function() | + value = feature() | + value = legacyQueryFeature() | + ( <LBRACE> value = expression() <RBRACE> { value = new EmbracedNode(value); } ) ) ) + [ LOOKAHEAD(2) valueAddress = valueAddress() { value = new TensorFunctionNode(new Value(TensorFunctionNode.wrap(value), valueAddress)); } ] { - ret = not ? new NotNode(ret) : ret; - ret = neg ? new NegativeNode(ret) : ret; - return ret; + value = not ? new NotNode(value) : value; + value = neg ? new NegativeNode(value) : value; + return value; } } @@ -323,12 +324,12 @@ List<ExpressionNode> args() : { return arguments; } } -// TODO: Replace use of this for macro arguments with value() +// TODO: Replace use of this for function arguments with value() // For that to work with the current search execution framework -// we need to generate another macro for the argument such that we can replace -// instances of the argument with the reference to that macro in the same way +// we need to generate another function for the argument such that we can replace +// instances of the argument with the reference to that function in the same way // as we replace by constants/names today (this can make for some fun combinatorial explosion). -// Simon also points out that we should stop doing macro expansion in the toString of a macro. +// We should also stop doing function expansion in the toString of a function. // - Jon 2014-05-02 ExpressionNode arg() : { @@ -368,9 +369,9 @@ FunctionNode scalarOrTensorFunction() : ) } -ExpressionNode tensorFunction() : +TensorFunctionNode tensorFunction() : { - ExpressionNode tensorExpression; + TensorFunctionNode tensorExpression; } { ( @@ -395,7 +396,7 @@ ExpressionNode tensorFunction() : { return tensorExpression; } } -ExpressionNode tensorMap() : +TensorFunctionNode tensorMap() : { ExpressionNode tensor; LambdaFunctionNode doubleMapper; @@ -403,10 +404,10 @@ ExpressionNode tensorMap() : { <MAP> <LBRACE> tensor = expression() <COMMA> doubleMapper = lambdaFunction() <RBRACE> { return new TensorFunctionNode(new Map(TensorFunctionNode.wrap(tensor), - doubleMapper.asDoubleUnaryOperator())); } + doubleMapper.asDoubleUnaryOperator())); } } -ExpressionNode tensorReduce() : +TensorFunctionNode tensorReduce() : { ExpressionNode tensor; Reduce.Aggregator aggregator; @@ -417,7 +418,7 @@ ExpressionNode tensorReduce() : { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } -ExpressionNode tensorReduceComposites() : +TensorFunctionNode tensorReduceComposites() : { ExpressionNode tensor; Reduce.Aggregator aggregator; @@ -429,7 +430,7 @@ ExpressionNode tensorReduceComposites() : { return new TensorFunctionNode(new Reduce(TensorFunctionNode.wrap(tensor), aggregator, dimensions)); } } -ExpressionNode tensorJoin() : +TensorFunctionNode tensorJoin() : { ExpressionNode tensor1, tensor2; LambdaFunctionNode doubleJoiner; @@ -441,7 +442,7 @@ ExpressionNode tensorJoin() : doubleJoiner.asDoubleBinaryOperator())); } } -ExpressionNode tensorRename() : +TensorFunctionNode tensorRename() : { ExpressionNode tensor; List<String> fromDimensions, toDimensions; @@ -454,7 +455,7 @@ ExpressionNode tensorRename() : { return new TensorFunctionNode(new Rename(TensorFunctionNode.wrap(tensor), fromDimensions, toDimensions)); } } -ExpressionNode tensorConcat() : +TensorFunctionNode tensorConcat() : { ExpressionNode tensor1, tensor2; String dimension; @@ -466,10 +467,10 @@ ExpressionNode tensorConcat() : dimension)); } } -ExpressionNode tensorGenerate() : +TensorFunctionNode tensorGenerate() : { TensorType type; - ExpressionNode expression; + TensorFunctionNode expression; } { <TENSOR> type = tensorType() @@ -480,16 +481,16 @@ ExpressionNode tensorGenerate() : { return expression; } } -ExpressionNode tensorGenerateBody(TensorType type) : +TensorFunctionNode tensorGenerateBody(TensorType type) : { ExpressionNode generator; } { <LBRACE> generator = expression() <RBRACE> - { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } + { return new TensorFunctionNode(Generate.bound(type, TensorFunctionNode.wrapScalar(generator))); } } -ExpressionNode tensorRange() : +TensorFunctionNode tensorRange() : { TensorType type; } @@ -498,7 +499,7 @@ ExpressionNode tensorRange() : { return new TensorFunctionNode(new Range(type)); } } -ExpressionNode tensorDiag() : +TensorFunctionNode tensorDiag() : { TensorType type; } @@ -507,7 +508,7 @@ ExpressionNode tensorDiag() : { return new TensorFunctionNode(new Diag(type)); } } -ExpressionNode tensorRandom() : +TensorFunctionNode tensorRandom() : { TensorType type; } @@ -516,7 +517,7 @@ ExpressionNode tensorRandom() : { return new TensorFunctionNode(new Random(type)); } } -ExpressionNode tensorL1Normalize() : +TensorFunctionNode tensorL1Normalize() : { ExpressionNode tensor; String dimension; @@ -526,7 +527,7 @@ ExpressionNode tensorL1Normalize() : { return new TensorFunctionNode(new L1Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorL2Normalize() : +TensorFunctionNode tensorL2Normalize() : { ExpressionNode tensor; String dimension; @@ -536,7 +537,7 @@ ExpressionNode tensorL2Normalize() : { return new TensorFunctionNode(new L2Normalize(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorMatmul() : +TensorFunctionNode tensorMatmul() : { ExpressionNode tensor1, tensor2; String dimension; @@ -548,7 +549,7 @@ ExpressionNode tensorMatmul() : dimension)); } } -ExpressionNode tensorSoftmax() : +TensorFunctionNode tensorSoftmax() : { ExpressionNode tensor; String dimension; @@ -558,7 +559,7 @@ ExpressionNode tensorSoftmax() : { return new TensorFunctionNode(new Softmax(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorXwPlusB() : +TensorFunctionNode tensorXwPlusB() : { ExpressionNode tensor1, tensor2, tensor3; String dimension; @@ -574,7 +575,7 @@ ExpressionNode tensorXwPlusB() : dimension)); } } -ExpressionNode tensorArgmax() : +TensorFunctionNode tensorArgmax() : { ExpressionNode tensor; String dimension; @@ -584,7 +585,7 @@ ExpressionNode tensorArgmax() : { return new TensorFunctionNode(new Argmax(TensorFunctionNode.wrap(tensor), dimension)); } } -ExpressionNode tensorArgmin() : +TensorFunctionNode tensorArgmin() : { ExpressionNode tensor; String dimension; @@ -811,20 +812,20 @@ ConstantNode constantPrimitive() : ( <INTEGER> { value = token.image; } | <FLOAT> { value = token.image; } | <STRING> { value = token.image; } ) - { return new ConstantNode(Value.parse(sign + value),sign + value); } + { return new ConstantNode(com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + value),sign + value); } } -Value primitiveValue() : +com.yahoo.searchlib.rankingexpression.evaluation.Value primitiveValue() : { String sign = ""; } { ( <SUB> { sign = "-";} ) ? ( <INTEGER> | <FLOAT> | <STRING> ) - { return Value.parse(sign + token.image); } + { return com.yahoo.searchlib.rankingexpression.evaluation.Value.parse(sign + token.image); } } -ExpressionNode tensorValueBody(TensorType type) : +TensorFunctionNode tensorValueBody(TensorType type) : { DynamicTensor dynamicTensor; } @@ -846,7 +847,7 @@ DynamicTensor mappedTensorValueBody(TensorType type) : ( tensorCell(type, cells))* ( <COMMA> tensorCell(type, cells))* <RCURLY> - { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } DynamicTensor indexedTensorValueBody(TensorType type) : @@ -859,7 +860,7 @@ DynamicTensor indexedTensorValueBody(TensorType type) : ( (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* ( <COMMA> (<LSQUARE>)* value = expression() (<RSQUARE>)* { cells.add(value); } )* // <RSQUARE> - { return DynamicTensor.from(type, TensorFunctionNode.wrap(cells)); } + { return DynamicTensor.from(type, TensorFunctionNode.wrapScalars(cells)); } } void tensorCell(TensorType type, java.util.Map cells) : @@ -882,4 +883,50 @@ void labelAndDimension(TensorAddress.Builder addressBuilder) : { dimension = identifier() <COLON> label = tag() { addressBuilder.add(dimension, label); } +} + +void labelAndDimensionValues(List addressValues) : +{ + String dimension; + Value.DimensionValue dimensionValue; +} +{ + dimension = identifier() <COLON> dimensionValue = dimensionValue(Optional.of(dimension)) + { addressValues.add(dimensionValue); } +} + +/** A tensor address (possibly on short form) represented as a list because the tensor type is not available */ +List valueAddress() : +{ + List dimensionValues = new ArrayList(); + ExpressionNode valueExpression; + Value.DimensionValue dimensionValue; +} +{ + ( + ( <LSQUARE> ( valueExpression = expression() { dimensionValues.add(new Value.DimensionValue(TensorFunctionNode.wrapScalar(valueExpression))); } ) <RSQUARE> ) + | + LOOKAHEAD(3) ( <LCURLY> + ( labelAndDimensionValues(dimensionValues))+ + ( <COMMA> labelAndDimensionValues(dimensionValues))* + <RCURLY> + ) + | + ( <LCURLY> dimensionValue = dimensionValue(Optional.empty()) { dimensionValues.add(dimensionValue); } <RCURLY> ) + ) + { return dimensionValues;} +} + +Value.DimensionValue dimensionValue(Optional dimensionName) : +{ + ExpressionNode value; +} +{ + value = expression() + { + if (value instanceof ReferenceNode && ((ReferenceNode)value).reference().isIdentifier()) + return new Value.DimensionValue(dimensionName, ((ReferenceNode)value).reference().name()); + else + return new Value.DimensionValue(dimensionName, TensorFunctionNode.wrapScalar(value)); + } }
\ No newline at end of file diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java index e7024b87452..26fcec9efba 100755 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/RankingExpressionTestCase.java @@ -61,7 +61,7 @@ public class RankingExpressionTestCase { ReferenceNode input = new ReferenceNode("input"); ReferenceNode constant = new ReferenceNode("constant"); ArithmeticNode product = new ArithmeticNode(input, ArithmeticOperator.MULTIPLY, constant); - Reduce sum = new Reduce(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); + Reduce<Reference> sum = new Reduce<>(new TensorFunctionNode.ExpressionTensorFunction(product), Reduce.Aggregator.sum); RankingExpression expression = new RankingExpression(new TensorFunctionNode(sum)); RankingExpression expected = new RankingExpression("sum(input * constant)"); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java index 6064035702e..99047aeb79d 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/evaluation/EvaluationTestCase.java @@ -317,6 +317,12 @@ public class EvaluationTestCase { tester.assertEvaluates("{ {x:0,y:0,z:0}:1, {x:0,y:0,z:1}:0, {x:0,y:1,z:0}:0, {x:0,y:1,z:1}:0, {x:1,y:0,z:0}:0, {x:1,y:0,z:1}:0, {x:1,y:1,z:0}:0, {x:1,y:1,z:1}:1, }", "diag(x[2],y[2],z[2])"); tester.assertEvaluates("6", "reduce(random(x[2],y[3]), count)"); + // tensor value + tester.assertEvaluates("3.0", "tensor0{x:1}", "{ {x:0}:1, {x:1}:3 }"); + tester.assertEvaluates("1.2", "tensor0{key:foo,x:0}", true, "{ {key:foo,x:0}:1.2, {key:bar,x:0}:3 }"); + tester.assertEvaluates("3.0", "tensor0{bar}", true, "{ {x:foo}:1, {x:bar}:3 }"); + tester.assertEvaluates("3.3", "tensor0[2]", "tensor(values[4]):[1.1, 2.2, 3.3, 4.4]]"); + // composite functions tester.assertEvaluates("{ {x:0}:0.25, {x:1}:0.75 }", "l1_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); tester.assertEvaluates("{ {x:0}:0.31622776601683794, {x:1}:0.9486832980505138 }", "l2_normalize(tensor0, x)", "{ {x:0}:1, {x:1}:3 }"); @@ -349,6 +355,18 @@ public class EvaluationTestCase { tester.assertEvaluates("0", "reduce(join(tensor0, tensor1, f(x,y) (if(x > y, 1.0, 0.0))), sum, tag) == reduce(tensor0, count, tag)", "tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}"); + tester.assertEvaluates("500", + "join(tensor0, tensor1, f(x,y) (x*y)){tag2}", + "tensor(tag{}):{{tag:tag1}:10, {tag:tag2}:20}", "{25}"); + tester.assertEvaluates("tensor(j[3]):[3, 3, 3]", + "tensor(j[3])(tensor0[2])", + "tensor(values[5]):[1, 2, 3, 4, 5]"); + tester.assertEvaluates("tensor(j[3]):[5, 4, 3]", + "tensor(j[3])(tensor0[4-j])", + "tensor(values[5]):[1, 2, 3, 4, 5]"); + tester.assertEvaluates("tensor(j[2]):[6, 5]", + "tensor(j[2])(tensor0{key:bar,i:2-j})", + "tensor(key{},i[5]):{{key:foo,i:0}:1,{key:foo,i:1}:2,{key:foo,i:2}:2,{key:bar,i:0}:4,{key:bar,i:1}:5,{key:bar,i:2}:6}"); // tensor result dimensions are given from argument dimensions, not the resulting values tester.assertEvaluates("tensor(x{}):{}", "tensor0 * tensor1", "{ {x:0}:1 }", "tensor(x{}):{ {x:1}:1 }"); diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 4ec1a5a234b..59474021de2 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1448,12 +1448,12 @@ "public void <init>()", "public void put(java.lang.String, com.yahoo.tensor.Tensor)", "public com.yahoo.tensor.TensorType getType(java.lang.String)", - "public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)", + "public com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", "public com.yahoo.tensor.Tensor getTensor(java.lang.String)" ], "fields": [] }, - "com.yahoo.tensor.evaluation.TypeContext$Name": { + "com.yahoo.tensor.evaluation.Name": { "superClass": "java.lang.Object", "interfaces": [], "attributes": [ @@ -1477,7 +1477,7 @@ "abstract" ], "methods": [ - "public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.TypeContext$Name)", + "public abstract com.yahoo.tensor.TensorType getType(com.yahoo.tensor.evaluation.Name)", "public abstract com.yahoo.tensor.TensorType getType(java.lang.String)" ], "fields": [] @@ -1620,6 +1620,8 @@ ], "methods": [ "public void <init>(com.yahoo.tensor.TensorType, java.util.function.Function)", + "public static com.yahoo.tensor.functions.Generate free(com.yahoo.tensor.TensorType, java.util.function.Function)", + "public static com.yahoo.tensor.functions.Generate bound(com.yahoo.tensor.TensorType, com.yahoo.tensor.functions.ScalarFunction)", "public java.util.List arguments()", "public com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)", "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", @@ -2506,6 +2508,47 @@ ], "fields": [] }, + "com.yahoo.tensor.functions.Value$DimensionValue": { + "superClass": "java.lang.Object", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(java.lang.String, java.lang.String)", + "public void <init>(java.lang.String, int)", + "public void <init>(int)", + "public void <init>(java.lang.String)", + "public void <init>(com.yahoo.tensor.functions.ScalarFunction)", + "public void <init>(java.util.Optional, java.lang.String)", + "public void <init>(java.util.Optional, com.yahoo.tensor.functions.ScalarFunction)", + "public void <init>(java.lang.String, com.yahoo.tensor.functions.ScalarFunction)", + "public java.util.Optional dimension()", + "public java.util.Optional label()", + "public java.util.Optional index()", + "public java.lang.String toString()" + ], + "fields": [] + }, + "com.yahoo.tensor.functions.Value": { + "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction", + "interfaces": [], + "attributes": [ + "public" + ], + "methods": [ + "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)", + "public java.util.List arguments()", + "public com.yahoo.tensor.functions.Value withArguments(java.util.List)", + "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()", + "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)", + "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)", + "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)", + "public java.lang.String toString()", + "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)" + ], + "fields": [] + }, "com.yahoo.tensor.functions.XwPlusB": { "superClass": "com.yahoo.tensor.functions.CompositeTensorFunction", "interfaces": [], diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index afd82751137..b8ef84cabb7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -1,6 +1,7 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.evaluation.TypeContext; import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.Argmin; import com.yahoo.tensor.functions.Concat; @@ -37,7 +38,7 @@ import java.util.function.Function; * Each cell is is identified by its <i>address</i>, which consists of a set of dimension-label pairs which defines * the location of that cell. Both dimensions and labels are string on the form of an identifier or integer. * <p> - * The size of the set of dimensions of a tensor is called its <i>order</i>. + * The size of the set of dimensions of a tensor is called its <i>rank</i>. * <p> * In contrast to regular mathematical formulations of tensors, this definition of a tensor allows <i>sparseness</i> * as there is no built-in notion of a contiguous space, and even in cases where a space is implied (such as when @@ -144,25 +145,25 @@ public interface Tensor { // ----------------- Primitive tensor functions default Tensor map(DoubleUnaryOperator mapper) { - return new com.yahoo.tensor.functions.Map(new ConstantTensor(this), mapper).evaluate(); + return new com.yahoo.tensor.functions.Map<>(new ConstantTensor<>(this), mapper).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, String ... dimensions) { - return new Reduce(new ConstantTensor(this), aggregator, Arrays.asList(dimensions)).evaluate(); + return new Reduce<>(new ConstantTensor<>(this), aggregator, Arrays.asList(dimensions)).evaluate(); } /** Aggregates cells over a set of dimensions, or over all dimensions if no dimensions are specified */ default Tensor reduce(Reduce.Aggregator aggregator, List<String> dimensions) { - return new Reduce(new ConstantTensor(this), aggregator, dimensions).evaluate(); + return new Reduce<>(new ConstantTensor<>(this), aggregator, dimensions).evaluate(); } default Tensor join(Tensor argument, DoubleBinaryOperator combinator) { - return new Join(new ConstantTensor(this), new ConstantTensor(argument), combinator).evaluate(); + return new Join<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), combinator).evaluate(); } default Tensor rename(String fromDimension, String toDimension) { - return new Rename(new ConstantTensor(this), Collections.singletonList(fromDimension), - Collections.singletonList(toDimension)).evaluate(); + return new Rename<>(new ConstantTensor<>(this), Collections.singletonList(fromDimension), + Collections.singletonList(toDimension)).evaluate(); } default Tensor concat(double argument, String dimension) { @@ -170,50 +171,50 @@ public interface Tensor { } default Tensor concat(Tensor argument, String dimension) { - return new Concat(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + return new Concat<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), dimension).evaluate(); } default Tensor rename(List<String> fromDimensions, List<String> toDimensions) { - return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); + return new Rename<>(new ConstantTensor<>(this), fromDimensions, toDimensions).evaluate(); } static Tensor generate(TensorType type, Function<List<Long>, Double> valueSupplier) { - return new Generate(type, valueSupplier).evaluate(); + return new Generate<>(type, valueSupplier).evaluate(); } // ----------------- Composite tensor functions which have a defined primitive mapping default Tensor l1Normalize(String dimension) { - return new L1Normalize(new ConstantTensor(this), dimension).evaluate(); + return new L1Normalize<>(new ConstantTensor<>(this), dimension).evaluate(); } default Tensor l2Normalize(String dimension) { - return new L2Normalize(new ConstantTensor(this), dimension).evaluate(); + return new L2Normalize<>(new ConstantTensor<>(this), dimension).evaluate(); } default Tensor matmul(Tensor argument, String dimension) { - return new Matmul(new ConstantTensor(this), new ConstantTensor(argument), dimension).evaluate(); + return new Matmul<>(new ConstantTensor<>(this), new ConstantTensor<>(argument), dimension).evaluate(); } default Tensor softmax(String dimension) { - return new Softmax(new ConstantTensor(this), dimension).evaluate(); + return new Softmax<>(new ConstantTensor<>(this), dimension).evaluate(); } default Tensor xwPlusB(Tensor w, Tensor b, String dimension) { - return new XwPlusB(new ConstantTensor(this), new ConstantTensor(w), new ConstantTensor(b), dimension).evaluate(); + return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate(); } default Tensor argmax(String dimension) { - return new Argmax(new ConstantTensor(this), dimension).evaluate(); + return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate(); } - default Tensor argmin(String dimension) { return new Argmin(new ConstantTensor(this), dimension).evaluate(); } + default Tensor argmin(String dimension) { return new Argmin<>(new ConstantTensor<>(this), dimension).evaluate(); } - static Tensor diag(TensorType type) { return new Diag(type).evaluate(); } + static Tensor diag(TensorType type) { return new Diag<>(type).evaluate(); } - static Tensor random(TensorType type) { return new Random(type).evaluate(); } + static Tensor random(TensorType type) { return new Random<>(type).evaluate(); } - static Tensor range(TensorType type) { return new Range(type).evaluate(); } + static Tensor range(TensorType type) { return new Range<>(type).evaluate(); } // ----------------- Composite tensor functions mapped to primitives here on the fly diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 30f7185959c..3812dd26370 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -110,7 +110,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return Long.parseLong(labels[i]); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i); + throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 9ec105f8174..6e6b42cc1cd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.Tensor; * * @author bratseth */ -public interface EvaluationContext<NAMETYPE extends TypeContext.Name> extends TypeContext<NAMETYPE> { +public interface EvaluationContext<NAMETYPE extends Name> extends TypeContext<NAMETYPE> { /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index e302e317418..f684987476f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -9,7 +9,7 @@ import java.util.HashMap; /** * @author bratseth */ -public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> { +public class MapEvaluationContext<NAMETYPE extends Name> implements EvaluationContext<NAMETYPE> { private final java.util.Map<String, Tensor> bindings = new HashMap<>(); @@ -17,14 +17,14 @@ public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> @Override public TensorType getType(String name) { - return getType(new Name(name)); + Tensor tensor = bindings.get(name); + if (tensor == null) return null; + return tensor.type(); } @Override - public TensorType getType(Name name) { - Tensor tensor = bindings.get(name.toString()); - if (tensor == null) return null; - return tensor.type(); + public TensorType getType(NAMETYPE name) { + return getType(name.name()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java new file mode 100644 index 00000000000..9033af1d7ec --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/Name.java @@ -0,0 +1,28 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.evaluation; + +/** A name which is just a string. Names are value objects. */ +public class Name { + + private final String name; + + public Name(String name) { + this.name = name; + } + + public String name() { return name; } + + @Override + public String toString() { return name; } + + @Override + public int hashCode() { return name.hashCode(); } + + @Override + public boolean equals(Object other) { + if (other == this) return true; + if ( ! (other instanceof Name)) return false; + return ((Name)other).name.equals(this.name); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index 1437fd91974..84d82b624ba 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -8,7 +8,7 @@ import com.yahoo.tensor.TensorType; * * @author bratseth */ -public interface TypeContext<NAMETYPE extends TypeContext.Name> { +public interface TypeContext<NAMETYPE extends Name> { /** * Returns the type of the tensor with this name. @@ -26,31 +26,5 @@ public interface TypeContext<NAMETYPE extends TypeContext.Name> { */ TensorType getType(String name); - /** A name which is just a string. Names are value objects. */ - class Name { - - private final String name; - - public Name(String name) { - this.name = name; - } - - public String name() { return name; } - - @Override - public String toString() { return name; } - - @Override - public int hashCode() { return name.hashCode(); } - - @Override - public boolean equals(Object other) { - if (other == this) return true; - if ( ! (other instanceof Name)) return false; - return ((Name)other).name.equals(this.name); - } - - } - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index c1cfa319664..8ea82aa4a79 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -16,7 +16,7 @@ import java.util.Optional; * * @author bratseth */ -public class VariableTensor extends PrimitiveTensorFunction { +public class VariableTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { private final String name; private final Optional<TensorType> requiredType; @@ -33,16 +33,16 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { return this; } + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); @@ -50,7 +50,7 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); @@ -67,4 +67,5 @@ public class VariableTensor extends PrimitiveTensorFunction { throw new IllegalArgumentException("Variable '" + name + "' must be compatible with " + requiredType.get() + " but was " + givenType); } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 3478061b32c..a365f0f4bdc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -1,38 +1,40 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class Argmax extends CompositeTensorFunction { +public class Argmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Argmax(TensorFunction argument, String dimension) { + public Argmax(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); - return new Argmax(arguments.get(0), dimension); + return new Argmax<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension), - ScalarFunctions.equal()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(primitiveArgument, + new Reduce<>(primitiveArgument, Reduce.Aggregator.max, dimension), + ScalarFunctions.equal()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index ba5b3c3e4b2..32ccdf51336 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -1,38 +1,40 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class Argmin extends CompositeTensorFunction { +public class Argmin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Argmin(TensorFunction argument, String dimension) { + public Argmin(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); - return new Argmin(arguments.get(0), dimension); + return new Argmin<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension), - ScalarFunctions.equal()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(primitiveArgument, + new Reduce<>(primitiveArgument, Reduce.Aggregator.min, dimension), + ScalarFunctions.equal()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 5dd2cc442aa..eacc4493035 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; /** @@ -12,17 +13,17 @@ import com.yahoo.tensor.evaluation.TypeContext; * * @author bratseth */ -public abstract class CompositeTensorFunction extends TensorFunction { +public abstract class CompositeTensorFunction<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public final TensorType type(TypeContext<NAMETYPE> context) { return toPrimitive().type(context); } /** Evaluates this by first converting it to a primitive function */ @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return toPrimitive().evaluate(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 42c6fe2f4aa..fff2ddaf320 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -8,6 +8,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Arrays; @@ -23,12 +24,12 @@ import java.util.stream.Collectors; * * @author bratseth */ -public class Concat extends PrimitiveTensorFunction { +public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final String dimension; - public Concat(TensorFunction argumentA, TensorFunction argumentB, String dimension) { + public Concat(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, String dimension) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(dimension, "The dimension cannot be null"); @@ -38,18 +39,18 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); - return new Concat(arguments.get(0), arguments.get(1), dimension); + return new Concat<>(arguments.get(0), arguments.get(1), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Concat(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Concat<>(argumentA.toPrimitive(), argumentB.toPrimitive(), dimension); } @Override @@ -58,7 +59,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } @@ -86,7 +87,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 7c1ce068c90..bb7481f7c64 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -4,6 +4,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; @@ -14,7 +15,7 @@ import java.util.List; * * @author bratseth */ -public class ConstantTensor extends PrimitiveTensorFunction { +public class ConstantTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { private final Tensor constant; @@ -27,23 +28,23 @@ public class ConstantTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } + public TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index e302f6606e7..203331a1c0d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; @@ -14,7 +15,7 @@ import java.util.stream.Stream; * * @author bratseth */ -public class Diag extends CompositeTensorFunction { +public class Diag<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorType type; private final Function<List<Long>, Double> diagFunction; @@ -25,18 +26,18 @@ public class Diag extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Generate(type, diagFunction); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Generate<>(type, diagFunction); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index b8b644f8b49..416940a60eb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -7,20 +7,19 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Map; -import java.util.function.Function; /** * A function which is a tensor whose values are computed by individual lambda functions on evaluation. * * @author bratseth */ -public abstract class DynamicTensor extends PrimitiveTensorFunction { +public abstract class DynamicTensor<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { private final TensorType type; @@ -29,20 +28,20 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 0) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } TensorType type() { return type; } @@ -54,26 +53,26 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { abstract String contentToString(ToStringContext context); /** Creates a dynamic tensor function. The cell addresses must match the type. */ - public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) { - return new MappedDynamicTensor(type, cells); + public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { + return new MappedDynamicTensor<>(type, cells); } /** Creates a dynamic tensor function for a bound, indexed tensor */ - public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) { - return new IndexedDynamicTensor(type, cells); + public static <NAMETYPE extends Name> DynamicTensor<NAMETYPE> from(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { + return new IndexedDynamicTensor<>(type, cells); } - private static class MappedDynamicTensor extends DynamicTensor { + private static class MappedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> { - private final ImmutableMap<TensorAddress, ScalarFunction> cells; + private final ImmutableMap<TensorAddress, ScalarFunction<NAMETYPE>> cells; - MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) { + MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { super(type); this.cells = ImmutableMap.copyOf(cells); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) builder.cell(cell.getKey(), cell.getValue().apply(context)); @@ -101,11 +100,11 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } - private static class IndexedDynamicTensor extends DynamicTensor { + private static class IndexedDynamicTensor<NAMETYPE extends Name> extends DynamicTensor<NAMETYPE> { - private final List<ScalarFunction> cells; + private final List<ScalarFunction<NAMETYPE>> cells; - IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) { + IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + @@ -114,7 +113,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) builder.cellByDirectIndex(i, cells.get(i).apply(context)); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index 83cba3479e2..e5095178be7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -6,11 +6,13 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.function.Function; /** @@ -18,25 +20,49 @@ import java.util.function.Function; * * @author bratseth */ -public class Generate extends PrimitiveTensorFunction { +public class Generate<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { private final TensorType type; - private final Function<List<Long>, Double> generator; + + // One of these are null + private final Function<List<Long>, Double> freeGenerator; + private final ScalarFunction<NAMETYPE> boundGenerator; + + /** The same as Generate.free */ + public Generate(TensorType type, Function<List<Long>, Double> generator) { + this(type, Objects.requireNonNull(generator), null); + } /** - * Creates a generated tensor + * Creates a generated tensor from a free function * * @param type the type of the tensor * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function<List<Long>, Double> generator) { + public static <NAMETYPE extends Name> Generate<NAMETYPE> free(TensorType type, Function<List<Long>, Double> generator) { + return new Generate<>(type, Objects.requireNonNull(generator), null); + } + + /** + * Creates a generated tensor from a bound function + * + * @param type the type of the tensor + * @param generator the function generating values from a list of numbers specifying the indexes of the + * tensor cell which will receive the value + * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound + */ + public static <NAMETYPE extends Name> Generate<NAMETYPE> bound(TensorType type, ScalarFunction<NAMETYPE> generator) { + return new Generate<>(type, null, Objects.requireNonNull(generator)); + } + + private Generate(TensorType type, Function<List<Long>, Double> freeGenerator, ScalarFunction<NAMETYPE> boundGenerator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); - Objects.requireNonNull(generator, "The argument function cannot be null"); validateType(type); this.type = type; - this.generator = generator; + this.freeGenerator = freeGenerator; + this.boundGenerator = boundGenerator; } private void validateType(TensorType type) { @@ -46,28 +72,29 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); + GenerateContext generateContext = new GenerateContext(type, context); for (int i = 0; i < indexes.size(); i++) { indexes.next(); - builder.cell(generator.apply(indexes.toList()), indexes.indexesForReading()); + builder.cell(generateContext.apply(indexes), indexes.indexesForReading()); } return builder.build(); } @@ -80,6 +107,70 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public String toString(ToStringContext context) { return type + "(" + generator + ")"; } + public String toString(ToStringContext context) { return type + "(" + generatorToString(context) + ")"; } + + private String generatorToString(ToStringContext context) { + if (freeGenerator != null) + return freeGenerator.toString(); + else + return boundGenerator.toString(context); + } + + /** + * A context for generating all the values of a tensor produced by evaluating Generate. + * This returns all the current index values as variables and falls back to delivering from the given + * evaluation context. + */ + private class GenerateContext implements EvaluationContext<NAMETYPE> { + + private final TensorType type; + private final EvaluationContext<NAMETYPE> context; + + private IndexedTensor.Indexes indexes; + + GenerateContext(TensorType type, EvaluationContext<NAMETYPE> context) { + this.type = type; + this.context = context; + } + + @SuppressWarnings("unchecked") + double apply(IndexedTensor.Indexes indexes) { + if (freeGenerator != null) { + return freeGenerator.apply(indexes.toList()); + } + else { + this.indexes = indexes; + return boundGenerator.apply(this); + } + } + + @Override + public Tensor getTensor(String name) { + Optional<Integer> index = type.indexOfDimension(name); + if (index.isPresent()) // this is the name of a dimension + return Tensor.from(indexes.indexesForReading()[index.get()]); + else + return context.getTensor(name); + } + + @Override + public TensorType getType(NAMETYPE name) { + Optional<Integer> index = type.indexOfDimension(name.name()); + if (index.isPresent()) // this is the name of a dimension + return TensorType.empty; + else + return context.getType(name); + } + + @Override + public TensorType getType(String name) { + Optional<Integer> index = type.indexOfDimension(name); + if (index.isPresent()) // this is the name of a dimension + return TensorType.empty; + else + return context.getType(name); + } + + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 2939b964f04..1e0eaa7fad3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -10,6 +10,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.ArrayList; @@ -31,12 +32,12 @@ import java.util.function.DoubleBinaryOperator; * * @author bratseth */ -public class Join extends PrimitiveTensorFunction { +public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final DoubleBinaryOperator combinator; - public Join(TensorFunction argumentA, TensorFunction argumentB, DoubleBinaryOperator combinator) { + public Join(TensorFunction<NAMETYPE> argumentA, TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator) { Objects.requireNonNull(argumentA, "The first argument tensor cannot be null"); Objects.requireNonNull(argumentB, "The second argument tensor cannot be null"); Objects.requireNonNull(combinator, "The combinator function cannot be null"); @@ -53,18 +54,18 @@ public class Join extends PrimitiveTensorFunction { public DoubleBinaryOperator combinator() { return combinator; } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); - return new Join(arguments.get(0), arguments.get(1), combinator); + return new Join<>(arguments.get(0), arguments.get(1), combinator); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); } @Override @@ -73,12 +74,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 7939457a101..ed4da6678ce 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -1,39 +1,41 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class L1Normalize extends CompositeTensorFunction { +public class L1Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public L1Normalize(TensorFunction argument, String dimension) { + public L1Normalize(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size()); - return new L1Normalize(arguments.get(0), dimension); + return new L1Normalize<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); // join(x, reduce(x, "avg", "dimension"), f(x,y) (x / y)) - return new Join(primitiveArgument, - new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), - ScalarFunctions.divide()); + return new Join<>(primitiveArgument, + new Reduce<>(primitiveArgument, Reduce.Aggregator.sum, dimension), + ScalarFunctions.divide()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index 40edb8ba23f..93b2b377176 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -1,40 +1,42 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; + import java.util.Collections; import java.util.List; /** * @author bratseth */ -public class L2Normalize extends CompositeTensorFunction { +public class L2Normalize<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public L2Normalize(TensorFunction argument, String dimension) { + public L2Normalize(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size()); - return new L2Normalize(arguments.get(0), dimension); + return new L2Normalize<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(primitiveArgument, - new Map(new Reduce(new Map(primitiveArgument, ScalarFunctions.square()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.sqrt()), + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(primitiveArgument, + new Map<>(new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.square()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.sqrt()), ScalarFunctions.divide()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 016c60c6897..0ddf0bb4e63 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; @@ -18,12 +19,12 @@ import java.util.function.DoubleUnaryOperator; * * @author bratseth */ -public class Map extends PrimitiveTensorFunction { +public class Map<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final DoubleUnaryOperator mapper; - public Map(TensorFunction argument, DoubleUnaryOperator mapper) { + public Map(TensorFunction<NAMETYPE> argument, DoubleUnaryOperator mapper) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(mapper, "The argument function cannot be null"); this.argument = argument; @@ -32,31 +33,31 @@ public class Map extends PrimitiveTensorFunction { public static TensorType outputType(TensorType inputType) { return inputType; } - public TensorFunction argument() { return argument; } + public TensorFunction<NAMETYPE> argument() { return argument; } public DoubleUnaryOperator mapper() { return mapper; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size()); - return new Map(arguments.get(0), mapper); + return new Map<>(arguments.get(0), mapper); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Map(argument.toPrimitive(), mapper); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Map<>(argument.toPrimitive(), mapper); } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return argument.type(context); } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 7c65afc98f9..54bfdd4a732 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -3,18 +3,19 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; import java.util.List; /** * @author bratseth */ -public class Matmul extends CompositeTensorFunction { +public class Matmul<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument1, argument2; + private final TensorFunction<NAMETYPE> argument1, argument2; private final String dimension; - public Matmul(TensorFunction argument1, TensorFunction argument2, String dimension) { + public Matmul(TensorFunction<NAMETYPE> argument1, TensorFunction<NAMETYPE> argument2, String dimension) { this.argument1 = argument1; this.argument2 = argument2; this.dimension = dimension; @@ -25,22 +26,22 @@ public class Matmul extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argument1, argument2); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); - return new Matmul(arguments.get(0), arguments.get(1), dimension); + return new Matmul<>(arguments.get(0), arguments.get(1), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument1 = argument1.toPrimitive(); - TensorFunction primitiveArgument2 = argument2.toPrimitive(); - return new Reduce(new Join(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument1 = argument1.toPrimitive(); + TensorFunction<NAMETYPE> primitiveArgument2 = argument2.toPrimitive(); + return new Reduce<>(new Join<>(primitiveArgument1, primitiveArgument2, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java index e2aae39f11f..99117bb250e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/PrimitiveTensorFunction.java @@ -1,6 +1,8 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor.functions; +import com.yahoo.tensor.evaluation.Name; + /** * A primitive tensor function is a tensor function which cannot be expressed in terms of other tensor functions. * All tensor implementations must implement all primitive tensor functions. @@ -8,6 +10,6 @@ package com.yahoo.tensor.functions; * * @author bratseth */ -public abstract class PrimitiveTensorFunction extends TensorFunction { +public abstract class PrimitiveTensorFunction<NAMETYPE extends Name> extends TensorFunction<NAMETYPE> { } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 7175c91ed33..b459b1a8ddd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; @@ -13,7 +14,7 @@ import java.util.stream.Stream; * * @author bratseth */ -public class Random extends CompositeTensorFunction { +public class Random<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorType type; @@ -22,18 +23,18 @@ public class Random extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Generate(type, ScalarFunctions.random()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Generate<>(type, ScalarFunctions.random()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index d951ec9ccbd..00d0e4b4818 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; @@ -15,7 +16,7 @@ import java.util.stream.Stream; * * @author bratseth */ -public class Range extends CompositeTensorFunction { +public class Range<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { private final TensorType type; private final Function<List<Long>, Double> rangeFunction; @@ -26,18 +27,18 @@ public class Range extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Generate(type, rangeFunction); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Generate<>(type, rangeFunction); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 017dc3920e6..1eb09a603fa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; @@ -24,21 +25,21 @@ import java.util.Set; * * @author bratseth */ -public class Reduce extends PrimitiveTensorFunction { +public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { public enum Aggregator { avg, count, prod, sum, max, min; } - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final List<String> dimensions; private final Aggregator aggregator; /** Creates a reduce function reducing all dimensions */ - public Reduce(TensorFunction argument, Aggregator aggregator) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator) { this(argument, aggregator, Collections.emptyList()); } /** Creates a reduce function reducing a single dimension */ - public Reduce(TensorFunction argument, Aggregator aggregator, String dimension) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, String dimension) { this(argument, aggregator, Collections.singletonList(dimension)); } @@ -51,7 +52,7 @@ public class Reduce extends PrimitiveTensorFunction { * producing a dimensionless tensor (a scalar). * @throws IllegalArgumentException if any of the tensor dimensions are not present in the input tensor */ - public Reduce(TensorFunction argument, Aggregator aggregator, List<String> dimensions) { + public Reduce(TensorFunction<NAMETYPE> argument, Aggregator aggregator, List<String> dimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(aggregator, "The aggregator cannot be null"); Objects.requireNonNull(dimensions, "The dimensions cannot be null"); @@ -70,25 +71,25 @@ public class Reduce extends PrimitiveTensorFunction { return b.build(); } - public TensorFunction argument() { return argument; } + public TensorFunction<NAMETYPE> argument() { return argument; } Aggregator aggregator() { return aggregator; } List<String> dimensions() { return dimensions; } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); - return new Reduce(arguments.get(0), aggregator, dimensions); + return new Reduce<>(arguments.get(0), aggregator, dimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { - return new Reduce(argument.toPrimitive(), aggregator, dimensions); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new Reduce<>(argument.toPrimitive(), aggregator, dimensions); } @Override @@ -104,7 +105,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context), dimensions); } @@ -118,7 +119,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { return evaluate(this.argument.evaluate(context), dimensions, aggregator); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index 1134e8177ad..83807a20ec9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -7,7 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; -import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.evaluation.Name; import java.util.Arrays; import java.util.List; @@ -26,19 +26,19 @@ import java.util.stream.Collectors; * * @author lesters */ -public class ReduceJoin extends CompositeTensorFunction { +public class ReduceJoin<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argumentA, argumentB; + private final TensorFunction<NAMETYPE> argumentA, argumentB; private final DoubleBinaryOperator combinator; private final Reduce.Aggregator aggregator; private final List<String> dimensions; - public ReduceJoin(Reduce reduce, Join join) { + public ReduceJoin(Reduce<NAMETYPE> reduce, Join<NAMETYPE> join) { this(join.arguments().get(0), join.arguments().get(1), join.combinator(), reduce.aggregator(), reduce.dimensions()); } - public ReduceJoin(TensorFunction argumentA, - TensorFunction argumentB, + public ReduceJoin(TensorFunction<NAMETYPE> argumentA, + TensorFunction<NAMETYPE> argumentB, DoubleBinaryOperator combinator, Reduce.Aggregator aggregator, List<String> dimensions) { @@ -50,25 +50,25 @@ public class ReduceJoin extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("ReduceJoin must have 2 arguments, got " + arguments.size()); - return new ReduceJoin(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions); + return new ReduceJoin<>(arguments.get(0), arguments.get(1), combinator, aggregator, dimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { - Join join = new Join(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); - return new Reduce(join, aggregator, dimensions); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + Join<NAMETYPE> join = new Join<>(argumentA.toPrimitive(), argumentB.toPrimitive(), combinator); + return new Reduce<>(join, aggregator, dimensions); } @Override - public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public final Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index 5694684956e..275b546c0aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -6,6 +6,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.Collections; @@ -20,18 +21,18 @@ import java.util.Objects; * * @author bratseth */ -public class Rename extends PrimitiveTensorFunction { +public class Rename<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final List<String> fromDimensions; private final List<String> toDimensions; private final Map<String, String> fromToMap; - public Rename(TensorFunction argument, String fromDimension, String toDimension) { + public Rename(TensorFunction<NAMETYPE> argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); } - public Rename(TensorFunction argument, List<String> fromDimensions, List<String> toDimensions) { + public Rename(TensorFunction<NAMETYPE> argument, List<String> fromDimensions, List<String> toDimensions) { Objects.requireNonNull(argument, "The argument tensor cannot be null"); Objects.requireNonNull(fromDimensions, "The 'from' dimensions cannot be null"); Objects.requireNonNull(toDimensions, "The 'to' dimensions cannot be null"); @@ -57,20 +58,20 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); - return new Rename(arguments.get(0), fromDimensions, toDimensions); + return new Rename<>(arguments.get(0), fromDimensions, toDimensions); } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -82,7 +83,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java index c6a244b64df..07b3658fb58 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunction.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import java.util.function.Function; @@ -10,10 +11,10 @@ import java.util.function.Function; * * @author bratseth */ -public interface ScalarFunction extends Function<EvaluationContext<?>, Double> { +public interface ScalarFunction<NAMETYPE extends Name> extends Function<EvaluationContext<NAMETYPE>, Double> { @Override - Double apply(EvaluationContext<?> context); + Double apply(EvaluationContext<NAMETYPE> context); default String toString(ToStringContext context) { return toString(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index bd732cdc11e..755711a4d44 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; import java.util.Collections; import java.util.List; @@ -10,12 +11,12 @@ import java.util.List; /** * @author bratseth */ -public class Softmax extends CompositeTensorFunction { +public class Softmax<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction argument; + private final TensorFunction<NAMETYPE> argument; private final String dimension; - public Softmax(TensorFunction argument, String dimension) { + public Softmax(TensorFunction<NAMETYPE> argument, String dimension) { this.argument = argument; this.dimension = dimension; } @@ -25,23 +26,23 @@ public class Softmax extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return Collections.singletonList(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); - return new Softmax(arguments.get(0), dimension); + return new Softmax<>(arguments.get(0), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveArgument = argument.toPrimitive(); - return new Join(new Map(primitiveArgument, ScalarFunctions.exp()), - new Reduce(new Map(primitiveArgument, ScalarFunctions.exp()), - Reduce.Aggregator.sum, - dimension), - ScalarFunctions.divide()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveArgument = argument.toPrimitive(); + return new Join<>(new Map<>(primitiveArgument, ScalarFunctions.exp()), + new Reduce<>(new Map<>(primitiveArgument, ScalarFunctions.exp()), + Reduce.Aggregator.sum, + dimension), + ScalarFunctions.divide()); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 810651bbcfb..b4c5dedbf4e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -5,6 +5,7 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; import java.util.List; @@ -16,17 +17,17 @@ import java.util.List; * * @author bratseth */ -public abstract class TensorFunction { +public abstract class TensorFunction<NAMETYPE extends Name> { /** Returns the function arguments of this node in the order they are applied */ - public abstract List<TensorFunction> arguments(); + public abstract List<TensorFunction<NAMETYPE>> arguments(); /** * Returns a copy of this tensor function with the arguments replaced by the given list of arguments. * * @throws IllegalArgumentException if the argument list has the wrong size for this function */ - public abstract TensorFunction withArguments(List<TensorFunction> arguments); + public abstract TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments); /** * Translate this function - and all of its arguments recursively - @@ -34,24 +35,24 @@ public abstract class TensorFunction { * * @return a tree of primitive functions implementing this */ - public abstract PrimitiveTensorFunction toPrimitive(); + public abstract PrimitiveTensorFunction<NAMETYPE> toPrimitive(); /** * Evaluates this tensor. * * @param context a context which must be passed to all nested functions when evaluating */ - public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); + public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); + public abstract TensorType type(TypeContext<NAMETYPE> context); /** Evaluate with no context */ - public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } + public final Tensor evaluate() { return evaluate(new MapEvaluationContext<>()); } /** * Return a string representation of this context. diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java new file mode 100644 index 00000000000..cb14711c0dd --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java @@ -0,0 +1,185 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.google.common.annotations.Beta; +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Returns the value of a cell of a tensor (as a rank 0 tensor). + * + * @author bratseth + */ +@Beta +public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final List<DimensionValue<NAMETYPE>> cellAddress; + + /** + * Creates a value function + * + * @param argument the tensor to return a cell value from + * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress + * because those require a type, but a type is not resolved until this is evaluated + */ + public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) { + this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); + if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) + throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + + "Specify dimension names explicitly"); + this.cellAddress = cellAddress; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } + + @Override + public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); + return new Value<NAMETYPE>(arguments.get(0), cellAddress); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor tensor = argument.evaluate(context); + if (tensor.type().rank() != cellAddress.size()) + throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + + " to a tensor of type " + tensor.type()); + TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); + for (int i = 0; i < cellAddress.size(); i++) { + if (cellAddress.get(i).label().isPresent()) + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + cellAddress.get(i).label().get()); + else + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + String.valueOf(cellAddress.get(i).index().get().apply(context).intValue())); + } + return Tensor.from(tensor.get(b.build())); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return new TensorType.Builder(argument.type(context).valueType()).build(); + } + + @Override + public String toString(ToStringContext context) { + return toString(); + } + + @Override + public String toString() { + if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) { + if (cellAddress.get(0).index().isPresent()) + return "[" + cellAddress.get(0).index().get() + "]"; + else + return "{" + cellAddress.get(0).label() + "}"; + } + else { + return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}"; + } + } + + public static class DimensionValue<NAMETYPE extends Name> { + + private final Optional<String> dimension; + + /** The label of this, or null if index is set */ + private final String label; + + /** The function returning the index of this, or null if label is set */ + private final ScalarFunction<NAMETYPE> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, null); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), null, new ConstantScalarFunction<>(index)); + } + + public DimensionValue(int index) { + this(Optional.empty(), null, new ConstantScalarFunction<>(index)); + } + + public DimensionValue(String label) { + this(Optional.empty(), label, null); + } + + public DimensionValue(ScalarFunction<NAMETYPE> index) { + this(Optional.empty(), null, index); + } + + public DimensionValue(Optional<String> dimension, String label) { + this(dimension, label, null); + } + + public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { + this(dimension, null, index); + } + + public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { + this(Optional.of(dimension), null, index); + } + + private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { + this.dimension = dimension; + this.label = label; + this.index = index; + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label for this dimension or empty if it is provided by an index function */ + public Optional<String> label() { return Optional.ofNullable(label); } + + /** Returns the index expression for this dimension, or empty if it is not a number */ + public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } + + @Override + public String toString() { + StringBuilder b = new StringBuilder(); + dimension.ifPresent(d -> b.append(d).append(":")); + if (label != null) + b.append(label); + else + b.append(index); + return b.toString(); + } + + } + + private static class ConstantScalarFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> { + + private final Double value; + + public ConstantScalarFunction(int value) { + this.value = (double)value; + } + + @Override + public Double apply(EvaluationContext<NAMETYPE> context) { + return value; + } + + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 4c0748ee39a..53e23674617 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -2,18 +2,19 @@ package com.yahoo.tensor.functions; import com.google.common.collect.ImmutableList; +import com.yahoo.tensor.evaluation.Name; import java.util.List; /** * @author bratseth */ -public class XwPlusB extends CompositeTensorFunction { +public class XwPlusB<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { - private final TensorFunction x, w, b; + private final TensorFunction<NAMETYPE> x, w, b; private final String dimension; - public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { + public XwPlusB(TensorFunction<NAMETYPE> x, TensorFunction<NAMETYPE> w, TensorFunction<NAMETYPE> b, String dimension) { this.x = x; this.w = w; this.b = b; @@ -21,25 +22,25 @@ public class XwPlusB extends CompositeTensorFunction { } @Override - public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); } + public List<TensorFunction<NAMETYPE>> arguments() { return ImmutableList.of(x, w, b); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if ( arguments.size() != 3) throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size()); - return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension); + return new XwPlusB<>(arguments.get(0), arguments.get(1), arguments.get(2), dimension); } @Override - public PrimitiveTensorFunction toPrimitive() { - TensorFunction primitiveX = x.toPrimitive(); - TensorFunction primitiveW = w.toPrimitive(); - TensorFunction primitiveB = b.toPrimitive(); - return new Join(new Reduce(new Join(primitiveX, primitiveW, ScalarFunctions.multiply()), - Reduce.Aggregator.sum, - dimension), - primitiveB, - ScalarFunctions.add()); + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorFunction<NAMETYPE> primitiveX = x.toPrimitive(); + TensorFunction<NAMETYPE> primitiveW = w.toPrimitive(); + TensorFunction<NAMETYPE> primitiveB = b.toPrimitive(); + return new Join<>(new Reduce<>(new Join<>(primitiveX, primitiveW, ScalarFunctions.multiply()), + Reduce.Aggregator.sum, + dimension), + primitiveB, + ScalarFunctions.add()); } @Override diff --git a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java index 439aac5578a..c334c58042c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/MatrixDotProductBenchmark.java @@ -2,17 +2,16 @@ package com.yahoo.tensor; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; import com.yahoo.tensor.functions.TensorFunction; -import java.util.ArrayList; import java.util.Collections; import java.util.List; import java.util.Random; -import java.util.stream.Collectors; /** * Microbenchmark of a "dot product" of two mapped rank 2 tensors @@ -42,10 +41,10 @@ public class MatrixDotProductBenchmark { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), - Reduce.Aggregator.sum).toPrimitive(); - MapEvaluationContext context = new MapEvaluationContext(); + TensorFunction<Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor), + new VariableTensor<>("argument"), (a, b) -> a * b), + Reduce.Aggregator.sum).toPrimitive(); + MapEvaluationContext<Name> context = new MapEvaluationContext<>(); for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 7b856dde2d5..b3c6fbc6862 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -2,6 +2,7 @@ package com.yahoo.tensor; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; @@ -49,10 +50,10 @@ public class TensorFunctionBenchmark { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double largest = Double.MIN_VALUE; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), - Reduce.Aggregator.sum).toPrimitive(); - MapEvaluationContext context = new MapEvaluationContext(); + TensorFunction<Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor), + new VariableTensor<>("argument"), (a, b) -> a * b), + Reduce.Aggregator.sum).toPrimitive(); + MapEvaluationContext<Name> context = new MapEvaluationContext<>(); for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index c6fbb9c009d..11365531019 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -3,6 +3,7 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.VariableTensor; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; @@ -307,10 +308,10 @@ public class TensorTestCase { private double dotProduct(Tensor tensor, List<Tensor> tensors) { double sum = 0; - TensorFunction dotProductFunction = new Reduce(new Join(new ConstantTensor(tensor), - new VariableTensor("argument"), (a, b) -> a * b), - Reduce.Aggregator.sum).toPrimitive(); - MapEvaluationContext context = new MapEvaluationContext(); + TensorFunction<Name> dotProductFunction = new Reduce<>(new Join<>(new ConstantTensor<>(tensor), + new VariableTensor<>("argument"), (a, b) -> a * b), + Reduce.Aggregator.sum).toPrimitive(); + MapEvaluationContext<Name> context = new MapEvaluationContext<>(); for (Tensor tensorElement : tensors) { // tensors.size() = 1 for larger tensor context.put("argument", tensorElement); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java index eafa5c4addf..0476fe1c757 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ConcatTestCase.java @@ -101,8 +101,8 @@ public class ConcatTestCase { private void assertConcat(String expectedType, String expected, Tensor a, Tensor b, String dimension) { Tensor expectedAsTensor = Tensor.from(expected); - TensorType inferredType = new Concat(new ConstantTensor(a), new ConstantTensor(b), dimension) - .type(new MapEvaluationContext()); + TensorType inferredType = new Concat<>(new ConstantTensor<>(a), new ConstantTensor<>(b), dimension) + .type(new MapEvaluationContext<>()); Tensor result = a.concat(b, dimension); if (expectedType != null) diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java index 925da9d3c89..e16b7b90a1d 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/DynamicTensorTestCase.java @@ -5,11 +5,11 @@ import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; import org.junit.Test; import java.util.Collections; import java.util.List; -import java.util.function.Function; import static org.junit.Assert.assertEquals; @@ -21,27 +21,27 @@ public class DynamicTensorTestCase { @Test public void testDynamicTensorFunction() { TensorType dense = TensorType.fromSpec("tensor(x[3])"); - DynamicTensor t1 = DynamicTensor.from(dense, - List.of(new Constant(1), new Constant(2), new Constant(3))); + DynamicTensor<Name> t1 = DynamicTensor.from(dense, + List.of(new Constant(1), new Constant(2), new Constant(3))); assertEquals(Tensor.from(dense, "[1, 2, 3]"), t1.evaluate()); assertEquals("tensor(x[3]):{{x:0}:1.0,{x:1}:2.0,{x:2}:3.0}", t1.toString()); TensorType sparse = TensorType.fromSpec("tensor(x{})"); - DynamicTensor t2 = DynamicTensor.from(sparse, - Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), - new Constant(5))); + DynamicTensor<Name> t2 = DynamicTensor.from(sparse, + Collections.singletonMap(new TensorAddress.Builder(sparse).add("x", "a").build(), + new Constant(5))); assertEquals(Tensor.from(sparse, "{{x:a}:5}"), t2.evaluate()); assertEquals("tensor(x{}):{{x:a}:5.0}", t2.toString()); } - private static class Constant implements ScalarFunction { + private static class Constant implements ScalarFunction<Name> { private final double value; public Constant(double value) { this.value = value; } @Override - public Double apply(EvaluationContext<?> evaluationContext) { return value; } + public Double apply(EvaluationContext<Name> evaluationContext) { return value; } @Override public String toString() { return String.valueOf(value); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java index e37bee2d990..e6560242d5c 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/TensorFunctionTestCase.java @@ -2,6 +2,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; import org.junit.Test; import static org.junit.Assert.assertEquals; @@ -15,14 +16,14 @@ public class TensorFunctionTestCase { @Test public void testTranslation() { assertTranslated("join(tensor(x{}):{{x:1}:1.0}, reduce(tensor(x{}):{{x:1}:1.0}, sum, x), f(a,b)(a / b))", - new L1Normalize(new ConstantTensor("{{x:1}:1.0}"), "x")); + new L1Normalize<>(new ConstantTensor<>("{{x:1}:1.0}"), "x")); assertTranslated("tensor(x[2],y[3],z[4])((x==y)*(y==z))", - new Diag(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); + new Diag<>(new TensorType.Builder().indexed("y",3).indexed("x",2).indexed("z",4).build())); assertTranslated("join(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, reduce(tensor(x{}):{{x:1}:1.0,{x:3}:5.0,{x:9}:3.0}, max, x), f(a,b)(a==b))", - new Argmax(new ConstantTensor("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); + new Argmax<>(new ConstantTensor<>("{ {x:1}:1, {x:3}:5, {x:9}:3 }"), "x")); } - private void assertTranslated(String expectedTranslation, TensorFunction inputFunction) { + private void assertTranslated(String expectedTranslation, TensorFunction<Name> inputFunction) { assertEquals(expectedTranslation, inputFunction.toPrimitive().toString()); } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java new file mode 100644 index 00000000000..7127abde016 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java @@ -0,0 +1,66 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import org.junit.Test; + +import java.util.List; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +/** + * @author bratseth + */ +public class ValueTestCase { + + private static final double delta = 0.000001; + + @Test + public void testValueFunctionGeneralForm() { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + Tensor result = new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>("key", "bar"), + new Value.DimensionValue<>("x", 0))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(2.3, result.asDouble(), delta); + } + + @Test + public void testValueFunctionSingleMappedDimension() { + Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }"); + Tensor result = new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>("foo"))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(1.4, result.asDouble(), delta); + } + + @Test + public void testValueFunctionSingleIndexedDimension() { + Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]"); + Tensor result = new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>(2))) + .evaluate(); + assertEquals(0, result.type().rank()); + assertEquals(3.3, result.asDouble(), delta); + } + + @Test + public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() { + try { + Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }"); + new Value<>(new ConstantTensor<>(input), + List.of(new Value.DimensionValue<>("bar"), + new Value.DimensionValue<>(0))) + .evaluate(); + fail("Expected exception"); + } + catch (IllegalArgumentException e) { + assertEquals("Short form of cell addresses is only supported with a single dimension: Specify dimension names explicitly", + e.getMessage()); + } + } + +} |