diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java | 36 |
1 files changed, 17 insertions, 19 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java index a75e49c6402..6830ec50c5f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DynamicTensor.java @@ -20,7 +20,7 @@ import java.util.function.Function; * * @author bratseth */ -public abstract class DynamicTensor extends PrimitiveTensorFunction { +public abstract class DynamicTensor<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { private final TensorType type; @@ -29,20 +29,20 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { return type; } + public TensorType type(TypeContext<NAMETYPE> context) { return type; } @Override - public List<TensorFunction> arguments() { return Collections.emptyList(); } + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.emptyList(); } @Override - public TensorFunction withArguments(List<TensorFunction> arguments) { + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 0) throw new IllegalArgumentException("Dynamic tensors must have 0 arguments, got " + arguments.size()); return this; } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } TensorType type() { return type; } @@ -54,27 +54,26 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { abstract String contentToString(ToStringContext context); /** Creates a dynamic tensor function. The cell addresses must match the type. */ - public static DynamicTensor from(TensorType type, Map<TensorAddress, ScalarFunction> cells) { - return new MappedDynamicTensor(type, cells); + public static <NAMETYPE extends TypeContext.Name> DynamicTensor<NAMETYPE> from(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { + return new MappedDynamicTensor<>(type, cells); } /** Creates a dynamic tensor function for a bound, indexed tensor */ - public static DynamicTensor from(TensorType type, List<ScalarFunction> cells) { - return new IndexedDynamicTensor(type, cells); + public static <NAMETYPE extends TypeContext.Name> DynamicTensor<NAMETYPE> from(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { + return new IndexedDynamicTensor<>(type, cells); } - private static class MappedDynamicTensor extends DynamicTensor { + private static class MappedDynamicTensor<NAMETYPE extends TypeContext.Name> extends DynamicTensor<NAMETYPE> { - private final ImmutableMap<TensorAddress, ScalarFunction> cells; + private final ImmutableMap<TensorAddress, ScalarFunction<NAMETYPE>> cells; - MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction> cells) { + MappedDynamicTensor(TensorType type, Map<TensorAddress, ScalarFunction<NAMETYPE>> cells) { super(type); this.cells = ImmutableMap.copyOf(cells); } @Override - @SuppressWarnings("unchecked") - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor.Builder builder = Tensor.Builder.of(type()); for (var cell : cells.entrySet()) builder.cell(cell.getKey(), cell.getValue().apply(context)); @@ -102,11 +101,11 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } - private static class IndexedDynamicTensor extends DynamicTensor { + private static class IndexedDynamicTensor<NAMETYPE extends TypeContext.Name> extends DynamicTensor<NAMETYPE> { - private final List<ScalarFunction> cells; + private final List<ScalarFunction<NAMETYPE>> cells; - IndexedDynamicTensor(TensorType type, List<ScalarFunction> cells) { + IndexedDynamicTensor(TensorType type, List<ScalarFunction<NAMETYPE>> cells) { super(type); if ( ! type.dimensions().stream().allMatch(d -> d.type() == TensorType.Dimension.Type.indexedBound)) throw new IllegalArgumentException("A dynamic tensor can only be created from a list if the type has " + @@ -115,8 +114,7 @@ public abstract class DynamicTensor extends PrimitiveTensorFunction { } @Override - @SuppressWarnings("unchecked") - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { IndexedTensor.BoundBuilder builder = (IndexedTensor.BoundBuilder)Tensor.Builder.of(type()); for (int i = 0; i < cells.size(); i++) builder.cellByDirectIndex(i, cells.get(i).apply(context)); |