aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-10-06 11:07:14 +0200
committerLester Solbakken <lesters@oath.com>2021-10-06 11:07:14 +0200
commit393b8284e615e251d1477ccbeee88398ed3388fd (patch)
treef2b6df06cdf6cb26867ba44d7298d7d7876ff8a8 /vespajlib
parent1fdf04b093030529563020ccea7894f24848d2a0 (diff)
Add expand function to tensor class
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java24
2 files changed, 7 insertions, 18 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 8387a3d1a83..c426195bc37 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1180,6 +1180,7 @@
"public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)",
"public com.yahoo.tensor.Tensor softmax(java.lang.String)",
"public com.yahoo.tensor.Tensor xwPlusB(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor, java.lang.String)",
+ "public com.yahoo.tensor.Tensor expand(java.lang.String)",
"public com.yahoo.tensor.Tensor argmax(java.lang.String)",
"public com.yahoo.tensor.Tensor argmin(java.lang.String)",
"public static com.yahoo.tensor.Tensor diag(com.yahoo.tensor.TensorType)",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3d4536d9249..c9be90bf20b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -1,25 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Argmax;
-import com.yahoo.tensor.functions.Argmin;
-import com.yahoo.tensor.functions.CellCast;
-import com.yahoo.tensor.functions.Concat;
-import com.yahoo.tensor.functions.ConstantTensor;
-import com.yahoo.tensor.functions.Diag;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.L1Normalize;
-import com.yahoo.tensor.functions.L2Normalize;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Merge;
-import com.yahoo.tensor.functions.Random;
-import com.yahoo.tensor.functions.Range;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.Softmax;
-import com.yahoo.tensor.functions.XwPlusB;
+import com.yahoo.tensor.functions.*;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
@@ -210,6 +194,10 @@ public interface Tensor {
return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate();
}
+ default Tensor expand(String dimension) {
+ return new Expand<>(new ConstantTensor<>(this), dimension).evaluate();
+ }
+
default Tensor argmax(String dimension) {
return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate();
}