diff options
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 1 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 24 |
2 files changed, 7 insertions, 18 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 8387a3d1a83..c426195bc37 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -1180,6 +1180,7 @@ "public com.yahoo.tensor.Tensor matmul(com.yahoo.tensor.Tensor, java.lang.String)", "public com.yahoo.tensor.Tensor softmax(java.lang.String)", "public com.yahoo.tensor.Tensor xwPlusB(com.yahoo.tensor.Tensor, com.yahoo.tensor.Tensor, java.lang.String)", + "public com.yahoo.tensor.Tensor expand(java.lang.String)", "public com.yahoo.tensor.Tensor argmax(java.lang.String)", "public com.yahoo.tensor.Tensor argmin(java.lang.String)", "public static com.yahoo.tensor.Tensor diag(com.yahoo.tensor.TensorType)", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 3d4536d9249..c9be90bf20b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -1,25 +1,9 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; -import com.yahoo.tensor.functions.Argmax; -import com.yahoo.tensor.functions.Argmin; -import com.yahoo.tensor.functions.CellCast; -import com.yahoo.tensor.functions.Concat; -import com.yahoo.tensor.functions.ConstantTensor; -import com.yahoo.tensor.functions.Diag; -import com.yahoo.tensor.functions.Generate; -import com.yahoo.tensor.functions.Join; -import com.yahoo.tensor.functions.L1Normalize; -import com.yahoo.tensor.functions.L2Normalize; -import com.yahoo.tensor.functions.Matmul; -import com.yahoo.tensor.functions.Merge; -import com.yahoo.tensor.functions.Random; -import com.yahoo.tensor.functions.Range; -import com.yahoo.tensor.functions.Reduce; -import com.yahoo.tensor.functions.Rename; -import com.yahoo.tensor.functions.Softmax; -import com.yahoo.tensor.functions.XwPlusB; +import com.yahoo.tensor.functions.*; import com.yahoo.text.Ascii7BitMatcher; import java.util.ArrayList; @@ -210,6 +194,10 @@ public interface Tensor { return new XwPlusB<>(new ConstantTensor<>(this), new ConstantTensor<>(w), new ConstantTensor<>(b), dimension).evaluate(); } + default Tensor expand(String dimension) { + return new Expand<>(new ConstantTensor<>(this), dimension).evaluate(); + } + default Tensor argmax(String dimension) { return new Argmax<>(new ConstantTensor<>(this), dimension).evaluate(); } |