diff options
author | Arne Juul <arnej@vespa.ai> | 2023-11-02 08:47:23 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2023-11-02 19:54:24 +0000 |
commit | bd9d7a9f74d41f2e88694aa2f1629ced0bca6428 (patch) | |
tree | af40320eae453618b6c00b854f2cf5d72d17e26e /vespajlib/src/main/java | |
parent | 96f6abe9caa338074ee39cb2fd566d3efff464c9 (diff) |
add reference implementation of MapSubspaces
Diffstat (limited to 'vespajlib/src/main/java')
5 files changed, 210 insertions, 3 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 084eaf2bf98..7f890a9ec51 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -80,7 +80,7 @@ public class TensorType { }; /** The empty tensor type - which is the same as a double */ - public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList()); + public static final TensorType empty = new TensorType(); private final Value valueType; @@ -90,6 +90,13 @@ public class TensorType { private final TensorType mappedSubtype; private final TensorType indexedSubtype; + private TensorType() { + this.valueType = Value.DOUBLE; + this.dimensions = List.of(); + this.mappedSubtype = this; + this.indexedSubtype = this; + } + public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; List<Dimension> dimensionList = new ArrayList<>(dimensions); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java index d1407bf3d9b..d875f1ef4eb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/evaluation/TypeContext.java @@ -26,5 +26,4 @@ public interface TypeContext<NAMETYPE extends Name> { */ TensorType getType(String name); - } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java new file mode 100644 index 00000000000..b6655a15361 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java @@ -0,0 +1,55 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Optional; +import java.util.function.Function; + +/** + * A function suitable for use in MapSubspaces + * + * @author arnej + */ +class DenseSubspaceFunction<NAMETYPE extends Name> { + + private final String argName; + private final TensorFunction<NAMETYPE> function; + + public DenseSubspaceFunction(String argName, TensorFunction<NAMETYPE> function) { + this.argName = argName; + this.function = function; + } + + Tensor map(Tensor subspace) { + var context = new MapEvaluationContext<NAMETYPE>(); + context.put(argName, subspace); + return function.evaluate(context); + } + + class MyTypeContext implements TypeContext<NAMETYPE> { + private final TensorType subspaceType; + MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; } + public TensorType getType(NAMETYPE name) { return getType(name.name()); } + public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; } + } + + TensorType outputType(TensorType subspaceType) { + var context = new MyTypeContext(subspaceType); + var result = function.type(context); + if (result.mappedSubtype().rank() > 0) { + throw new IllegalArgumentException("function used in map_subspaces type had mapped dimensions: " + result); + } + return result; + } + + public String toString() { + return "f(" + argName + ")(" + function + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java new file mode 100644 index 00000000000..c87ef42976d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/MapSubspaces.java @@ -0,0 +1,146 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorAddress; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.List; +import java.util.Map; +import java.util.Objects; + +/** + * The <i>map_subspaces</i> tensor function transforms each dense subspace in a (mixed) tensor + * + * @author arnej + */ +public class MapSubspaces<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final DenseSubspaceFunction<NAMETYPE> function; + + private MapSubspaces(TensorFunction<NAMETYPE> argument, DenseSubspaceFunction<NAMETYPE> function) { + this.argument = argument; + this.function = function; + } + public MapSubspaces(TensorFunction<NAMETYPE> argument, String functionArg, TensorFunction<NAMETYPE> function) { + this(argument, new DenseSubspaceFunction<>(functionArg, function)); + Objects.requireNonNull(argument, "The argument cannot be null"); + Objects.requireNonNull(functionArg, "The functionArg cannot be null"); + Objects.requireNonNull(function, "The function cannot be null"); + } + + private TensorType outputType(TensorType inputType) { + var m = inputType.mappedSubtype(); + var d = function.outputType(inputType.indexedSubtype()); + if (m.rank() == 0) { + return d; + } + if (d.rank() == 0) { + return TypeResolver.map(m); // decay cell type + } + TensorType.Value cellType = d.valueType(); + Map<String, TensorType.Dimension> dims = new HashMap<>(); + for (var dim : m.dimensions()) { + dims.put(dim.name(), dim); + } + for (var dim : d.dimensions()) { + var old = dims.put(dim.name(), dim); + if (old != null) { + throw new IllegalArgumentException("dimension name collision in map_subspaces: " + m + " vs " + d); + } + } + return new TensorType(cellType, dims.values()); + } + + public TensorFunction<NAMETYPE> argument() { return argument; } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if ( arguments.size() != 1) + throw new IllegalArgumentException("MapSubspaces must have 1 argument, got " + arguments.size()); + return new MapSubspaces<NAMETYPE>(arguments.get(0), function); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + return new MapSubspaces<>(argument.toPrimitive(), function); + } + + @Override + public TensorType type(TypeContext<NAMETYPE> context) { + return outputType(argument.type(context)); + } + + @Override + public Tensor evaluate(EvaluationContext<NAMETYPE> context) { + Tensor input = argument().evaluate(context); + TensorType inputType = input.type(); + TensorType inputTypeMapped = inputType.mappedSubtype(); + TensorType inputTypeDense = inputType.indexedSubtype(); + Map<TensorAddress, Tensor.Builder> builders = new HashMap<>(); + for (Iterator<Tensor.Cell> iter = input.cellIterator(); iter.hasNext(); ) { + var cell = iter.next(); + var fullAddr = cell.getKey(); + var mapAddrBuilder = new TensorAddress.Builder(inputTypeMapped); + var idxAddrBuilder = new TensorAddress.Builder(inputTypeDense); + for (int i = 0; i < inputType.dimensions().size(); i++) { + var dim = inputType.dimensions().get(i); + if (dim.isMapped()) { + mapAddrBuilder.add(dim.name(), fullAddr.label(i)); + } else { + idxAddrBuilder.add(dim.name(), fullAddr.label(i)); + } + } + var mapAddr = mapAddrBuilder.build(); + var builder = builders.computeIfAbsent(mapAddr, k -> Tensor.Builder.of(inputTypeDense)); + var idxAddr = idxAddrBuilder.build(); + builder.cell(idxAddr, cell.getValue()); + } + TensorType outputType = outputType(input.type()); + TensorType denseOutputType = outputType.indexedSubtype(); + var denseOutputDims = denseOutputType.dimensions(); + Tensor.Builder builder = Tensor.Builder.of(outputType); + for (var entry : builders.entrySet()) { + TensorAddress mappedAddr = entry.getKey(); + Tensor denseInput = entry.getValue().build(); + Tensor denseOutput = function.map(denseInput); + // XXX check denseOutput.type().dimensions() + for (Iterator<Tensor.Cell> iter = denseOutput.cellIterator(); iter.hasNext(); ) { + var cell = iter.next(); + var denseAddr = cell.getKey(); + var addrBuilder = new TensorAddress.Builder(outputType); + for (int i = 0; i < inputTypeMapped.dimensions().size(); i++) { + var dim = inputTypeMapped.dimensions().get(i); + addrBuilder.add(dim.name(), mappedAddr.label(i)); + } + for (int i = 0; i < denseOutputDims.size(); i++) { + var dim = denseOutputDims.get(i); + addrBuilder.add(dim.name(), denseAddr.label(i)); + } + builder.cell(addrBuilder.build(), cell.getValue()); + } + } + return builder.build(); + } + + @Override + public String toString(ToStringContext<NAMETYPE> context) { + return "map_subspaces(" + argument.toString(context) + ", " + function + ")"; + } + + @Override + public int hashCode() { return Objects.hash("map_subspaces", argument, function); } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java index 05c5f412c39..790c956f9c8 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/TensorFunction.java @@ -44,7 +44,7 @@ public abstract class TensorFunction<NAMETYPE extends Name> { * * @param context a context which must be passed to all nested functions when evaluating */ - public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context); + public abstract Tensor evaluate(EvaluationContext<NAMETYPE> context); /** * Returns the type of the tensor this produces given the input types in the context |