summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
committerJon Bratseth <bratseth@gmail.com>2023-01-27 09:38:43 +0100
commit35a1ad6eb3d59c9945cdfe8486f57e3f75b3091c (patch)
treeda3ca9d331d3d67060f6d6f9450f06e5be1a411b /vespajlib/src/main/java/com
parent7f923d43611071bf41fcac0c0ccac9eda16bb00c (diff)
Support embedding an array to a mixed 2d tensor
Diffstat (limited to 'vespajlib/src/main/java/com')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java30
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java18
3 files changed, 36 insertions, 20 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 2027dcfb60f..33e83c00e74 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -305,6 +305,10 @@ public class MixedTensor implements Tensor {
return new MixedTensor(type, builder, indexBuilder.build());
}
+ public static BoundBuilder of(TensorType type) {
+ return new BoundBuilder(type);
+ }
+
}
/**
@@ -371,6 +375,10 @@ public class MixedTensor implements Tensor {
return typeBuilder.build();
}
+ public static UnboundBuilder of(TensorType type) {
+ return new UnboundBuilder(type);
+ }
+
}
/**
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index 5636150bca1..d5c3b1340f1 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -13,21 +13,8 @@ import java.util.stream.Collectors;
* @author bratseth
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
- private static String [] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
- }
- return asStrings;
- }
- private static String asString(int index) {
- return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[index] : String.valueOf(index);
- }
- private static String asString(long index) {
- return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
+ private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
public static TensorAddress of(String[] labels) {
return new StringTensorAddress(labels);
@@ -86,8 +73,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
@Override
public boolean equals(Object o) {
if (o == this) return true;
- if ( ! (o instanceof TensorAddress)) return false;
- TensorAddress other = (TensorAddress)o;
+ if ( ! (o instanceof TensorAddress other)) return false;
if (other.size() != this.size()) return false;
for (int i = 0; i < this.size(); i++)
if ( ! Objects.equals(this.label(i), other.label(i)))
@@ -115,6 +101,18 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return "'" + label + "'";
}
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String [] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private static String asString(long index) {
+ return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
private static final class StringTensorAddress extends TensorAddress {
private final String[] labels;
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 36693280183..57d276f278e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -88,6 +88,7 @@ public class TensorType {
private final List<Dimension> dimensions;
private final TensorType mappedSubtype;
+ private final TensorType indexedSubtype;
public TensorType(Value valueType, Collection<Dimension> dimensions) {
this.valueType = valueType;
@@ -95,12 +96,18 @@ public class TensorType {
Collections.sort(dimensionList);
this.dimensions = List.copyOf(dimensionList);
- if (dimensionList.stream().allMatch(d -> d.isIndexed()))
+ if (dimensionList.stream().allMatch(d -> d.isIndexed())) {
mappedSubtype = empty;
- else if (dimensionList.stream().noneMatch(d -> d.isIndexed()))
+ indexedSubtype = this;
+ }
+ else if (dimensionList.stream().noneMatch(d -> d.isIndexed())) {
mappedSubtype = this;
- else
- mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> ! d.isIndexed()).toList());
+ indexedSubtype = empty;
+ }
+ else {
+ mappedSubtype = new TensorType(valueType, dimensions.stream().filter(d -> !d.isIndexed()).toList());
+ indexedSubtype = new TensorType(valueType, dimensions.stream().filter(Dimension::isIndexed).toList());
+ }
}
static public Value combinedValueType(TensorType ... types) {
@@ -135,6 +142,9 @@ public class TensorType {
/** The type representing the mapped subset of dimensions of this. */
public TensorType mappedSubtype() { return mappedSubtype; }
+ /** The type representing the indexed subset of dimensions of this. */
+ public TensorType indexedSubtype() { return indexedSubtype; }
+
/** Returns the number of dimensions of this: dimensions().size() */
public int rank() { return dimensions.size(); }