aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java24
1 files changed, 6 insertions, 18 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index 3d4536d9249..c9be90bf20b 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -1,25 +1,9 @@
// Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
-import com.yahoo.tensor.functions.Argmax;
-import com.yahoo.tensor.functions.Argmin;
-import com.yahoo.tensor.functions.CellCast;
-import com.yahoo.tensor.functions.Concat;
-import com.yahoo.tensor.functions.ConstantTensor;
-import com.yahoo.tensor.functions.Diag;
-import com.yahoo.tensor.functions.Generate;
-import com.yahoo.tensor.functions.Join;
-import com.yahoo.tensor.functions.L1Normalize;
-import com.yahoo.tensor.functions.L2Normalize;
-import com.yahoo.tensor.functions.Matmul;
-import com.yahoo.tensor.functions.Merge;
-import com.yahoo.tensor.functions.Random;
-import com.yahoo.tensor.functions.Range;
-import com.yahoo.tensor.functions.Reduce;
-import com.yahoo.tensor.functions.Rename;
-import com.yahoo.tensor.functions.Softmax;
-import com.yahoo.tensor.functions.XwPlusB;
+import com.yahoo.tensor.functions.*;
import com.yahoo.text.Ascii7BitMatcher;
import java.util.ArrayList;
@@ -210,6 +194,10 @@ public interface Tensor {
return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate();
}
+ default Tensor expand(String dimension) {
+ return new Expand<>(new ConstantTensor<>(this), dimension).evaluate();
+ }
+
default Tensor argmax(String dimension) {
return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate();
}