diff options
author | Lester Solbakken <lesters@oath.com> | 2021-10-06 11:07:14 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2021-10-06 11:07:14 +0200 |
commit | 393b8284e615e251d1477ccbeee88398ed3388fd (patch) | |
tree | f2b6df06cdf6cb26867ba44d7298d7d7876ff8a8 /vespajlib | |
parent | 1fdf04b093030529563020ccea7894f24848d2a0 (diff) |
Add expand function to tensor class
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 1 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 24 |
2 files changed, 7 insertions, 18 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 8387a3d1a83..c426195bc37 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1180,6 +1180,7 @@ "public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)", "public com.yahoo.tensor.Tensor softmax(java.lang.String)", "public com.yahoo.tensor.Tensor xwPlusB(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor, java.lang.String)", + "public com.yahoo.tensor.Tensor expand(java.lang.String)", "public com.yahoo.tensor.Tensor argmax(java.lang.String)", "public com.yahoo.tensor.Tensor argmin(java.lang.String)", "public static com.yahoo.tensor.Tensor diag(com.yahoo.tensor.TensorType)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 3d4536d9249..c9be90bf20b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -1,25 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Argmax; -import com.yahoo.tensor.functions.Argmin; -import com.yahoo.tensor.functions.CellCast; -import com.yahoo.tensor.functions.Concat; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Diag; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.L1Normalize; -import com.yahoo.tensor.functions.L2Normalize; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Merge; -import com.yahoo.tensor.functions.Random; -import com.yahoo.tensor.functions.Range; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.XwPlusB; +import com.yahoo.tensor.functions.*; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -210,6 +194,10 @@ public interface Tensor { return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate(); } + default Tensor expand(String dimension) { + return new Expand<>(new ConstantTensor<>(this), dimension).evaluate(); + } + default Tensor argmax(String dimension) { return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate(); } |