summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-11-02 08:47:23 +0000
committerArne Juul <arnej@vespa.ai>2023-11-02 19:54:24 +0000
commitbd9d7a9f74d41f2e88694aa2f1629ced0bca6428 (patch)
treeaf40320eae453618b6c00b854f2cf5d72d17e26e /vespajlib/src/main/java/com/yahoo/tensor/functions
parent96f6abe9caa338074ee39cb2fd566d3efff464c9 (diff)
add reference implementation of MapSubspaces
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java55
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java146
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java2
3 files changed, 202 insertions, 1 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
new file mode 100644
index 00000000000..b6655a15361
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
@@ -0,0 +1,55 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * A function suitable for use in MapSubspaces
+ *
+ * @author arnej
+ */
+class DenseSubspaceFunction<NAMETYPE extends Name> {
+
+ private final String argName;
+ private final TensorFunction<NAMETYPE> function;
+
+ public DenseSubspaceFunction(String argName, TensorFunction<NAMETYPE> function) {
+ this.argName = argName;
+ this.function = function;
+ }
+
+ Tensor map(Tensor subspace) {
+ var context = new MapEvaluationContext<NAMETYPE>();
+ context.put(argName, subspace);
+ return function.evaluate(context);
+ }
+
+ class MyTypeContext implements TypeContext<NAMETYPE> {
+ private final TensorType subspaceType;
+ MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; }
+ public TensorType getType(NAMETYPE name) { return getType(name.name()); }
+ public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; }
+ }
+
+ TensorType outputType(TensorType subspaceType) {
+ var context = new MyTypeContext(subspaceType);
+ var result = function.type(context);
+ if (result.mappedSubtype().rank() > 0) {
+ throw new IllegalArgumentException("function used in map_subspaces type had mapped dimensions: " + result);
+ }
+ return result;
+ }
+
+ public String toString() {
+ return "f(" + argName + ")(" + function + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
new file mode 100644
index 00000000000..c87ef42976d
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java
@@ -0,0 +1,146 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorAddress;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Collections;
+import java.util.HashMap;
+import java.util.Iterator;
+import java.util.List;
+import java.util.List;
+import java.util.Map;
+import java.util.Objects;
+
+/**
+ * The <i>map_subspaces</i> tensor function transforms each dense subspace in a (mixed) tensor
+ *
+ * @author arnej
+ */
+public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> {
+
+ private final TensorFunction<NAMETYPE> argument;
+ private final DenseSubspaceFunction<NAMETYPE> function;
+
+ private MapSubspaces(TensorFunction<NAMETYPE> argument, DenseSubspaceFunction<NAMETYPE> function) {
+ this.argument = argument;
+ this.function = function;
+ }
+ public MapSubspaces(TensorFunction<NAMETYPE> argument, String functionArg, TensorFunction<NAMETYPE> function) {
+ this(argument, new DenseSubspaceFunction<>(functionArg, function));
+ Objects.requireNonNull(argument, "The argument cannot be null");
+ Objects.requireNonNull(functionArg, "The functionArg cannot be null");
+ Objects.requireNonNull(function, "The function cannot be null");
+ }
+
+ private TensorType outputType(TensorType inputType) {
+ var m = inputType.mappedSubtype();
+ var d = function.outputType(inputType.indexedSubtype());
+ if (m.rank() == 0) {
+ return d;
+ }
+ if (d.rank() == 0) {
+ return TypeResolver.map(m); // decay cell type
+ }
+ TensorType.Value cellType = d.valueType();
+ Map<String, TensorType.Dimension> dims = new HashMap<>();
+ for (var dim : m.dimensions()) {
+ dims.put(dim.name(), dim);
+ }
+ for (var dim : d.dimensions()) {
+ var old = dims.put(dim.name(), dim);
+ if (old != null) {
+ throw new IllegalArgumentException("dimension name collision in map_subspaces: " + m + " vs " + d);
+ }
+ }
+ return new TensorType(cellType, dims.values());
+ }
+
+ public TensorFunction<NAMETYPE> argument() { return argument; }
+
+ @Override
+ public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); }
+
+ @Override
+ public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) {
+ if ( arguments.size() != 1)
+ throw new IllegalArgumentException("MapSubspaces must have 1 argument, got " + arguments.size());
+ return new MapSubspaces<NAMETYPE>(arguments.get(0), function);
+ }
+
+ @Override
+ public PrimitiveTensorFunction<NAMETYPE> toPrimitive() {
+ return new MapSubspaces<>(argument.toPrimitive(), function);
+ }
+
+ @Override
+ public TensorType type(TypeContext<NAMETYPE> context) {
+ return outputType(argument.type(context));
+ }
+
+ @Override
+ public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
+ Tensor input = argument().evaluate(context);
+ TensorType inputType = input.type();
+ TensorType inputTypeMapped = inputType.mappedSubtype();
+ TensorType inputTypeDense = inputType.indexedSubtype();
+ Map<TensorAddress, Tensor.Builder> builders = new HashMap<>();
+ for (Iterator<Tensor.Cell> iter = input.cellIterator(); iter.hasNext(); ) {
+ var cell = iter.next();
+ var fullAddr = cell.getKey();
+ var mapAddrBuilder = new TensorAddress.Builder(inputTypeMapped);
+ var idxAddrBuilder = new TensorAddress.Builder(inputTypeDense);
+ for (int i = 0; i < inputType.dimensions().size(); i++) {
+ var dim = inputType.dimensions().get(i);
+ if (dim.isMapped()) {
+ mapAddrBuilder.add(dim.name(), fullAddr.label(i));
+ } else {
+ idxAddrBuilder.add(dim.name(), fullAddr.label(i));
+ }
+ }
+ var mapAddr = mapAddrBuilder.build();
+ var builder = builders.computeIfAbsent(mapAddr, k -> Tensor.Builder.of(inputTypeDense));
+ var idxAddr = idxAddrBuilder.build();
+ builder.cell(idxAddr, cell.getValue());
+ }
+ TensorType outputType = outputType(input.type());
+ TensorType denseOutputType = outputType.indexedSubtype();
+ var denseOutputDims = denseOutputType.dimensions();
+ Tensor.Builder builder = Tensor.Builder.of(outputType);
+ for (var entry : builders.entrySet()) {
+ TensorAddress mappedAddr = entry.getKey();
+ Tensor denseInput = entry.getValue().build();
+ Tensor denseOutput = function.map(denseInput);
+ // XXX check denseOutput.type().dimensions()
+ for (Iterator<Tensor.Cell> iter = denseOutput.cellIterator(); iter.hasNext(); ) {
+ var cell = iter.next();
+ var denseAddr = cell.getKey();
+ var addrBuilder = new TensorAddress.Builder(outputType);
+ for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) {
+ var dim = inputTypeMapped.dimensions().get(i);
+ addrBuilder.add(dim.name(), mappedAddr.label(i));
+ }
+ for (int i = 0; i < denseOutputDims.size(); i++) {
+ var dim = denseOutputDims.get(i);
+ addrBuilder.add(dim.name(), denseAddr.label(i));
+ }
+ builder.cell(addrBuilder.build(), cell.getValue());
+ }
+ }
+ return builder.build();
+ }
+
+ @Override
+ public String toString(ToStringContext<NAMETYPE> context) {
+ return "map_subspaces(" + argument.toString(context) + ", " + function + ")";
+ }
+
+ @Override
+ public int hashCode() { return Objects.hash("map_subspaces", argument, function); }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
index 05c5f412c39..790c956f9c8 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java
@@ -44,7 +44,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> {
*
* @param context a context which must be passed to all nested functions when evaluating
*/
- public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context);
+ public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context);
/**
* Returns the type of the tensor this produces given the input types in the context