aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 11:23:41 +0100
committerHenning Baldersheim <balder@yahoo-inc.com>2024-01-18 11:23:41 +0100
commit3af27a9e69f4a9be0d5029394bc4ea4828081c6f (patch)
tree7eb2882585b119821351cca6ffe69dba794d0cad /vespajlib
parente9f4d3f1dec104acaac8b83459a6d7d4656f33ad (diff)
- Make an tensor.impl package that can be used from other tensor packages allowing bypass of defensive strategies
in the public interfaces. - Move private static inner classes TensorAddress.NumericTensorAddress and TensorAddress.StringTensorAddress to tensor.impl package. - Use the StringTensorAddress.of from Reduce to avoid defensive array copy.
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java94
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java59
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java52
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java6
5 files changed, 125 insertions, 89 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
index a1cb278c75a..f841b7757fb 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java
@@ -1,10 +1,12 @@
// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.tensor;
+import com.yahoo.tensor.impl.NumericTensorAddress;
+import com.yahoo.tensor.impl.StringTensorAddress;
+
import java.util.Arrays;
import java.util.Objects;
import java.util.Optional;
-import java.util.stream.Collectors;
/**
* An immutable address to a tensor cell. This simply supplies a value to each dimension
@@ -14,18 +16,16 @@ import java.util.stream.Collectors;
*/
public abstract class TensorAddress implements Comparable<TensorAddress> {
- private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
-
public static TensorAddress of(String[] labels) {
- return new StringTensorAddress(labels);
+ return StringTensorAddress.of(labels);
}
public static TensorAddress ofLabels(String ... labels) {
- return new StringTensorAddress(labels);
+ return StringTensorAddress.of(labels);
}
public static TensorAddress of(long ... labels) {
- return new NumericTensorAddress(labels);
+ return NumericTensorAddress.of(labels);
}
/** Returns the number of labels in this */
@@ -101,88 +101,6 @@ public abstract class TensorAddress implements Comparable<TensorAddress> {
return "'" + label + "'";
}
- private static String[] createSmallIndexesAsStrings(int count) {
- String [] asStrings = new String[count];
- for (int i = 0; i < count; i++) {
- asStrings[i] = String.valueOf(i);
- }
- return asStrings;
- }
-
- private static String asString(long index) {
- return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
- }
-
- private static final class StringTensorAddress extends TensorAddress {
-
- private final String[] labels;
-
- private StringTensorAddress(String ... labels) {
- this.labels = Arrays.copyOf(labels, labels.length);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return labels[i]; }
-
- @Override
- public long numericLabel(int i) {
- try {
- return Long.parseLong(labels[i]);
- }
- catch (NumberFormatException e) {
- throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'");
- }
- }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- String[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = TensorAddress.asString(label);
- return new StringTensorAddress(labels);
- }
-
-
- @Override
- public String toString() {
- return "cell address (" + String.join(",", labels) + ")";
- }
-
- }
-
- private static final class NumericTensorAddress extends TensorAddress {
-
- private final long[] labels;
-
- private NumericTensorAddress(long[] labels) {
- this.labels = Arrays.copyOf(labels, labels.length);
- }
-
- @Override
- public int size() { return labels.length; }
-
- @Override
- public String label(int i) { return TensorAddress.asString(labels[i]); }
-
- @Override
- public long numericLabel(int i) { return labels[i]; }
-
- @Override
- public TensorAddress withLabel(int index, long label) {
- long[] labels = Arrays.copyOf(this.labels, this.labels.length);
- labels[index] = label;
- return new NumericTensorAddress(labels);
- }
-
- @Override
- public String toString() {
- return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")";
- }
-
- }
-
/** Builder of a tensor address */
public static class Builder {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 5171cf1e472..fe20c41174a 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
+import com.yahoo.tensor.impl.StringTensorAddress;
import java.util.ArrayList;
import java.util.Collections;
@@ -164,7 +165,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
int reducedLabelIndex = 0;
for (int toKeep : indexesToKeep)
reducedLabels[reducedLabelIndex++] = address.label(toKeep);
- return TensorAddress.of(reducedLabels);
+ return StringTensorAddress.unsafeOf(reducedLabels);
}
private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java
new file mode 100644
index 00000000000..983074c9c90
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java
@@ -0,0 +1,59 @@
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Arrays;
+import java.util.stream.Collectors;
+
+public final class NumericTensorAddress extends TensorAddress {
+ private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000);
+
+ private final long[] labels;
+
+ private static String[] createSmallIndexesAsStrings(int count) {
+ String [] asStrings = new String[count];
+ for (int i = 0; i < count; i++) {
+ asStrings[i] = String.valueOf(i);
+ }
+ return asStrings;
+ }
+
+ private NumericTensorAddress(long[] labels) {
+ this.labels = labels;
+ }
+
+ public static NumericTensorAddress of(long ... labels) {
+ return new NumericTensorAddress(Arrays.copyOf(labels, labels.length));
+ }
+
+ public static NumericTensorAddress unsafeOf(long ... labels) {
+ return new NumericTensorAddress(labels);
+ }
+
+ @Override
+ public int size() { return labels.length; }
+
+ @Override
+ public String label(int i) { return asString(labels[i]); }
+
+ @Override
+ public long numericLabel(int i) { return labels[i]; }
+
+ @Override
+ public TensorAddress withLabel(int index, long label) {
+ long[] labels = Arrays.copyOf(this.labels, this.labels.length);
+ labels[index] = label;
+ return new NumericTensorAddress(labels);
+ }
+
+ @Override
+ public String toString() {
+ return "cell address (" + Arrays.stream(labels).mapToObj(NumericTensorAddress::asString).collect(Collectors.joining(",")) + ")";
+ }
+
+ public static String asString(long index) {
+ return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index);
+ }
+
+}
+
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java
new file mode 100644
index 00000000000..ca54494a19c
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java
@@ -0,0 +1,52 @@
+package com.yahoo.tensor.impl;
+
+import com.yahoo.tensor.TensorAddress;
+
+import java.util.Arrays;
+
+public final class StringTensorAddress extends TensorAddress {
+
+ private final String[] labels;
+
+ private StringTensorAddress(String [] labels) {
+ this.labels = labels;
+ }
+
+ public static StringTensorAddress of(String[] labels) {
+ return new StringTensorAddress(Arrays.copyOf(labels, labels.length));
+ }
+
+ public static StringTensorAddress unsafeOf(String[] labels) {
+ return new StringTensorAddress(labels);
+ }
+
+ @Override
+ public int size() { return labels.length; }
+
+ @Override
+ public String label(int i) { return labels[i]; }
+
+ @Override
+ public long numericLabel(int i) {
+ try {
+ return Long.parseLong(labels[i]);
+ }
+ catch (NumberFormatException e) {
+ throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'");
+ }
+ }
+
+ @Override
+ public TensorAddress withLabel(int index, long label) {
+ String[] labels = Arrays.copyOf(this.labels, this.labels.length);
+ labels[index] = NumericTensorAddress.asString(label);
+ return new StringTensorAddress(labels);
+ }
+
+
+ @Override
+ public String toString() {
+ return "cell address (" + String.join(",", labels) + ")";
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java
new file mode 100644
index 00000000000..6b004bf2d02
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java
@@ -0,0 +1,6 @@
+// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+
+@ExportPackage
+package com.yahoo.tensor.impl;
+
+import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file