diff options
author | Arne Juul <arnej@vespa.ai> | 2023-11-02 08:47:23 +0000 |
---|---|---|
committer | Arne Juul <arnej@vespa.ai> | 2023-11-02 19:54:24 +0000 |
commit | bd9d7a9f74d41f2e88694aa2f1629ced0bca6428 (patch) | |
tree | af40320eae453618b6c00b854f2cf5d72d17e26e /vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java | |
parent | 96f6abe9caa338074ee39cb2fd566d3efff464c9 (diff) |
add reference implementation of MapSubspaces
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java | 55 |
1 files changed, 55 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java new file mode 100644 index 00000000000..b6655a15361 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java @@ -0,0 +1,55 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.Tensor; +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.EvaluationContext; +import com.yahoo.tensor.evaluation.MapEvaluationContext; +import com.yahoo.tensor.evaluation.Name; +import com.yahoo.tensor.evaluation.TypeContext; + +import java.util.Optional; +import java.util.function.Function; + +/** + * A function suitable for use in MapSubspaces + * + * @author arnej + */ +class DenseSubspaceFunction<NAMETYPE extends Name> { + + private final String argName; + private final TensorFunction<NAMETYPE> function; + + public DenseSubspaceFunction(String argName, TensorFunction<NAMETYPE> function) { + this.argName = argName; + this.function = function; + } + + Tensor map(Tensor subspace) { + var context = new MapEvaluationContext<NAMETYPE>(); + context.put(argName, subspace); + return function.evaluate(context); + } + + class MyTypeContext implements TypeContext<NAMETYPE> { + private final TensorType subspaceType; + MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; } + public TensorType getType(NAMETYPE name) { return getType(name.name()); } + public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; } + } + + TensorType outputType(TensorType subspaceType) { + var context = new MyTypeContext(subspaceType); + var result = function.type(context); + if (result.mappedSubtype().rank() > 0) { + throw new IllegalArgumentException("function used in map_subspaces type had mapped dimensions: " + result); + } + return result; + } + + public String toString() { + return "f(" + argName + ")(" + function + ")"; + } + +} |