diff options
author | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 11:23:41 +0100 |
---|---|---|
committer | Henning Baldersheim <balder@yahoo-inc.com> | 2024-01-18 11:23:41 +0100 |
commit | 3af27a9e69f4a9be0d5029394bc4ea4828081c6f (patch) | |
tree | 7eb2882585b119821351cca6ffe69dba794d0cad /vespajlib/src/main/java/com/yahoo | |
parent | e9f4d3f1dec104acaac8b83459a6d7d4656f33ad (diff) |
- Make an tensor.impl package that can be used from other tensor packages allowing bypass of defensive strategies
in the public interfaces.
- Move private static inner classes TensorAddress.NumericTensorAddress and TensorAddress.StringTensorAddress
to tensor.impl package.
- Use the StringTensorAddress.of from Reduce to avoid defensive array copy.
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo')
5 files changed, 125 insertions, 89 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index a1cb278c75a..f841b7757fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,10 +1,12 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.impl.NumericTensorAddress; +import com.yahoo.tensor.impl.StringTensorAddress; + import java.util.Arrays; import java.util.Objects; import java.util.Optional; -import java.util.stream.Collectors; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -14,18 +16,16 @@ import java.util.stream.Collectors; */ public abstract class TensorAddress implements Comparable<TensorAddress> { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - public static TensorAddress of(String[] labels) { - return new StringTensorAddress(labels); + return StringTensorAddress.of(labels); } public static TensorAddress ofLabels(String ... labels) { - return new StringTensorAddress(labels); + return StringTensorAddress.of(labels); } public static TensorAddress of(long ... labels) { - return new NumericTensorAddress(labels); + return NumericTensorAddress.of(labels); } /** Returns the number of labels in this */ @@ -101,88 +101,6 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { return "'" + label + "'"; } - private static String[] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); - } - return asStrings; - } - - private static String asString(long index) { - return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } - - private static final class StringTensorAddress extends TensorAddress { - - private final String[] labels; - - private StringTensorAddress(String ... labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return labels[i]; } - - @Override - public long numericLabel(int i) { - try { - return Long.parseLong(labels[i]); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); - } - } - - @Override - public TensorAddress withLabel(int index, long label) { - String[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = TensorAddress.asString(label); - return new StringTensorAddress(labels); - } - - - @Override - public String toString() { - return "cell address (" + String.join(",", labels) + ")"; - } - - } - - private static final class NumericTensorAddress extends TensorAddress { - - private final long[] labels; - - private NumericTensorAddress(long[] labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return TensorAddress.asString(labels[i]); } - - @Override - public long numericLabel(int i) { return labels[i]; } - - @Override - public TensorAddress withLabel(int index, long label) { - long[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = label; - return new NumericTensorAddress(labels); - } - - @Override - public String toString() { - return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")"; - } - - } - /** Builder of a tensor address */ public static class Builder { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 5171cf1e472..fe20c41174a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; import java.util.Collections; @@ -164,7 +165,7 @@ public class Reduce<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET int reducedLabelIndex = 0; for (int toKeep : indexesToKeep) reducedLabels[reducedLabelIndex++] = address.label(toKeep); - return TensorAddress.of(reducedLabels); + return StringTensorAddress.unsafeOf(reducedLabels); } private static Tensor reduceAllGeneral(Tensor argument, Aggregator aggregator) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java new file mode 100644 index 00000000000..983074c9c90 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java @@ -0,0 +1,59 @@ +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import java.util.Arrays; +import java.util.stream.Collectors; + +public final class NumericTensorAddress extends TensorAddress { + private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); + + private final long[] labels; + + private static String[] createSmallIndexesAsStrings(int count) { + String [] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + + private NumericTensorAddress(long[] labels) { + this.labels = labels; + } + + public static NumericTensorAddress of(long ... labels) { + return new NumericTensorAddress(Arrays.copyOf(labels, labels.length)); + } + + public static NumericTensorAddress unsafeOf(long ... labels) { + return new NumericTensorAddress(labels); + } + + @Override + public int size() { return labels.length; } + + @Override + public String label(int i) { return asString(labels[i]); } + + @Override + public long numericLabel(int i) { return labels[i]; } + + @Override + public TensorAddress withLabel(int index, long label) { + long[] labels = Arrays.copyOf(this.labels, this.labels.length); + labels[index] = label; + return new NumericTensorAddress(labels); + } + + @Override + public String toString() { + return "cell address (" + Arrays.stream(labels).mapToObj(NumericTensorAddress::asString).collect(Collectors.joining(",")) + ")"; + } + + public static String asString(long index) { + return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + +} + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java new file mode 100644 index 00000000000..ca54494a19c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java @@ -0,0 +1,52 @@ +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import java.util.Arrays; + +public final class StringTensorAddress extends TensorAddress { + + private final String[] labels; + + private StringTensorAddress(String [] labels) { + this.labels = labels; + } + + public static StringTensorAddress of(String[] labels) { + return new StringTensorAddress(Arrays.copyOf(labels, labels.length)); + } + + public static StringTensorAddress unsafeOf(String[] labels) { + return new StringTensorAddress(labels); + } + + @Override + public int size() { return labels.length; } + + @Override + public String label(int i) { return labels[i]; } + + @Override + public long numericLabel(int i) { + try { + return Long.parseLong(labels[i]); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); + } + } + + @Override + public TensorAddress withLabel(int index, long label) { + String[] labels = Arrays.copyOf(this.labels, this.labels.length); + labels[index] = NumericTensorAddress.asString(label); + return new StringTensorAddress(labels); + } + + + @Override + public String toString() { + return "cell address (" + String.join(",", labels) + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java new file mode 100644 index 00000000000..6b004bf2d02 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java @@ -0,0 +1,6 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.tensor.impl; + +import com.yahoo.osgi.annotation.ExportPackage;
\ No newline at end of file |