diff options
Diffstat (limited to 'vespajlib/src/main/java/com')
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/Tensor.java | 16 | ||||
-rw-r--r-- | vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java | 21 |
2 files changed, 27 insertions, 10 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 2ad3212c424..f97e137af83 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -243,7 +243,7 @@ public interface Tensor { default Tensor notEqual(Tensor argument) { return join(argument, (a, b) -> ( a != b ? 1.0 : 0.0)); } default Tensor approxEqual(Tensor argument) { return join(argument, (a, b) -> ( approxEquals(a,b) ? 1.0 : 0.0)); } default Tensor bit(Tensor argument) { return join(argument, (a,b) -> ((int)b < 8 && (int)b >= 0 && ((int)a & (1 << (int)b)) != 0) ? 1.0 : 0.0); } - default Tensor hamming(Tensor argument) { return join(argument, (a,b) -> Hamming.hamming(a,b)); } + default Tensor hamming(Tensor argument) { return join(argument, Hamming::hamming); } default Tensor avg() { return avg(Collections.emptyList()); } default Tensor avg(String dimension) { return avg(Collections.singletonList(dimension)); } @@ -466,9 +466,12 @@ public interface Tensor { class Cell implements Map.Entry<TensorAddress, Double> { private final TensorAddress address; - private final Number value; + private final double value; Cell(TensorAddress address, Number value) { + this(address, value.doubleValue()); + } + Cell(TensorAddress address, double value) { this.address = address; this.value = value; } @@ -485,7 +488,7 @@ public interface Tensor { /** Returns the value as a double */ @Override - public Double getValue() { return value.doubleValue(); } + public Double getValue() { return value; } /** Returns the value as a float */ public float getFloatValue() { return getValue().floatValue(); } @@ -501,8 +504,7 @@ public interface Tensor { @Override public boolean equals(Object o) { if (o == this) return true; - if ( ! ( o instanceof Map.Entry)) return false; - Map.Entry<?,?> other = (Map.Entry)o; + if ( ! ( o instanceof Map.Entry<?,?> other)) return false; if ( ! this.getValue().equals(other.getValue())) return false; if ( ! this.getKey().equals(other.getKey())) return false; return true; @@ -531,7 +533,7 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type) { - boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed()); + boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); @@ -543,7 +545,7 @@ public interface Tensor { /** Creates a suitable builder for the given type */ static Builder of(TensorType type, DimensionSizes dimensionSizes) { - boolean containsIndexed = type.dimensions().stream().anyMatch(d -> d.isIndexed()); + boolean containsIndexed = type.dimensions().stream().anyMatch(TensorType.Dimension::isIndexed); boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); if (containsIndexed && containsMapped) return MixedTensor.Builder.of(type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index 342aca5fb3d..5636150bca1 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -13,6 +13,21 @@ import java.util.stream.Collectors; * @author bratseth */ public abstract class TensorAddress implements Comparable<TensorAddress> { + private static final String [] SMALL_INDEXES = createSmallIndexesAsStrings(1000); + + private static String [] createSmallIndexesAsStrings(int count) { + String [] asStrings = new String[count]; + for (int i = 0; i < count; i++) { + asStrings[i] = String.valueOf(i); + } + return asStrings; + } + private static String asString(int index) { + return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[index] : String.valueOf(index); + } + private static String asString(long index) { + return (index < SMALL_INDEXES.length) ? SMALL_INDEXES[(int)index] : String.valueOf(index); + } public static TensorAddress of(String[] labels) { return new StringTensorAddress(labels); @@ -127,7 +142,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public TensorAddress withLabel(int index, long label) { String[] labels = Arrays.copyOf(this.labels, this.labels.length); - labels[index] = String.valueOf(label); + labels[index] = TensorAddress.asString(label); return new StringTensorAddress(labels); } @@ -151,7 +166,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { public int size() { return labels.length; } @Override - public String label(int i) { return String.valueOf(labels[i]); } + public String label(int i) { return TensorAddress.asString(labels[i]); } @Override public long numericLabel(int i) { return labels[i]; } @@ -165,7 +180,7 @@ public abstract class TensorAddress implements Comparable<TensorAddress> { @Override public String toString() { - return "cell address (" + Arrays.stream(labels).mapToObj(String::valueOf).collect(Collectors.joining(",")) + ")"; + return "cell address (" + Arrays.stream(labels).mapToObj(TensorAddress::asString).collect(Collectors.joining(",")) + ")"; } } |