summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java157
1 files changed, 93 insertions, 64 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
index 0a881c0a290..0325753d2e0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
@@ -19,10 +19,10 @@ import java.util.stream.Collectors;
* @author bratseth
*/
@Beta
-public class Value extends PrimitiveTensorFunction {
+public class Value<NAMETYPE extends TypeContext.Name> extends PrimitiveTensorFunction<NAMETYPE> {
- private final TensorFunction argument;
- private final List<DimensionValue> cellAddress;
+ private final TensorFunction<NAMETYPE> argument;
+ private final List<DimensionValue<NAMETYPE>> cellAddress;
/**
* Creates a value function
@@ -31,7 +31,7 @@ public class Value extends PrimitiveTensorFunction {
* @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress
* because those require a type, but a type is not resolved until this is evaluated
*/
- public Value(TensorFunction argument, List<DimensionValue> cellAddress) {
+ public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) {
this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " +
@@ -40,34 +40,38 @@ public class Value extends PrimitiveTensorFunction {
}
@Override
- public List<TensorFunction> arguments() { return List.of(argument); }
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
@Override
- public Value withArguments(List<TensorFunction> arguments) {
+ public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
if (arguments.size() != 1)
throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
- return new Value(arguments.get(0), cellAddress);
+ return new Value<NAMETYPE>(arguments.get(0), cellAddress);
}
@Override
- public PrimitiveTensorFunction toPrimitive() { return this; }
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
@Override
- public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor tensor = argument.evaluate(context);
if (tensor.type().rank() != cellAddress.size())
throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() +
" to a tensor of type " + tensor.type());
TensorAddress.Builder b = new TensorAddress.Builder(tensor.type());
for (int i = 0; i < cellAddress.size(); i++) {
- b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
- cellAddress.get(i).label());
+ if (cellAddress.get(i).label().isPresent())
+ b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
+ cellAddress.get(i).label().get());
+ else
+ b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
+ String.valueOf(cellAddress.get(i).index().get().apply(context).intValue()));
}
return Tensor.from(tensor.get(b.build()));
}
@Override
- public <NAMETYPE extends TypeContext.Name> TensorType type(TypeContext<NAMETYPE> context) {
+ public TensorType type(TypeContext<NAMETYPE> context) {
return new TensorType.Builder(argument.type(context).valueType()).build();
}
@@ -87,69 +91,94 @@ public class Value extends PrimitiveTensorFunction {
else {
return "{" + cellAddress.stream().map(i -> i.toString()).collect(Collectors.joining(", ")) + "}";
}
- }
+ }
+
+ public static class DimensionValue<NAMETYPE extends TypeContext.Name> {
+
+ private final Optional<String> dimension;
- public static class DimensionValue {
+ /** The label of this, or null if index is set */
+ private final String label;
- private final Optional<String> dimension;
+ /** The function returning the index of this, or null if label is set */
+ private final ScalarFunction<NAMETYPE> index;
+
+ public DimensionValue(String dimension, String label) {
+ this(Optional.of(dimension), label, null);
+ }
+
+ public DimensionValue(String dimension, int index) {
+ this(Optional.of(dimension), null, new ConstantScalarFunction<>(index));
+ }
- /** The label of this. Always available, whether or not index is */
- private final String label;
+ public DimensionValue(int index) {
+ this(Optional.empty(), null, new ConstantScalarFunction<>(index));
+ }
- /** The index of this, or empty if this is a non-integer label */
- private final Optional<Integer> index;
+ public DimensionValue(String label) {
+ this(Optional.empty(), label, null);
+ }
- public DimensionValue(String dimension, String label) {
- this(Optional.of(dimension), label, indexOrEmpty(label));
- }
+ public DimensionValue(ScalarFunction<NAMETYPE> index) {
+ this(Optional.empty(), null, index);
+ }
- public DimensionValue(String dimension, int index) {
- this(Optional.of(dimension), String.valueOf(index), Optional.of(index));
- }
+ public DimensionValue(Optional<String> dimension, String label) {
+ this(dimension, label, null);
+ }
- public DimensionValue(int index) {
- this(Optional.empty(), String.valueOf(index), Optional.of(index));
- }
+ public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
+ this(dimension, null, index);
+ }
- public DimensionValue(String label) {
- this(Optional.empty(), label, indexOrEmpty(label));
- }
+ public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
+ this(Optional.of(dimension), null, index);
+ }
- private DimensionValue(Optional<String> dimension, String label, Optional<Integer> index) {
+ private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
this.dimension = dimension;
this.label = label;
this.index = index;
- }
-
- /**
- * Returns the given name of the dimension, or null if dense form is used, such that name
- * must be inferred from order
- */
- public Optional<String> dimension() { return dimension; }
-
- /** Returns the label or index for this dimension as a string */
- public String label() { return label; }
-
- /** Returns the index for this dimension, or empty if it is not a number */
- Optional<Integer> index() { return index; }
-
- @Override
- public String toString() {
- if (dimension.isPresent())
- return dimension.get() + ":" + label;
- else
- return label;
- }
-
- private static Optional<Integer> indexOrEmpty(String label) {
- try {
- return Optional.of(Integer.parseInt(label));
- }
- catch (IllegalArgumentException e) {
- return Optional.empty();
- }
- }
-
- }
+ }
+
+ /**
+ * Returns the given name of the dimension, or null if dense form is used, such that name
+ * must be inferred from order
+ */
+ public Optional<String> dimension() { return dimension; }
+
+ /** Returns the label for this dimension or empty if it is provided by an index function */
+ public Optional<String> label() { return Optional.ofNullable(label); }
+
+ /** Returns the index expression for this dimension, or empty if it is not a number */
+ public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
+
+ @Override
+ public String toString() {
+ StringBuilder b = new StringBuilder();
+ dimension.ifPresent(d -> b.append(d).append(":"));
+ if (label != null)
+ b.append(label);
+ else
+ b.append(index);
+ return b.toString();
+ }
+
+ }
+
+ private static class ConstantScalarFunction<NAMETYPE extends TypeContext.Name> implements ScalarFunction<NAMETYPE> {
+
+ private final Double value;
+
+ public ConstantScalarFunction(int value) {
+ this.value = (double)value;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<NAMETYPE> context) {
+ return value;
+ }
+
+ }
}