summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2019-12-09 14:31:43 -0800
committerGitHub <noreply@github.com>2019-12-09 14:31:43 -0800
commit4192f7e82258b6fc6165230cdf1910c384a138c1 (patch)
treead80d958954597289c8d838112d6cf092fbaffc3 /vespajlib
parent676dd43a12db59f96536aa6d8a45369d24d17404 (diff)
parent7ef64a61b4f04a400428fe58ed2475aa37c43d39 (diff)
Merge pull request #11528 from vespa-engine/bratseth/tensor-slice
Generalized Slice tensor function.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json93
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java25
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java98
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java266
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java189
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java138
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java75
7 files changed, 559 insertions, 325 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index cea58d565c2..e991173805f 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1036,6 +1036,7 @@
"methods": [
"public void <init>(int)",
"public void add(java.lang.String, long)",
+ "public void add(java.lang.String, java.lang.String)",
"public com.yahoo.tensor.PartialAddress build()"
],
"fields": []
@@ -1046,7 +1047,15 @@
"attributes": [
"public"
],
- "methods": [],
+ "methods": [
+ "public java.lang.String dimension(int)",
+ "public long numericLabel(java.lang.String)",
+ "public java.lang.String label(java.lang.String)",
+ "public java.lang.String label(int)",
+ "public int size()",
+ "public com.yahoo.tensor.TensorAddress asAddress(com.yahoo.tensor.TensorType)",
+ "public java.lang.String toString()"
+ ],
"fields": []
},
"com.yahoo.tensor.Tensor$Builder$CellBuilder": {
@@ -2465,6 +2474,47 @@
],
"fields": []
},
+ "com.yahoo.tensor.functions.Slice$DimensionValue": {
+ "superClass": "java.lang.Object",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(java.lang.String, java.lang.String)",
+ "public void <init>(java.lang.String, int)",
+ "public void <init>(int)",
+ "public void <init>(java.lang.String)",
+ "public void <init>(com.yahoo.tensor.functions.ScalarFunction)",
+ "public void <init>(java.util.Optional, java.lang.String)",
+ "public void <init>(java.util.Optional, com.yahoo.tensor.functions.ScalarFunction)",
+ "public void <init>(java.lang.String, com.yahoo.tensor.functions.ScalarFunction)",
+ "public java.util.Optional dimension()",
+ "public java.util.Optional label()",
+ "public java.util.Optional index()",
+ "public java.lang.String toString()",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
+ ],
+ "fields": []
+ },
+ "com.yahoo.tensor.functions.Slice": {
+ "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
+ "interfaces": [],
+ "attributes": [
+ "public"
+ ],
+ "methods": [
+ "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
+ "public java.util.List arguments()",
+ "public com.yahoo.tensor.functions.Slice withArguments(java.util.List)",
+ "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
+ "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
+ "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
+ "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
+ "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)"
+ ],
+ "fields": []
+ },
"com.yahoo.tensor.functions.Softmax": {
"superClass": "com.yahoo.tensor.functions.CompositeTensorFunction",
"interfaces": [],
@@ -2531,47 +2581,6 @@
],
"fields": []
},
- "com.yahoo.tensor.functions.Value$DimensionValue": {
- "superClass": "java.lang.Object",
- "interfaces": [],
- "attributes": [
- "public"
- ],
- "methods": [
- "public void <init>(java.lang.String, java.lang.String)",
- "public void <init>(java.lang.String, int)",
- "public void <init>(int)",
- "public void <init>(java.lang.String)",
- "public void <init>(com.yahoo.tensor.functions.ScalarFunction)",
- "public void <init>(java.util.Optional, java.lang.String)",
- "public void <init>(java.util.Optional, com.yahoo.tensor.functions.ScalarFunction)",
- "public void <init>(java.lang.String, com.yahoo.tensor.functions.ScalarFunction)",
- "public java.util.Optional dimension()",
- "public java.util.Optional label()",
- "public java.util.Optional index()",
- "public java.lang.String toString()",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)"
- ],
- "fields": []
- },
- "com.yahoo.tensor.functions.Value": {
- "superClass": "com.yahoo.tensor.functions.PrimitiveTensorFunction",
- "interfaces": [],
- "attributes": [
- "public"
- ],
- "methods": [
- "public void <init>(com.yahoo.tensor.functions.TensorFunction, java.util.List)",
- "public java.util.List arguments()",
- "public com.yahoo.tensor.functions.Value withArguments(java.util.List)",
- "public com.yahoo.tensor.functions.PrimitiveTensorFunction toPrimitive()",
- "public com.yahoo.tensor.Tensor evaluate(com.yahoo.tensor.evaluation.EvaluationContext)",
- "public com.yahoo.tensor.TensorType type(com.yahoo.tensor.evaluation.TypeContext)",
- "public java.lang.String toString(com.yahoo.tensor.functions.ToStringContext)",
- "public bridge synthetic com.yahoo.tensor.functions.TensorFunction withArguments(java.util.List)"
- ],
- "fields": []
- },
"com.yahoo.tensor.functions.XwPlusB": {
"superClass": "com.yahoo.tensor.functions.CompositeTensorFunction",
"interfaces": [],
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 1f3c373c1e8..1cde1fcdbb7 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -54,9 +54,8 @@ public class MixedTensor implements Tensor {
public double get(TensorAddress address) {
long cellIndex = index.indexOf(address);
Cell cell = cells.get((int)cellIndex);
- if (!address.equals(cell.getKey())) {
- throw new IllegalStateException("Unable to find correct cell by direct index.");
- }
+ if ( ! address.equals(cell.getKey()))
+ throw new IllegalStateException("Unable to find correct cell in " + this + " by direct index " + address);
return cell.getValue();
}
@@ -375,9 +374,8 @@ public class MixedTensor implements Tensor {
public long indexOf(TensorAddress address) {
TensorAddress sparsePart = sparsePartialAddress(address);
- if ( ! sparseMap.containsKey(sparsePart)) {
- throw new IllegalArgumentException("Address not found");
- }
+ if ( ! sparseMap.containsKey(sparsePart))
+ throw new IllegalArgumentException("Address subspace " + sparsePart + " not found in " + this);
long base = sparseMap.get(sparsePart);
long offset = denseOffset(address);
return base + offset;
@@ -414,7 +412,7 @@ public class MixedTensor implements Tensor {
TensorType.Dimension dimension = type.dimensions().get(i);
if (dimension.isIndexed()) {
denseSubspaceSize *= dimension.size().orElseThrow(() ->
- new IllegalArgumentException("Unknown size of indexed dimension."));
+ new IllegalArgumentException("Unknown size of indexed dimension"));
}
}
}
@@ -422,15 +420,13 @@ public class MixedTensor implements Tensor {
}
private TensorAddress sparsePartialAddress(TensorAddress address) {
- if (type.dimensions().size() != address.size()) {
- throw new IllegalArgumentException("Tensor type and address are not of same size.");
- }
+ if (type.dimensions().size() != address.size())
+ throw new IllegalArgumentException("Tensor type of " + this + " is not the same size as " + address);
TensorAddress.Builder builder = new TensorAddress.Builder(sparseType);
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
- if (!dimension.isIndexed()) {
+ if ( ! dimension.isIndexed())
builder.add(dimension.name(), address.label(i));
- }
}
return builder.build();
}
@@ -488,6 +484,11 @@ public class MixedTensor implements Tensor {
return TensorAddress.of(labels);
}
+ @Override
+ public String toString() {
+ return "indexes into " + type;
+ }
+
}
public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
index 9c41d5aad68..4eca9c47402 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java
@@ -1,6 +1,8 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import java.util.Arrays;
+
/**
* An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors
* dimensions.
@@ -13,10 +15,10 @@ package com.yahoo.tensor;
// - We can add support for string labels later without breaking the API
public class PartialAddress {
- // Two arrays which contains corresponding dimension=label pairs.
+ // Two arrays which contains corresponding dimension:label pairs.
// The sizes of these are always equal.
private final String[] dimensionNames;
- private final long[] labels;
+ private final Object[] labels;
private PartialAddress(Builder builder) {
this.dimensionNames = builder.dimensionNames;
@@ -25,23 +27,99 @@ public class PartialAddress {
builder.labels = null;
}
- /** Returns the int label of this dimension, or -1 if no label is specified for it */
- long numericLabel(String dimensionName) {
+ public String dimension(int i) {
+ return dimensionNames[i];
+ }
+
+ /** Returns the numeric label of this dimension, or -1 if no label is specified for it */
+ public long numericLabel(String dimensionName) {
for (int i = 0; i < dimensionNames.length; i++)
if (dimensionNames[i].equals(dimensionName))
- return labels[i];
+ return asLong(labels[i]);
return -1;
}
+ /** Returns the label of this dimension, or null if no label is specified for it */
+ public String label(String dimensionName) {
+ for (int i = 0; i < dimensionNames.length; i++)
+ if (dimensionNames[i].equals(dimensionName))
+ return labels[i].toString();
+ return null;
+ }
+
+ /**
+ * Returns the label at position i
+ *
+ * @throws IllegalArgumentException if i is out of bounds
+ */
+ public String label(int i) {
+ if (i >= size())
+ throw new IllegalArgumentException("No label at position " + i + " in " + this);
+ return labels[i].toString();
+ }
+
+ public int size() { return dimensionNames.length; }
+
+ /** Returns this as an address in the given tensor type */
+ // We need the type here not just for validation but because this must map to the dimension order given by the type
+ public TensorAddress asAddress(TensorType type) {
+ if (type.rank() != size())
+ throw new IllegalArgumentException(type + " has a different rank than " + this);
+ if (Arrays.stream(labels).allMatch(l -> l instanceof Long)) {
+ long[] numericLabels = new long[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ long label = numericLabel(type.dimensions().get(i).name());
+ if (label < 0)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ numericLabels[i] = label;
+ }
+ return TensorAddress.of(numericLabels);
+ }
+ else {
+ String[] stringLabels = new String[labels.length];
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ String label = label(type.dimensions().get(i).name());
+ if (label == null)
+ throw new IllegalArgumentException(type + " dimension names does not match " + this);
+ stringLabels[i] = label;
+ }
+ return TensorAddress.of(stringLabels);
+ }
+ }
+
+ private long asLong(Object label) {
+ if (label instanceof Long) {
+ return (Long) label;
+ }
+ else {
+ try {
+ return Long.parseLong(label.toString());
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Label '" + label + "' is not numeric");
+ }
+ }
+ }
+
+ @Override
+ public String toString() {
+ StringBuilder b = new StringBuilder("Partial address {");
+ for (int i = 0; i < dimensionNames.length; i++)
+ b.append(dimensionNames[i]).append(":").append(label(i)).append(", ");
+ if (size() > 0)
+ b.setLength(b.length() - 2);
+ return b.toString();
+ }
+
public static class Builder {
private String[] dimensionNames;
- private long[] labels;
+ private Object[] labels;
private int index = 0;
public Builder(int size) {
dimensionNames = new String[size];
- labels = new long[size];
+ labels = new Object[size];
}
public void add(String dimensionName, long label) {
@@ -50,6 +128,12 @@ public class PartialAddress {
index++;
}
+ public void add(String dimensionName, String label) {
+ dimensionNames[index] = dimensionName;
+ labels[index] = label;
+ index++;
+ }
+
public PartialAddress build() {
return new PartialAddress(this);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
new file mode 100644
index 00000000000..4d3989b8782
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Slice.java
@@ -0,0 +1,266 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.google.common.annotations.Beta;
+import com.yahoo.tensor.PartialAddress;
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Iterator;
+import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Returns a subspace of a tensor
+ *
+ * @author bratseth
+ */
+@Beta
+public class Slice<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final List<DimensionValue<NAMETYPE>> subspaceAddress;
+
+ /**
+ * Creates a value function
+ *
+ * @param argument the tensor to return a cell value from
+ * @param subspaceAddress a description of the address of the cell to return the value of. This is not a TensorAddress
+ * because those require a type, but a type is not resolved until this is evaluated
+ */
+ public Slice(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> subspaceAddress) {
+ this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
+ if (subspaceAddress.size() > 1 && subspaceAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
+ throw new IllegalArgumentException("Short form of subspace addresses is only supported with a single dimension: " +
+ "Specify dimension names explicitly instead");
+ this.subspaceAddress = subspaceAddress;
+ }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
+
+ @Override
+ public Slice<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if (arguments.size() != 1)
+ throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
+ return new Slice<>(arguments.get(0), subspaceAddress);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor tensor = argument.evaluate(context);
+ TensorType resultType = resultType(tensor.type());
+
+ PartialAddress subspaceAddress = subspaceToAddress(tensor.type(), context);
+ if (resultType.rank() == 0) // shortcut common case
+ return Tensor.from(tensor.get(subspaceAddress.asAddress(tensor.type())));
+
+ Tensor.Builder b = Tensor.Builder.of(resultType);
+ for (Iterator<Tensor.Cell> i = tensor.cellIterator(); i.hasNext(); ) {
+ Tensor.Cell cell = i.next();
+ if (matches(subspaceAddress, cell.getKey(), tensor.type()))
+ b.cell(remaining(resultType, cell.getKey(), tensor.type()), cell.getValue());
+ }
+ return b.build();
+ }
+
+ private PartialAddress subspaceToAddress(TensorType type, EvaluationContext<NAMETYPE> context) {
+ PartialAddress.Builder b = new PartialAddress.Builder(subspaceAddress.size());
+ for (int i = 0; i < subspaceAddress.size(); i++) {
+ if (subspaceAddress.get(i).label().isPresent())
+ b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
+ subspaceAddress.get(i).label().get());
+ else
+ b.add(subspaceAddress.get(i).dimension().orElse(type.dimensions().get(i).name()),
+ subspaceAddress.get(i).index().get().apply(context).intValue());
+ }
+ return b.build();
+ }
+
+ private boolean matches(PartialAddress subspaceAddress,
+ TensorAddress address, TensorType type) {
+ for (int i = 0; i < subspaceAddress.size(); i++) {
+ String label = address.label(type.indexOfDimension(subspaceAddress.dimension(i)).get());
+ if ( ! label.equals(subspaceAddress.label(i)))
+ return false;
+ }
+ return true;
+ }
+
+ /** Returns the subset of the given address which is present in the subspace type */
+ private TensorAddress remaining(TensorType subspaceType, TensorAddress address, TensorType type) {
+ TensorAddress.Builder b = new TensorAddress.Builder(subspaceType);
+ for (int i = 0; i < address.size(); i++) {
+ String dimension = type.dimensions().get(i).name();
+ if (subspaceType.dimension(type.dimensions().get(i).name()).isPresent())
+ b.add(dimension, address.label(i));
+ }
+ return b.build();
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return resultType(argument.type(context));
+ }
+
+ private TensorType resultType(TensorType argumentType) {
+ TensorType.Builder b = new TensorType.Builder();
+
+ // Special case where a single indexed or mapped dimension is sliced
+ if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ if (subspaceAddress.get(0).index().isPresent()) {
+ if (argumentType.dimensions().stream().filter(d -> d.isIndexed()).count() > 1)
+ throw new IllegalArgumentException(this + " slices a single indexed dimension, cannot be applied " +
+ " to " + argumentType + ", which have multiple");
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if ( ! dimension.isIndexed())
+ b.dimension(dimension);
+ }
+ }
+ else {
+ if (argumentType.dimensions().stream().filter(d -> ! d.isIndexed()).count() > 1)
+ throw new IllegalArgumentException(this + " slices a single mapped dimension, cannot be applied " +
+ " to " + argumentType + ", which have multiple");
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if (dimension.isIndexed())
+ b.dimension(dimension);
+ }
+
+ }
+ }
+ else { // general slicing
+ Set<String> slicedDimensions = subspaceAddress.stream().map(d -> d.dimension().get()).collect(Collectors.toSet());
+ for (TensorType.Dimension dimension : argumentType.dimensions()) {
+ if (slicedDimensions.contains(dimension.name()))
+ slicedDimensions.remove(dimension.name());
+ else
+ b.dimension(dimension);
+ }
+ if ( ! slicedDimensions.isEmpty())
+ throw new IllegalArgumentException(this + " slices " + slicedDimensions + " which are not present in " +
+ argumentType);
+ }
+ return b.build();
+ }
+
+ @Override
+ public String toString(ToStringContext context) {
+ StringBuilder b = new StringBuilder(argument.toString(context));
+ if (subspaceAddress.size() == 1 && subspaceAddress.get(0).dimension().isEmpty()) {
+ if (subspaceAddress.get(0).index().isPresent())
+ b.append("[").append(subspaceAddress.get(0).index().get().toString(context)).append("]");
+ else
+ b.append("{").append(subspaceAddress.get(0).label().get()).append("}");
+ }
+ else {
+ b.append("{").append(subspaceAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
+ }
+ return b.toString();
+ }
+
+ public static class DimensionValue<NAMETYPE extends Name> {
+
+ private final Optional<String> dimension;
+
+ /** The label of this, or null if index is set */
+ private final String label;
+
+ /** The function returning the index of this, or null if label is set */
+ private final ScalarFunction<NAMETYPE> index;
+
+ public DimensionValue(String dimension, String label) {
+ this(Optional.of(dimension), label, null);
+ }
+
+ public DimensionValue(String dimension, int index) {
+ this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
+ }
+
+ public DimensionValue(int index) {
+ this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
+ }
+
+ public DimensionValue(String label) {
+ this(Optional.empty(), label, null);
+ }
+
+ public DimensionValue(ScalarFunction<NAMETYPE> index) {
+ this(Optional.empty(), null, index);
+ }
+
+ public DimensionValue(Optional<String> dimension, String label) {
+ this(dimension, label, null);
+ }
+
+ public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
+ this(dimension, null, index);
+ }
+
+ public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
+ this(Optional.of(dimension), null, index);
+ }
+
+ private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
+ this.dimension = dimension;
+ this.label = label;
+ this.index = index;
+ }
+
+ /**
+ * Returns the given name of the dimension, or null if dense form is used, such that name
+ * must be inferred from order
+ */
+ public Optional<String> dimension() { return dimension; }
+
+ /** Returns the label for this dimension or empty if it is provided by an index function */
+ public Optional<String> label() { return Optional.ofNullable(label); }
+
+ /** Returns the index expression for this dimension, or empty if it is not a number */
+ public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
+
+ @Override
+ public String toString() {
+ return toString(null);
+ }
+
+ public String toString(ToStringContext context) {
+ StringBuilder b = new StringBuilder();
+ dimension.ifPresent(d -> b.append(d).append(":"));
+ if (label != null)
+ b.append(label);
+ else
+ b.append(index.toString(context));
+ return b.toString();
+ }
+
+ }
+
+ private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
+
+ private final int value;
+
+ public ConstantIntegerFunction(int value) {
+ this.value = value;
+ }
+
+ @Override
+ public Double apply(EvaluationContext<NAMETYPE> context) {
+ return (double)value;
+ }
+
+ @Override
+ public String toString() { return String.valueOf(value); }
+
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
deleted file mode 100644
index 37a54807673..00000000000
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Value.java
+++ /dev/null
@@ -1,189 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.functions;
-
-import com.google.common.annotations.Beta;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorAddress;
-import com.yahoo.tensor.TensorType;
-import com.yahoo.tensor.evaluation.EvaluationContext;
-import com.yahoo.tensor.evaluation.Name;
-import com.yahoo.tensor.evaluation.TypeContext;
-
-import java.util.List;
-import java.util.Objects;
-import java.util.Optional;
-import java.util.stream.Collectors;
-
-/**
- * Returns the value of a cell of a tensor (as a rank 0 tensor).
- *
- * @author bratseth
- */
-@Beta
-public class Value<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
-
- private final TensorFunction<NAMETYPE> argument;
- private final List<DimensionValue<NAMETYPE>> cellAddress;
-
- /**
- * Creates a value function
- *
- * @param argument the tensor to return a cell value from
- * @param cellAddress a description of the address of the cell to return the value of. This is not a TensorAddress
- * because those require a type, but a type is not resolved until this is evaluated
- */
- public Value(TensorFunction<NAMETYPE> argument, List<DimensionValue<NAMETYPE>> cellAddress) {
- this.argument = Objects.requireNonNull(argument, "Argument cannot be null");
- if (cellAddress.size() > 1 && cellAddress.stream().anyMatch(c -> c.dimension().isEmpty()))
- throw new IllegalArgumentException("Short form of cell addresses is only supported with a single dimension: " +
- "Specify dimension names explicitly");
- this.cellAddress = cellAddress;
- }
-
- @Override
- public List<TensorFunction<NAMETYPE>> arguments() { return List.of(argument); }
-
- @Override
- public Value<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
- if (arguments.size() != 1)
- throw new IllegalArgumentException("Value takes exactly one argument but got " + arguments.size());
- return new Value<>(arguments.get(0), cellAddress);
- }
-
- @Override
- public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { return this; }
-
- @Override
- public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
- Tensor tensor = argument.evaluate(context);
- if (tensor.type().rank() != cellAddress.size())
- throw new IllegalArgumentException("Type/address size mismatch: Cannot address a value with " + toString() +
- " to a tensor of type " + tensor.type());
- TensorAddress.Builder b = new TensorAddress.Builder(tensor.type());
- for (int i = 0; i < cellAddress.size(); i++) {
- if (cellAddress.get(i).label().isPresent())
- b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
- cellAddress.get(i).label().get());
- else
- b.add(cellAddress.get(i).dimension().orElse(tensor.type().dimensions().get(i).name()),
- String.valueOf(cellAddress.get(i).index().get().apply(context).intValue()));
- }
- return Tensor.from(tensor.get(b.build()));
- }
-
- @Override
- public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType.Builder(argument.type(context).valueType()).build();
- }
-
- @Override
- public String toString(ToStringContext context) {
- StringBuilder b = new StringBuilder(argument.toString(context));
- if (cellAddress.size() == 1 && cellAddress.get(0).dimension().isEmpty()) {
- if (cellAddress.get(0).index().isPresent())
- b.append("[").append(cellAddress.get(0).index().get().toString(context)).append("]");
- else
- b.append("{").append(cellAddress.get(0).label().get()).append("}");
- }
- else {
- b.append("{").append(cellAddress.stream().map(i -> i.toString(context)).collect(Collectors.joining(", "))).append("}");
- }
- return b.toString();
- }
-
- public static class DimensionValue<NAMETYPE extends Name> {
-
- private final Optional<String> dimension;
-
- /** The label of this, or null if index is set */
- private final String label;
-
- /** The function returning the index of this, or null if label is set */
- private final ScalarFunction<NAMETYPE> index;
-
- public DimensionValue(String dimension, String label) {
- this(Optional.of(dimension), label, null);
- }
-
- public DimensionValue(String dimension, int index) {
- this(Optional.of(dimension), null, new ConstantIntegerFunction<>(index));
- }
-
- public DimensionValue(int index) {
- this(Optional.empty(), null, new ConstantIntegerFunction<>(index));
- }
-
- public DimensionValue(String label) {
- this(Optional.empty(), label, null);
- }
-
- public DimensionValue(ScalarFunction<NAMETYPE> index) {
- this(Optional.empty(), null, index);
- }
-
- public DimensionValue(Optional<String> dimension, String label) {
- this(dimension, label, null);
- }
-
- public DimensionValue(Optional<String> dimension, ScalarFunction<NAMETYPE> index) {
- this(dimension, null, index);
- }
-
- public DimensionValue(String dimension, ScalarFunction<NAMETYPE> index) {
- this(Optional.of(dimension), null, index);
- }
-
- private DimensionValue(Optional<String> dimension, String label, ScalarFunction<NAMETYPE> index) {
- this.dimension = dimension;
- this.label = label;
- this.index = index;
- }
-
- /**
- * Returns the given name of the dimension, or null if dense form is used, such that name
- * must be inferred from order
- */
- public Optional<String> dimension() { return dimension; }
-
- /** Returns the label for this dimension or empty if it is provided by an index function */
- public Optional<String> label() { return Optional.ofNullable(label); }
-
- /** Returns the index expression for this dimension, or empty if it is not a number */
- public Optional<ScalarFunction<NAMETYPE>> index() { return Optional.ofNullable(index); }
-
- @Override
- public String toString() {
- return toString(null);
- }
-
- public String toString(ToStringContext context) {
- StringBuilder b = new StringBuilder();
- dimension.ifPresent(d -> b.append(d).append(":"));
- if (label != null)
- b.append(label);
- else
- b.append(index.toString(context));
- return b.toString();
- }
-
- }
-
- private static class ConstantIntegerFunction<NAMETYPE extends Name> implements ScalarFunction<NAMETYPE> {
-
- private final int value;
-
- public ConstantIntegerFunction(int value) {
- this.value = value;
- }
-
- @Override
- public Double apply(EvaluationContext<NAMETYPE> context) {
- return (double)value;
- }
-
- @Override
- public String toString() { return String.valueOf(value); }
-
- }
-
-}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java
new file mode 100644
index 00000000000..55e6151f7e9
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/tensor/functions/SliceTestCase.java
@@ -0,0 +1,138 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import org.junit.Test;
+
+import java.util.List;
+
+import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.fail;
+
+/**
+ * @author bratseth
+ */
+public class SliceTestCase {
+
+ private static final double delta = 0.000001;
+
+ @Test
+ public void testSliceFunctionGeneralFormToRank0() {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("key", "bar"),
+ new Slice.DimensionValue<>("x", 0)))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(2.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToRank0ReverseDimensionOrder() {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 0),
+ new Slice.DimensionValue<>("key", "bar")))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(2.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToIndexedRank2to1() {
+ Tensor input = Tensor.from("tensor(key{},x[2]):{ {key:foo,x:0}:1.3, {key:foo,x:1}:1.4, {key:bar,x:0}:2.3, {key:bar,x:1}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("key", "bar")))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(x[2]):[2.3, 2.4]]"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank2to1() {
+ Tensor input = Tensor.from("tensor(key{},x[2]):{ {key:foo,x:0}:1.3, {key:foo,x:1}:1.4, {key:bar,x:0}:2.3, {key:bar,x:1}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 0)))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.3, {key:bar}:2.3}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank3to1() {
+ Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 1),
+ new Slice.DimensionValue<>("y", 0)))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.4, {key:bar}:2.4}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank3to1ReverseDimensionOrder() {
+ Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("y", 0),
+ new Slice.DimensionValue<>("x", 1)))
+ .evaluate();
+ assertEquals(1, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{}):{{key:foo}:1.4, {key:bar}:2.4}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionGeneralFormToMappedRank3to2() {
+ Tensor input = Tensor.from("tensor(key{},x[2],y[1]):{ {key:foo,x:0,y:0}:1.3, {key:foo,x:1,y:0}:1.4, {key:bar,x:0,y:0}:2.3, {key:bar,x:1,y:0}:2.4 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("x", 1)))
+ .evaluate();
+ assertEquals(2, result.type().rank());
+ assertEquals(Tensor.from("tensor(key{},y[1]):{{key:foo,y:0}:1.4, {key:bar,y:0}:2.4}"), result);
+ }
+
+ @Test
+ public void testSliceFunctionSingleMappedDimensionToRank0() {
+ Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("foo")))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(1.4, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionSingleIndexedDimensionToRank0() {
+ Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
+ Tensor result = new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>(2)))
+ .evaluate();
+ assertEquals(0, result.type().rank());
+ assertEquals(3.3, result.asDouble(), delta);
+ }
+
+ @Test
+ public void testSliceFunctionShortFormWithMultipleDimensionsIsNotAllowed() {
+ try {
+ Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
+ new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>("bar"),
+ new Slice.DimensionValue<>(0)))
+ .evaluate();
+ fail("Expected exception");
+ }
+ catch (IllegalArgumentException e) {
+ assertEquals("Short form of subspace addresses is only supported with a single dimension: Specify dimension names explicitly instead",
+ e.getMessage());
+ }
+ }
+
+ @Test
+ public void testToString() {
+ Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
+ assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]",
+ new Slice<>(new ConstantTensor<>(input),
+ List.of(new Slice.DimensionValue<>(2)))
+ .toString());
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
deleted file mode 100644
index 227fbffbaa8..00000000000
--- a/vespajlib/src/test/java/com/yahoo/tensor/functions/ValueTestCase.java
+++ /dev/null
@@ -1,75 +0,0 @@
-// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package com.yahoo.tensor.functions;
-
-import com.yahoo.tensor.Tensor;
-import org.junit.Test;
-
-import java.util.List;
-
-import static org.junit.Assert.assertEquals;
-import static org.junit.Assert.fail;
-
-/**
- * @author bratseth
- */
-public class ValueTestCase {
-
- private static final double delta = 0.000001;
-
- @Test
- public void testValueFunctionGeneralForm() {
- Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
- Tensor result = new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>("key", "bar"),
- new Value.DimensionValue<>("x", 0)))
- .evaluate();
- assertEquals(0, result.type().rank());
- assertEquals(2.3, result.asDouble(), delta);
- }
-
- @Test
- public void testValueFunctionSingleMappedDimension() {
- Tensor input = Tensor.from("tensor(key{}):{ {key:foo}:1.4, {key:bar}:2.3 }");
- Tensor result = new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>("foo")))
- .evaluate();
- assertEquals(0, result.type().rank());
- assertEquals(1.4, result.asDouble(), delta);
- }
-
- @Test
- public void testValueFunctionSingleIndexedDimension() {
- Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
- Tensor result = new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>(2)))
- .evaluate();
- assertEquals(0, result.type().rank());
- assertEquals(3.3, result.asDouble(), delta);
- }
-
- @Test
- public void testValueFunctionShortFormWithMultipleDimensionsIsNotAllowed() {
- try {
- Tensor input = Tensor.from("tensor(key{},x{}):{ {key:foo,x:0}:1.4, {key:bar,x:0}:2.3 }");
- new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>("bar"),
- new Value.DimensionValue<>(0)))
- .evaluate();
- fail("Expected exception");
- }
- catch (IllegalArgumentException e) {
- assertEquals("Short form of cell addresses is only supported with a single dimension: Specify dimension names explicitly",
- e.getMessage());
- }
- }
-
- @Test
- public void testToString() {
- Tensor input = Tensor.from("tensor(key[3]):[1.1, 2.2, 3.3]");
- assertEquals("tensor(key[3]):[1.1, 2.2, 3.3][2]",
- new Value<>(new ConstantTensor<>(input),
- List.of(new Value.DimensionValue<>(2)))
- .toString());
- }
-
-}