diff options
author | Lester Solbakken <lesters@oath.com> | 2021-10-06 10:52:56 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-10-06 10:52:56 +0200 |
commit | bcd003d3253a5e51c19149dcc8fa44e8fd526adb (patch) | |
tree | 8c8505ae2a075996a4724da106337262398ad72e /vespajlib/src | |
parent | 4de0026c1065403d028d7157abb571830603e6c9 (diff) |
Add non-primitive tensor expand function
Diffstat (limited to 'vespajlib/src')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java | 48 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java | 16 |
2 files changed, 64 insertions, 0 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java new file mode 100644 index 00000000000..8fc246a7d9d --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Expand.java @@ -0,0 +1,48 @@ +// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.tensor.functions; + +import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.evaluation.Name; + +import java.util.Collections; +import java.util.List; + +/** + * The <i>expand</i> tensor function returns a tensor with a new dimension of + * size 1 is added, equivalent to "tensor * tensor(dim_name[1])(1)". + * + * @author lesters + */ +public class Expand<NAMETYPE extends Name> extends CompositeTensorFunction<NAMETYPE> { + + private final TensorFunction<NAMETYPE> argument; + private final String dimensionName; + + public Expand(TensorFunction<NAMETYPE> argument, String dimension) { + this.argument = argument; + this.dimensionName = dimension; + } + + @Override + public List<TensorFunction<NAMETYPE>> arguments() { return Collections.singletonList(argument); } + + @Override + public TensorFunction<NAMETYPE> withArguments(List<TensorFunction<NAMETYPE>> arguments) { + if (arguments.size() != 1) + throw new IllegalArgumentException("Expand must have 1 argument, got " + arguments.size()); + return new Expand<>(arguments.get(0), dimensionName); + } + + @Override + public PrimitiveTensorFunction<NAMETYPE> toPrimitive() { + TensorType type = new TensorType.Builder(TensorType.Value.INT8).indexed(dimensionName, 1).build(); + Generate<NAMETYPE> expansion = new Generate<>(type, ScalarFunctions.constant(1.0)); + return new Join<>(expansion, argument, ScalarFunctions.multiply()); + } + + @Override + public String toString(ToStringContext context) { + return "expand(" + argument.toString(context) + ", " + dimensionName + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index d6fcd17b8fb..4c2b64244e5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -66,6 +66,7 @@ public class ScalarFunctions { public static Function<List<Long>, Double> random() { return new Random(); } public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } + public static Function<List<Long>, Double> constant(double value) { return new Constant(value); } // Binary operators ----------------------------------------------------------------------------- @@ -493,4 +494,19 @@ public class ScalarFunctions { } } + public static class Constant implements Function<List<Long>, Double> { + private final double value; + + public Constant(double value) { + this.value = value; + } + @Override + public Double apply(List<Long> values) { + return value; + } + @Override + public String toString() { return Double.toString(value); } + } + + } |