summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-10-06 11:07:14 +0200
committerLester Solbakken <lesters@oath.com>2021-10-06 11:07:14 +0200
commit393b8284e615e251d1477ccbeee88398ed3388fd (patch)
treef2b6df06cdf6cb26867ba44d7298d7d7876ff8a8 /vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
parent1fdf04b093030529563020ccea7894f24848d2a0 (diff)
Add expand function to tensor class
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/Tensor.java')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java24
1 files changed, 6 insertions, 18 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3d4536d9249..c9be90bf20b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -1,25 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Argmax;
-import com.yahoo.tensor.functions.Argmin;
-import com.yahoo.tensor.functions.CellCast;
-import com.yahoo.tensor.functions.Concat;
-import com.yahoo.tensor.functions.ConstantTensor;
-import com.yahoo.tensor.functions.Diag;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.L1Normalize;
-import com.yahoo.tensor.functions.L2Normalize;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Merge;
-import com.yahoo.tensor.functions.Random;
-import com.yahoo.tensor.functions.Range;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.Softmax;
-import com.yahoo.tensor.functions.XwPlusB;
+import com.yahoo.tensor.functions.*;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
@@ -210,6 +194,10 @@ public interface Tensor {
return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate();
}
+ default Tensor expand(String dimension) {
+ return new Expand<>(new ConstantTensor<>(this), dimension).evaluate();
+ }
+
default Tensor argmax(String dimension) {
return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate();
}