diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-01-29 14:51:23 +0100 |
commit | 1b4fde01d98bf724a54b6c1cfe3ffa4b29aec90e (patch) | |
tree | 20a127542b004eceb94e4d1344b3446df8092bd2 /vespajlib | |
parent | 28e3545728977a0be82159b8f278be8e772cb59b (diff) |
Propagate type information
Diffstat (limited to 'vespajlib')
22 files changed, 156 insertions, 87 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java index 3db661f8a23..e18b77a0434 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/EvaluationContext.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; /** * An evaluation context which is passed down to all nested functions during evaluation. @@ -12,6 +13,14 @@ import com.yahoo.tensor.Tensor; @Beta public interface EvaluationContext { + /** + * Returns tye type of the tensor with this name. + * + * @return returns the type of the tensor which will be returned by calling getTensor(name) + * or null if getTensor will return null. + */ + TensorType getTensorType(String name); + /** Returns the tensor bound to this name, or null if none */ Tensor getTensor(String name); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java index db8a66a5fa2..6bdfe8f19b6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/MapEvaluationContext.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import java.util.HashMap; @@ -19,6 +20,13 @@ public class MapEvaluationContext implements EvaluationContext { public void put(String name, Tensor tensor) { bindings.put(name, tensor); } @Override + public TensorType getTensorType(String name) { + Tensor tensor = bindings.get(name); + if (tensor == null) return null; + return tensor.type(); + } + + @Override public Tensor getTensor(String name) { return bindings.get(name); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java index 1f6ad050368..6c149724aca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/VariableTensor.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.evaluation; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.PrimitiveTensorFunction; import com.yahoo.tensor.functions.TensorFunction; import com.yahoo.tensor.functions.ToStringContext; @@ -25,15 +26,18 @@ public class VariableTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { return this; } + public TensorFunction withArguments(List<TensorFunction> arguments) { return this; } @Override public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { return context.getTensorType(name); } + + @Override public Tensor evaluate(EvaluationContext context) { return context.getTensor(name); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java index 10f53670826..93365d20966 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmax.java @@ -14,17 +14,17 @@ public class Argmax extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Argmax(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmax must have 1 argument, got " + arguments.size()); return new Argmax(arguments.get(0), dimension); @@ -37,7 +37,7 @@ public class Argmax extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.max, dimension), ScalarFunctions.equal()); } - + @Override public String toString(ToStringContext context) { return "argmax(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java index d324aec53e9..e598cdf8a98 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Argmin.java @@ -14,17 +14,17 @@ public class Argmin extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Argmin(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Argmin must have 1 argument, got " + arguments.size()); return new Argmin(arguments.get(0), dimension); @@ -37,7 +37,7 @@ public class Argmin extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.min, dimension), ScalarFunctions.equal()); } - + @Override public String toString(ToStringContext context) { return "argmin(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java index 191c7988443..0c43caef05c 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CompositeTensorFunction.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; /** @@ -14,6 +15,10 @@ import com.yahoo.tensor.evaluation.EvaluationContext; @Beta public abstract class CompositeTensorFunction extends TensorFunction { + /** Finds the type this produces by first converting it to a primitive function */ + @Override + public final TensorType type(EvaluationContext context) { return toPrimitive().type(context); } + /** Evaluates this by first converting it to a primitive function */ @Override public final Tensor evaluate(EvaluationContext context) { return toPrimitive().evaluate(context); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index d4affe0ef9b..cc8067224c7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -34,10 +34,10 @@ public class Concat extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if (arguments.size() != 2) throw new IllegalArgumentException("Concat must have 2 arguments, got " + arguments.size()); return new Concat(arguments.get(0), arguments.get(1), dimension); @@ -54,6 +54,20 @@ public class Concat extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return type(argumentA.type(context), argumentB.type(context)); + } + + /** Returns the type resulting from concatenating a and b */ + private TensorType type(TensorType a, TensorType b) { + TensorType.Builder builder = new TensorType.Builder(a, b); + if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size + builder.set(TensorType.Dimension.indexed(dimension, a.dimension(dimension).get().size().get() + + b.dimension(dimension).get().size().get())); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); @@ -63,7 +77,7 @@ public class Concat extends PrimitiveTensorFunction { IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - TensorType concatType = concatType(a, b); + TensorType concatType = type(a.type(), b.type()); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); @@ -115,15 +129,6 @@ public class Concat extends PrimitiveTensorFunction { } - /** Returns the type resulting from concatenating a and b */ - private TensorType concatType(Tensor a, Tensor b) { - TensorType.Builder builder = new TensorType.Builder(a.type(), b.type()); - if (builder.getDimension(dimension).get().size().isPresent()) // both types have size: correct to concat size - builder.set(TensorType.Dimension.indexed(dimension, a.type().dimension(dimension).get().size().get() + - b.type().dimension(dimension).get().size().get())); - return builder.build(); - } - /** Returns the concrete (not type) dimension sizes resulting from combining a and b */ private DimensionSizes concatSize(TensorType concatType, IndexedTensor a, IndexedTensor b, String concatDimension) { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java index 14ed38718ce..4a6d656142f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ConstantTensor.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import java.util.Collections; @@ -27,10 +28,10 @@ public class ConstantTensor extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("ConstantTensor must have 0 arguments, got " + arguments.size()); return this; @@ -40,6 +41,9 @@ public class ConstantTensor extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { return constant.type(); } + + @Override public Tensor evaluate(EvaluationContext context) { return constant; } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index 653be8dacf0..e302f6606e7 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -25,10 +25,10 @@ public class Diag extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Diag must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index ef2770c04f5..ff9589bd6ae 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -47,10 +47,10 @@ public class Generate extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Generate must have 0 arguments, got " + arguments.size()); return this; @@ -60,6 +60,9 @@ public class Generate extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { return type; } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor.Builder builder = Tensor.Builder.of(type); IndexedTensor.Indexes indexes = IndexedTensor.Indexes.of(dimensionSizes(type)); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 174a8e4c435..01c681bfb36 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -70,15 +70,13 @@ public class Join extends PrimitiveTensorFunction { return typeBuilder.build(); } - public TensorFunction argumentA() { return argumentA; } - public TensorFunction argumentB() { return argumentB; } public DoubleBinaryOperator combinator() { return combinator; } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argumentA, argumentB); } + public List<TensorFunction> arguments() { return ImmutableList.of(argumentA, argumentB); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Join must have 2 arguments, got " + arguments.size()); return new Join(arguments.get(0), arguments.get(1), combinator); @@ -95,6 +93,11 @@ public class Join extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return new TensorType.Builder(argumentA.type(context), argumentB.type(context)).build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java index 91a9c6d1b27..d7f7ae59d62 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L1Normalize.java @@ -14,17 +14,17 @@ public class L1Normalize extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public L1Normalize(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L1Normalize must have 1 argument, got " + arguments.size()); return new L1Normalize(arguments.get(0), dimension); @@ -38,7 +38,7 @@ public class L1Normalize extends CompositeTensorFunction { new Reduce(primitiveArgument, Reduce.Aggregator.sum, dimension), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "l1_normalize(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java index bdf8921f81d..e2c526760bd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/L2Normalize.java @@ -14,17 +14,17 @@ public class L2Normalize extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public L2Normalize(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("L2Normalize must have 1 argument, got " + arguments.size()); return new L2Normalize(arguments.get(0), dimension); @@ -40,7 +40,7 @@ public class L2Normalize extends CompositeTensorFunction { ScalarFunctions.sqrt()), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "l2_normalize(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java index a5e1a016a41..e5440b56c54 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Map.java @@ -2,8 +2,6 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableMap; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; @@ -39,10 +37,10 @@ public class Map extends PrimitiveTensorFunction { public DoubleUnaryOperator mapper() { return mapper; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Map must have 1 argument, got " + arguments.size()); return new Map(arguments.get(0), mapper); @@ -54,6 +52,11 @@ public class Map extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return argument.type(context); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = argument().evaluate(context); Tensor.Builder builder = Tensor.Builder.of(argument.type()); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java index 4071917c2b5..935e4761cfe 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Matmul.java @@ -27,10 +27,10 @@ public class Matmul extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(argument1, argument2); } + public List<TensorFunction> arguments() { return ImmutableList.of(argument1, argument2); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 2) throw new IllegalArgumentException("Matmul must have 2 arguments, got " + arguments.size()); return new Matmul(arguments.get(0), arguments.get(1), dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java index 958ef85d1dc..1475f7f4ac1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Random.java @@ -28,10 +28,10 @@ public class Random extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Random must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index 8e7f4e4c773..d951ec9ccbd 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -26,10 +26,10 @@ public class Range extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return Collections.emptyList(); } + public List<TensorFunction> arguments() { return Collections.emptyList(); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 0) throw new IllegalArgumentException("Range must have 0 arguments, got " + arguments.size()); return this; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index de9f90a5804..591a6e4649e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -73,10 +73,10 @@ public class Reduce extends PrimitiveTensorFunction { public TensorFunction argument() { return argument; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Reduce must have 1 argument, got " + arguments.size()); return new Reduce(arguments.get(0), aggregator, dimensions); @@ -100,6 +100,19 @@ public class Reduce extends PrimitiveTensorFunction { } @Override + public TensorType type(EvaluationContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType argumentType) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : argumentType.dimensions()) + if ( ! dimensions.contains(dimension.name())) // keep + builder.dimension(dimension); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor argument = this.argument.evaluate(context); if ( ! dimensions.isEmpty() && ! argument.type().dimensionNames().containsAll(dimensions)) @@ -113,12 +126,7 @@ public class Reduce extends PrimitiveTensorFunction { else return reduceAllGeneral(argument); - // Reduce type - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : argument.type().dimensions()) - if ( ! dimensions.contains(dimension.name())) // keep - builder.dimension(dimension); - TensorType reducedType = builder.build(); + TensorType reducedType = type(argument.type()); // Reduce cells Map<TensorAddress, ValueAggregator> aggregatingCells = new HashMap<>(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index ec9b762a41c..6a9b8d68b38 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -26,6 +26,7 @@ public class Rename extends PrimitiveTensorFunction { private final TensorFunction argument; private final List<String> fromDimensions; private final List<String> toDimensions; + private final Map<String, String> fromToMap; public Rename(TensorFunction argument, String fromDimension, String toDimension) { this(argument, ImmutableList.of(fromDimension), ImmutableList.of(toDimension)); @@ -43,13 +44,24 @@ public class Rename extends PrimitiveTensorFunction { this.argument = argument; this.fromDimensions = ImmutableList.copyOf(fromDimensions); this.toDimensions = ImmutableList.copyOf(toDimensions); + this.fromToMap = fromToMap(fromDimensions, toDimensions); + } + + public List<String> fromDimensions() { return fromDimensions; } + public List<String> toDimensions() { return toDimensions; } + + private static Map<String, String> fromToMap(List<String> fromDimensions, List<String> toDimensions) { + Map<String, String> map = new HashMap<>(); + for (int i = 0; i < fromDimensions.size(); i++) + map.put(fromDimensions.get(i), toDimensions.get(i)); + return map; } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Rename must have 1 argument, got " + arguments.size()); return new Rename(arguments.get(0), fromDimensions, toDimensions); @@ -59,11 +71,22 @@ public class Rename extends PrimitiveTensorFunction { public PrimitiveTensorFunction toPrimitive() { return this; } @Override + public TensorType type(EvaluationContext context) { + return type(argument.type(context)); + } + + private TensorType type(TensorType type) { + TensorType.Builder builder = new TensorType.Builder(); + for (TensorType.Dimension dimension : type.dimensions()) + builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); + return builder.build(); + } + + @Override public Tensor evaluate(EvaluationContext context) { Tensor tensor = argument.evaluate(context); - Map<String, String> fromToMap = fromToMap(); - TensorType renamedType = rename(tensor.type(), fromToMap); + TensorType renamedType = type(tensor.type()); // an array which lists the index of each label in the renamed type int[] toIndexes = new int[tensor.type().dimensions().size()]; @@ -82,13 +105,6 @@ public class Rename extends PrimitiveTensorFunction { return builder.build(); } - private TensorType rename(TensorType type, Map<String, String> fromToMap) { - TensorType.Builder builder = new TensorType.Builder(); - for (TensorType.Dimension dimension : type.dimensions()) - builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); - return builder.build(); - } - private TensorAddress rename(TensorAddress address, int[] toIndexes) { String[] reorderedLabels = new String[toIndexes.length]; for (int i = 0; i < toIndexes.length; i++) @@ -102,13 +118,6 @@ public class Rename extends PrimitiveTensorFunction { toVectorString(fromDimensions) + ", " + toVectorString(toDimensions) + ")"; } - private Map<String, String> fromToMap() { - Map<String, String> map = new HashMap<>(); - for (int i = 0; i < fromDimensions.size(); i++) - map.put(fromDimensions.get(i), toDimensions.get(i)); - return map; - } - private String toVectorString(List<String> elements) { if (elements.size() == 1) return elements.get(0); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java index c856b548180..32cff5ac84a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Softmax.java @@ -16,21 +16,21 @@ public class Softmax extends CompositeTensorFunction { private final TensorFunction argument; private final String dimension; - + public Softmax(TensorFunction argument, String dimension) { this.argument = argument; this.dimension = dimension; } - + public static TensorType outputType(TensorType inputType, String dimension) { return Reduce.outputType(inputType, ImmutableList.of(dimension)); } @Override - public List<TensorFunction> functionArguments() { return Collections.singletonList(argument); } + public List<TensorFunction> arguments() { return Collections.singletonList(argument); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 1) throw new IllegalArgumentException("Softmax must have 1 argument, got " + arguments.size()); return new Softmax(arguments.get(0), dimension); @@ -45,7 +45,7 @@ public class Softmax extends CompositeTensorFunction { dimension), ScalarFunctions.divide()); } - + @Override public String toString(ToStringContext context) { return "softmax(" + argument.toString(context) + ", " + dimension + ")"; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 533a46f87fe..3f6dfae6222 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.google.common.annotations.Beta; import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.MapEvaluationContext; @@ -19,14 +20,14 @@ import java.util.List; public abstract class TensorFunction { /** Returns the function arguments of this node in the order they are applied */ - public abstract List<TensorFunction> functionArguments(); + public abstract List<TensorFunction> arguments(); /** * Returns a copy of this tensor function with the arguments replaced by the given list of arguments. * * @throws IllegalArgumentException if the argument list has the wrong size for this function */ - public abstract TensorFunction replaceArguments(List<TensorFunction> arguments); + public abstract TensorFunction withArguments(List<TensorFunction> arguments); /** * Translate this function - and all of its arguments recursively - @@ -43,6 +44,13 @@ public abstract class TensorFunction { */ public abstract Tensor evaluate(EvaluationContext context); + /** + * Returns the type of the tensor this produces given the input types in the context + * + * @param context a context which must be passed to all nexted functions when evaluating + */ + public abstract TensorType type(EvaluationContext context); + /** Evaluate with no context */ public final Tensor evaluate() { return evaluate(new MapEvaluationContext()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java index 2464be981f5..78ff0731566 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/XwPlusB.java @@ -14,7 +14,7 @@ public class XwPlusB extends CompositeTensorFunction { private final TensorFunction x, w, b; private final String dimension; - + public XwPlusB(TensorFunction x, TensorFunction w, TensorFunction b, String dimension) { this.x = x; this.w = w; @@ -23,10 +23,10 @@ public class XwPlusB extends CompositeTensorFunction { } @Override - public List<TensorFunction> functionArguments() { return ImmutableList.of(x, w, b); } + public List<TensorFunction> arguments() { return ImmutableList.of(x, w, b); } @Override - public TensorFunction replaceArguments(List<TensorFunction> arguments) { + public TensorFunction withArguments(List<TensorFunction> arguments) { if ( arguments.size() != 3) throw new IllegalArgumentException("XwPlusB must have 3 arguments, got " + arguments.size()); return new XwPlusB(arguments.get(0), arguments.get(1), arguments.get(2), dimension); @@ -43,7 +43,7 @@ public class XwPlusB extends CompositeTensorFunction { primitiveB, ScalarFunctions.add()); } - + @Override public String toString(ToStringContext context) { return "xw_plus_b(" + x.toString(context) + ", " + |