summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
commit35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (patch)
treeda3ca9d331d3d67060f6d6f9450f06e5be1a411b /vespajlib
parent7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff)
Support embedding an array to a mixed 2d tensor
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
4 files changed, 39 insertions, 20 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 7f4a19b029d..418f3ed5911 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -976,6 +976,7 @@
"public com.yahoo.tensor.Tensor$Builder cell(com.yahoo.tensor.TensorAddress, double)",
"public com.yahoo.tensor.Tensor$Builder block(com.yahoo.tensor.TensorAddress, double[])",
"public com.yahoo.tensor.MixedTensor build()",
+ "public static com.yahoo.tensor.MixedTensor$BoundBuilder of(com.yahoo.tensor.TensorType)",
"public bridge synthetic com.yahoo.tensor.Tensor build()"
],
"fields" : [ ]
@@ -1026,6 +1027,7 @@
"public com.yahoo.tensor.MixedTensor build()",
"public void trackBounds(com.yahoo.tensor.TensorAddress)",
"public com.yahoo.tensor.TensorType createBoundType()",
+ "public static com.yahoo.tensor.MixedTensor$UnboundBuilder of(com.yahoo.tensor.TensorType)",
"public bridge synthetic com.yahoo.tensor.Tensor build()"
],
"fields" : [ ]
@@ -1466,6 +1468,7 @@
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
"public com.yahoo.tensor.TensorType$Value valueType()",
"public com.yahoo.tensor.TensorType mappedSubtype()",
+ "public com.yahoo.tensor.TensorType indexedSubtype()",
"public int rank()",
"public java.util.List dimensions()",
"public java.util.Set dimensionNames()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 2027dcfb60f..33e83c00e74 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -305,6 +305,10 @@ public class MixedTensor implements Tensor {
return new MixedTensor(type, builder, indexBuilder.build());
}
+ public static BoundBuilder of(TensorType type) {
+ return new BoundBuilder(type);
+ }
+
}
/**
@@ -371,6 +375,10 @@ public class MixedTensor implements Tensor {
return typeBuilder.build();
}
+ public static UnboundBuilder of(TensorType type) {
+ return new UnboundBuilder(type);
+ }
+
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 5636150bca1..d5c3b1340f1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -13,21 +13,8 @@ import java.util.stream.Collectors;
* @author bratseth
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
- private static String [] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
- }
- return asStrings;
- }
- private static String asString(int index) {
- return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[index] : String.valueOf(index);
- }
- private static String asString(long index) {
- return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
+ private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
public static TensorAddress of(String[] labels) {
return new StringTensorAddress(labels);
@@ -86,8 +73,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
@Override
public boolean equals(Object o) {
if (o == this) return true;
- if ( ! (o instanceof TensorAddress)) return false;
- TensorAddress other = (TensorAddress)o;
+ if ( ! (o instanceof TensorAddress other)) return false;
if (other.size() != this.size()) return false;
for (int i = 0; i < this.size(); i++)
if ( ! Objects.equals(this.label(i), other.label(i)))
@@ -115,6 +101,18 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return "'" + label + "'";
}
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String [] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private static String asString(long index) {
+ return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
private static final class StringTensorAddress extends TensorAddress {
private final String[] labels;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 36693280183..57d276f278e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -88,6 +88,7 @@ public class TensorType {
private final List<Dimension> dimensions;
private final TensorType mappedSubtype;
+ private final TensorType indexedSubtype;
public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
@@ -95,12 +96,18 @@ public class TensorType {
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
- if (dimensionList.stream().allMatch(d -> d.isIndexed()))
+ if (dimensionList.stream().allMatch(d -> d.isIndexed())) {
mappedSubtype = empty;
- else if (dimensionList.stream().noneMatch(d -> d.isIndexed()))
+ indexedSubtype = this;
+ }
+ else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) {
mappedSubtype = this;
- else
- mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList());
+ indexedSubtype = empty;
+ }
+ else {
+ mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList());
+ indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList());
+ }
}
static public Value combinedValueType(TensorType ... types) {
@@ -135,6 +142,9 @@ public class TensorType {
/** The type representing the mapped subset of dimensions of this. */
public TensorType mappedSubtype() { return mappedSubtype; }
+ /** The type representing the indexed subset of dimensions of this. */
+ public TensorType indexedSubtype() { return indexedSubtype; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }