diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-02-17 18:03:48 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-02-17 18:03:48 +0100 |
commit | 8c0b40b3dc30d7820f88017a590cd3e506ee2aa9 (patch) | |
tree | 6edd4b4af48e1f27271e1b1935ee6eab2f0daede /vespajlib/src/main | |
parent | d7e1e3b5b24b0f9f0e3dfcc6d1e37d442f1de4e8 (diff) |
Propagate type information
Diffstat (limited to 'vespajlib/src/main')
12 files changed, 39 insertions, 22 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index 74457a163fd..b9394da31e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -18,6 +18,11 @@ public class MapEvaluationContext implements EvaluationContext<TypeContext.Name> public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override + public TensorType getType(String name) { + return getType(new Name(name)); + } + + @Override public TensorType getType(Name name) { Tensor tensor = bindings.get(name.toString()); if (tensor == null) return null; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index 3674b373db0..ff2e6318b37 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -18,6 +18,14 @@ public interface TypeContext<NAMETYPE extends TypeContext.Name> { */ TensorType getType(NAMETYPE name); + /** + * Returns the type of the tensor with this name by converting from a string name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ + TensorType getType(String name); + /** A name which is just a string. Names are value objects. */ class Name { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 5f809a3d2b1..acb2363cba4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -44,15 +44,15 @@ public class VariableTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { - TensorType givenType = context.getType(new TypeContext.Name(name)); + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + TensorType givenType = context.getType(name); if (givenType == null) return null; verifyType(givenType); return givenType; } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = context.getTensor(name); if (tensor == null) return null; verifyType(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 2109b730e1a..bfc0938abcc 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -18,10 +18,14 @@ public abstract class CompositeTensorFunction extends TensorFunction { /** Finds the type this produces by first converting it to a primitive function */ @Override - public final TensorType type(TypeContext context) { return toPrimitive().type(context); } + public final <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + return toPrimitive().type(context); + } /** Evaluates this by first converting it to a primitive function */ @Override - public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } + public final <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + return toPrimitive().evaluate(context); + } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index c77ed1c0526..a073053bec8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -60,7 +60,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argumentA.type(context), argumentB.type(context)); } @@ -74,7 +74,7 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); a = ensureIndexedDimension(dimension, a); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 50b479da168..a43de297b9a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -42,10 +42,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return constant.type(); } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return constant.type(); } @Override - public Tensor evaluate(EvaluationContext context) { return constant; } + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { return constant; } @Override public String toString(ToStringContext context) { return constant.toString(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e70d1de3db7..edfa8253eb9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -61,10 +61,10 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { return type; } + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); for (int i = 0; i < indexes.size(); i++) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 7812c985091..17e1c103ea3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -95,12 +95,12 @@ public class Join extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); TensorType joinedType = new TensorType.Builder(a.type(), b.type()).build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index 53504868ff2..4a338e5501e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -53,12 +53,12 @@ public class Map extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return argument.type(context); } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); for (Iterator<Tensor.Cell> i = argument.cellIterator(); i.hasNext(); ) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 416b74e7f94..e045effbe7e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -101,7 +101,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -115,7 +115,7 @@ public class Reduce extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) throw new IllegalArgumentException("Cannot reduce " + argument + " over dimensions " + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index de3d2be265a..af4492ca1e4 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -72,7 +72,7 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override - public TensorType type(TypeContext context) { + public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type(argument.type(context)); } @@ -84,7 +84,7 @@ public class Rename extends PrimitiveTensorFunction { } @Override - public Tensor evaluate(EvaluationContext context) { + public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); TensorType renamedType = type(tensor.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 78ab09c7820..e805e9d87bb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -43,14 +43,14 @@ public abstract class TensorFunction { * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract Tensor evaluate(EvaluationContext context); + public abstract <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context * * @param context a context which must be passed to all nexted functions when evaluating */ - public abstract TensorType type(TypeContext context); + public abstract <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context); /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } @@ -58,7 +58,7 @@ public abstract class TensorFunction { /** * Return a string representation of this context. * - * @param context a context which must be passed to all nexted functions when requesting the string value + * @param context a context which must be passed to all nested functions when requesting the string value */ public abstract String toString(ToStringContext context); |