summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
diff options
context:
space:
mode:
authorArne Juul <arnej@vespa.ai>2023-11-02 08:47:23 +0000
committerArne Juul <arnej@vespa.ai>2023-11-02 19:54:24 +0000
commitbd9d7a9f74d41f2e88694aa2f1629ced0bca6428 (patch)
treeaf40320eae453618b6c00b854f2cf5d72d17e26e /vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
parent96f6abe9caa338074ee39cb2fd566d3efff464c9 (diff)
add reference implementation of MapSubspaces
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java55
1 files changed, 55 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
new file mode 100644
index 00000000000..b6655a15361
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/DenseSubspaceFunction.java
@@ -0,0 +1,55 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.tensor.functions;
+
+import com.yahoo.tensor.Tensor;
+import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.evaluation.EvaluationContext;
+import com.yahoo.tensor.evaluation.MapEvaluationContext;
+import com.yahoo.tensor.evaluation.Name;
+import com.yahoo.tensor.evaluation.TypeContext;
+
+import java.util.Optional;
+import java.util.function.Function;
+
+/**
+ * A function suitable for use in MapSubspaces
+ *
+ * @author arnej
+ */
+class DenseSubspaceFunction<NAMETYPE extends Name> {
+
+ private final String argName;
+ private final TensorFunction<NAMETYPE> function;
+
+ public DenseSubspaceFunction(String argName, TensorFunction<NAMETYPE> function) {
+ this.argName = argName;
+ this.function = function;
+ }
+
+ Tensor map(Tensor subspace) {
+ var context = new MapEvaluationContext<NAMETYPE>();
+ context.put(argName, subspace);
+ return function.evaluate(context);
+ }
+
+ class MyTypeContext implements TypeContext<NAMETYPE> {
+ private final TensorType subspaceType;
+ MyTypeContext(TensorType subspaceType) { this.subspaceType = subspaceType; }
+ public TensorType getType(NAMETYPE name) { return getType(name.name()); }
+ public TensorType getType(String name) { return argName.equals(name) ? subspaceType : null; }
+ }
+
+ TensorType outputType(TensorType subspaceType) {
+ var context = new MyTypeContext(subspaceType);
+ var result = function.type(context);
+ if (result.mappedSubtype().rank() > 0) {
+ throw new IllegalArgumentException("function used in map_subspaces type had mapped dimensions: " + result);
+ }
+ return result;
+ }
+
+ public String toString() {
+ return "f(" + argName + ")(" + function + ")";
+ }
+
+}