diff options
author | Jon Bratseth <bratseth@gmail.com> | 2023-01-27 09:38:43 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@gmail.com> | 2023-01-27 09:38:43 +0100 |
commit | 35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (patch) | |
tree | da3ca9d331d3d67060f6d6f9450f06e5be1a411b /vespajlib | |
parent | 7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff) |
Support embedding an array to a mixed 2d tensor
Diffstat (limited to 'vespajlib')
-rw-r--r-- | vespajlib/abi-spec.json | 3 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java | 8 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java | 30 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorType.java | 18 |
4 files changed, 39 insertions, 20 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 7f4a19b029d..418f3ed5911 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -976,6 +976,7 @@ "public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)", "public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])", "public com.yahoo.tensor.MixedTensor build()", + "public static com.yahoo.tensor.MixedTensor$BoundBuilder of(com.yahoo.tensor.TensorType)", "public bridge synthetic com.yahoo.tensor.Tensor build()" ], "fields" : [ ] @@ -1026,6 +1027,7 @@ "public com.yahoo.tensor.MixedTensor build()", "public void trackBounds(com.yahoo.tensor.TensorAddress)", "public com.yahoo.tensor.TensorType createBoundType()", + "public static com.yahoo.tensor.MixedTensor$UnboundBuilder of(com.yahoo.tensor.TensorType)", "public bridge synthetic com.yahoo.tensor.Tensor build()" ], "fields" : [ ] @@ -1466,6 +1468,7 @@ "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", "public com.yahoo.tensor.TensorType$Value valueType()", "public com.yahoo.tensor.TensorType mappedSubtype()", + "public com.yahoo.tensor.TensorType indexedSubtype()", "public int rank()", "public java.util.List dimensions()", "public java.util.Set dimensionNames()", diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 2027dcfb60f..33e83c00e74 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -305,6 +305,10 @@ public class MixedTensor implements Tensor { return new MixedTensor(type, builder, indexBuilder.build()); } + public static BoundBuilder of(TensorType type) { + return new BoundBuilder(type); + } + } /** @@ -371,6 +375,10 @@ public class MixedTensor implements Tensor { return typeBuilder.build(); } + public static UnboundBuilder of(TensorType type) { + return new UnboundBuilder(type); + } + } /** diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 5636150bca1..d5c3b1340f1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -13,21 +13,8 @@ import java.util.stream.Collectors; * @author bratseth */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - private static String [] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); - } - return asStrings; - } - private static String asString(int index) { - return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[index] : String.valueOf(index); - } - private static String asString(long index) { - return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } + private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); public static TensorAddress of(String[] labels) { return new StringTensorAddress(labels); @@ -86,8 +73,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public boolean equals(Object o) { if (o == this) return true; - if ( ! (o instanceof TensorAddress)) return false; - TensorAddress other = (TensorAddress)o; + if ( ! (o instanceof TensorAddress other)) return false; if (other.size() != this.size()) return false; for (int i = 0; i < this.size(); i++) if ( ! Objects.equals(this.label(i), other.label(i))) @@ -115,6 +101,18 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return "'" + label + "'"; } + private static String[] createSmallIndexesAsStrings(int count) { + String [] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + + private static String asString(long index) { + return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + private static final class StringTensorAddress extends TensorAddress { private final String[] labels; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 36693280183..57d276f278e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -88,6 +88,7 @@ public class TensorType { private final List<Dimension> dimensions; private final TensorType mappedSubtype; + private final TensorType indexedSubtype; public TensorType(Value valueType, Collection<Dimension> dimensions) { this.valueType = valueType; @@ -95,12 +96,18 @@ public class TensorType { Collections.sort(dimensionList); this.dimensions = List.copyOf(dimensionList); - if (dimensionList.stream().allMatch(d -> d.isIndexed())) + if (dimensionList.stream().allMatch(d -> d.isIndexed())) { mappedSubtype = empty; - else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) + indexedSubtype = this; + } + else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) { mappedSubtype = this; - else - mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList()); + indexedSubtype = empty; + } + else { + mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList()); + indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList()); + } } static public Value combinedValueType(TensorType ... types) { @@ -135,6 +142,9 @@ public class TensorType { /** The type representing the mapped subset of dimensions of this. */ public TensorType mappedSubtype() { return mappedSubtype; } + /** The type representing the indexed subset of dimensions of this. */ + public TensorType indexedSubtype() { return indexedSubtype; } + /** Returns the number of dimensions of this: dimensions().size() */ public int rank() { return dimensions.size(); } |