From 3af27a9e69f4a9be0d5029394bc4ea4828081c6f Mon Sep 17 00:00:00 2001 From: Henning Baldersheim Date: Thu, 18 Jan 2024 11:23:41 +0100 Subject: - Make an tensor.impl package that can be used from other tensor packages allowing bypass of defensive strategies in the public interfaces. - Move private static inner classes TensorAddress.NumericTensorAddress and TensorAddress.StringTensorAddress to tensor.impl package. - Use the StringTensorAddress.of from Reduce to avoid defensive array copy. --- .../main/java/com/yahoo/tensor/TensorAddress.java | 94 ++-------------------- .../java/com/yahoo/tensor/functions/Reduce.java | 3 +- .../yahoo/tensor/impl/NumericTensorAddress.java | 59 ++++++++++++++ .../com/yahoo/tensor/impl/StringTensorAddress.java | 52 ++++++++++++ .../java/com/yahoo/tensor/impl/package-info.java | 6 ++ 5 files changed, 125 insertions(+), 89 deletions(-) create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/impl/NumericTensorAddress.java create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java create mode 100644 vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index a1cb278c75a..f841b7757fb 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -1,10 +1,12 @@ // Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.tensor; +import com.yahoo.tensor.impl.NumericTensorAddress; +import com.yahoo.tensor.impl.StringTensorAddress; + import java.util.Arrays; import java.util.Objects; import java.util.Optional; -import java.util.stream.Collectors; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -14,18 +16,16 @@ import java.util.stream.Collectors; */ public abstract class TensorAddress implements Comparable { - private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); - public static TensorAddress of(String[] labels) { - return new StringTensorAddress(labels); + return StringTensorAddress.of(labels); } public static TensorAddress ofLabels(String ... labels) { - return new StringTensorAddress(labels); + return StringTensorAddress.of(labels); } public static TensorAddress of(long ... labels) { - return new NumericTensorAddress(labels); + return NumericTensorAddress.of(labels); } /** Returns the number of labels in this */ @@ -101,88 +101,6 @@ public abstract class TensorAddress implements Comparable { return "'" + label + "'"; } - private static String[] createSmallIndexesAsStrings(int count) { - String [] asStrings = new String[count]; - for (int i = 0; i < count; i++) { - asStrings[i] = String.valueOf(i); - } - return asStrings; - } - - private static String asString(long index) { - return ((index >= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); - } - - private static final class StringTensorAddress extends TensorAddress { - - private final String[] labels; - - private StringTensorAddress(String ... labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return labels[i]; } - - @Override - public long numericLabel(int i) { - try { - return Long.parseLong(labels[i]); - } - catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); - } - } - - @Override - public TensorAddress withLabel(int index, long label) { - String[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = TensorAddress.asString(label); - return new StringTensorAddress(labels); - } - - - @Override - public String toString() { - return "cell address (" + String.join(",", labels) + ")"; - } - - } - - private static final class NumericTensorAddress extends TensorAddress { - - private final long[] labels; - - private NumericTensorAddress(long[] labels) { - this.labels = Arrays.copyOf(labels, labels.length); - } - - @Override - public int size() { return labels.length; } - - @Override - public String label(int i) { return TensorAddress.asString(labels[i]); } - - @Override - public long numericLabel(int i) { return labels[i]; } - - @Override - public TensorAddress withLabel(int index, long label) { - long[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = label; - return new NumericTensorAddress(labels); - } - - @Override - public String toString() { - return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")"; - } - - } - /** Builder of a tensor address */ public static class Builder { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 5171cf1e472..fe20c41174a 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; +import com.yahoo.tensor.impl.StringTensorAddress; import java.util.ArrayList; import java.util.Collections; @@ -164,7 +165,7 @@ public class Reduce extends PrimitiveTensorFunction= 0) && (index < SMALL_INDEXES.length)) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } + +} + diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java new file mode 100644 index 00000000000..ca54494a19c --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/StringTensorAddress.java @@ -0,0 +1,52 @@ +package com.yahoo.tensor.impl; + +import com.yahoo.tensor.TensorAddress; + +import java.util.Arrays; + +public final class StringTensorAddress extends TensorAddress { + + private final String[] labels; + + private StringTensorAddress(String [] labels) { + this.labels = labels; + } + + public static StringTensorAddress of(String[] labels) { + return new StringTensorAddress(Arrays.copyOf(labels, labels.length)); + } + + public static StringTensorAddress unsafeOf(String[] labels) { + return new StringTensorAddress(labels); + } + + @Override + public int size() { return labels.length; } + + @Override + public String label(int i) { return labels[i]; } + + @Override + public long numericLabel(int i) { + try { + return Long.parseLong(labels[i]); + } + catch (NumberFormatException e) { + throw new IllegalArgumentException("Expected an integer label in " + this + " at position " + i + " but got '" + labels[i] + "'"); + } + } + + @Override + public TensorAddress withLabel(int index, long label) { + String[] labels = Arrays.copyOf(this.labels, this.labels.length); + labels[index] = NumericTensorAddress.asString(label); + return new StringTensorAddress(labels); + } + + + @Override + public String toString() { + return "cell address (" + String.join(",", labels) + ")"; + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java new file mode 100644 index 00000000000..6b004bf2d02 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/tensor/impl/package-info.java @@ -0,0 +1,6 @@ +// Copyright Vespa.ai. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. + +@ExportPackage +package com.yahoo.tensor.impl; + +import com.yahoo.osgi.annotation.ExportPackage; \ No newline at end of file -- cgit v1.2.3