diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java | 157 |
1 files changed, 93 insertions, 64 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java index 0a881c0a290..0325753d2e0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java @@ -19,10 +19,10 @@ import java.util.stream.Collectors; * @author bratseth */ @Beta -public class Value extends PrimitiveTensorFunction { +public class Value<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> { - private final TensorFunction argument; - private final List<DimensionValue> cellAddress; + private final TensorFunction<NAMETYPE> argument; + private final List<DimensionValue<NAMETYPE>> cellAddress; /** * Creates a value function @@ -31,7 +31,7 @@ public class Value extends PrimitiveTensorFunction { * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress * because those require a type, but a type is not resolved until this is evaluated */ - public Value(TensorFunction argument, List<DimensionValue> cellAddress) { + public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) { this.argument = Objects.requireNonNull(argument, "Argument cannot be null"); if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty())) throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " + @@ -40,34 +40,38 @@ public class Value extends PrimitiveTensorFunction { } @Override - public List<TensorFunction> arguments() { return List.of(argument); } + public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); } @Override - public Value withArguments(List<TensorFunction> arguments) { + public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { if (arguments.size() != 1) throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size()); - return new Value(arguments.get(0), cellAddress); + return new Value<NAMETYPE>(arguments.get(0), cellAddress); } @Override - public PrimitiveTensorFunction toPrimitive() { return this; } + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; } @Override - public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) { + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor tensor = argument.evaluate(context); if (tensor.type().rank() != cellAddress.size()) throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() + " to a tensor of type " + tensor.type()); TensorAddress.Builder b = new TensorAddress.Builder(tensor.type()); for (int i = 0; i < cellAddress.size(); i++) { - b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), - cellAddress.get(i).label()); + if (cellAddress.get(i).label().isPresent()) + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + cellAddress.get(i).label().get()); + else + b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()), + String.valueOf(cellAddress.get(i).index().get().apply(context).intValue())); } return Tensor.from(tensor.get(b.build())); } @Override - public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) { + public TensorType type(TypeContext<NAMETYPE> context) { return new TensorType.Builder(argument.type(context).valueType()).build(); } @@ -87,69 +91,94 @@ public class Value extends PrimitiveTensorFunction { else { return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}"; } - } + } + + public static class DimensionValue<NAMETYPE extends TypeContext.Name> { + + private final Optional<String> dimension; - public static class DimensionValue { + /** The label of this, or null if index is set */ + private final String label; - private final Optional<String> dimension; + /** The function returning the index of this, or null if label is set */ + private final ScalarFunction<NAMETYPE> index; + + public DimensionValue(String dimension, String label) { + this(Optional.of(dimension), label, null); + } + + public DimensionValue(String dimension, int index) { + this(Optional.of(dimension), null, new ConstantScalarFunction<>(index)); + } - /** The label of this. Always available, whether or not index is */ - private final String label; + public DimensionValue(int index) { + this(Optional.empty(), null, new ConstantScalarFunction<>(index)); + } - /** The index of this, or empty if this is a non-integer label */ - private final Optional<Integer> index; + public DimensionValue(String label) { + this(Optional.empty(), label, null); + } - public DimensionValue(String dimension, String label) { - this(Optional.of(dimension), label, indexOrEmpty(label)); - } + public DimensionValue(ScalarFunction<NAMETYPE> index) { + this(Optional.empty(), null, index); + } - public DimensionValue(String dimension, int index) { - this(Optional.of(dimension), String.valueOf(index), Optional.of(index)); - } + public DimensionValue(Optional<String> dimension, String label) { + this(dimension, label, null); + } - public DimensionValue(int index) { - this(Optional.empty(), String.valueOf(index), Optional.of(index)); - } + public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) { + this(dimension, null, index); + } - public DimensionValue(String label) { - this(Optional.empty(), label, indexOrEmpty(label)); - } + public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) { + this(Optional.of(dimension), null, index); + } - private DimensionValue(Optional<String> dimension, String label, Optional<Integer> index) { + private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) { this.dimension = dimension; this.label = label; this.index = index; - } - - /** - * Returns the given name of the dimension, or null if dense form is used, such that name - * must be inferred from order - */ - public Optional<String> dimension() { return dimension; } - - /** Returns the label or index for this dimension as a string */ - public String label() { return label; } - - /** Returns the index for this dimension, or empty if it is not a number */ - Optional<Integer> index() { return index; } - - @Override - public String toString() { - if (dimension.isPresent()) - return dimension.get() + ":" + label; - else - return label; - } - - private static Optional<Integer> indexOrEmpty(String label) { - try { - return Optional.of(Integer.parseInt(label)); - } - catch (IllegalArgumentException e) { - return Optional.empty(); - } - } - - } + } + + /** + * Returns the given name of the dimension, or null if dense form is used, such that name + * must be inferred from order + */ + public Optional<String> dimension() { return dimension; } + + /** Returns the label for this dimension or empty if it is provided by an index function */ + public Optional<String> label() { return Optional.ofNullable(label); } + + /** Returns the index expression for this dimension, or empty if it is not a number */ + public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); } + + @Override + public String toString() { + StringBuilder b = new StringBuilder(); + dimension.ifPresent(d -> b.append(d).append(":")); + if (label != null) + b.append(label); + else + b.append(index); + return b.toString(); + } + + } + + private static class ConstantScalarFunction<NAMETYPE extends TypeContext.Name> implements ScalarFunction<NAMETYPE> { + + private final Double value; + + public ConstantScalarFunction(int value) { + this.value = (double)value; + } + + @Override + public Double apply(EvaluationContext<NAMETYPE> context) { + return value; + } + + } } |