From 35d59981840614bf4b877714ee88e273816c46d2 Mon Sep 17 00:00:00 2001 From: Jon Bratseth Date: Tue, 19 Dec 2017 23:02:04 +0100 Subject: Use longs for dimensions lengths in all API's This is to be able to support tensor dimensions with more than 2B elements in the future without API change. --- .../config/application/api/ApplicationPackage.java | 22 ++- .../prelude/searcher/ValidateSortingSearcher.java | 8 +- .../integration/tensorflow/TensorConverter.java | 2 +- .../rule/GeneratorLambdaFunctionNode.java | 10 +- .../src/main/javacc/RankingExpressionParser.jj | 2 +- .../main/java/com/yahoo/tensor/DimensionSizes.java | 18 +- .../main/java/com/yahoo/tensor/IndexedTensor.java | 194 +++++++++++---------- .../main/java/com/yahoo/tensor/MappedTensor.java | 4 +- .../main/java/com/yahoo/tensor/MixedTensor.java | 73 ++++---- .../main/java/com/yahoo/tensor/PartialAddress.java | 30 ++-- .../src/main/java/com/yahoo/tensor/Tensor.java | 10 +- .../main/java/com/yahoo/tensor/TensorAddress.java | 40 ++--- .../src/main/java/com/yahoo/tensor/TensorType.java | 18 +- .../java/com/yahoo/tensor/functions/Concat.java | 20 +-- .../main/java/com/yahoo/tensor/functions/Diag.java | 2 +- .../java/com/yahoo/tensor/functions/Generate.java | 6 +- .../main/java/com/yahoo/tensor/functions/Join.java | 18 +- .../java/com/yahoo/tensor/functions/Range.java | 2 +- .../yahoo/tensor/functions/ScalarFunctions.java | 36 ++-- .../tensor/serialization/DenseBinaryFormat.java | 6 +- .../tensor/serialization/MixedBinaryFormat.java | 18 +- .../tensor/serialization/SparseBinaryFormat.java | 13 +- .../test/java/com/yahoo/tensor/TensorTestCase.java | 7 +- 23 files changed, 279 insertions(+), 280 deletions(-) diff --git a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java index 480d4d05451..cd825767565 100644 --- a/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java +++ b/config-model-api/src/main/java/com/yahoo/config/application/api/ApplicationPackage.java @@ -3,11 +3,10 @@ package com.yahoo.config.application.api; import com.yahoo.config.provision.AllocatedHosts; import com.yahoo.config.provision.Version; -import com.yahoo.config.provision.Zone; import com.yahoo.config.provision.ZoneId; -import com.yahoo.path.Path; import com.yahoo.io.IOUtils; import com.yahoo.io.reader.NamedReader; +import com.yahoo.path.Path; import com.yahoo.text.XML; import com.yahoo.vespa.config.ConfigDefinitionKey; import org.w3c.dom.Element; @@ -15,8 +14,17 @@ import org.xml.sax.SAXException; import javax.xml.parsers.ParserConfigurationException; import javax.xml.transform.TransformerException; -import java.io.*; -import java.util.*; +import java.io.File; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; import java.util.jar.JarEntry; import java.util.jar.JarFile; @@ -229,9 +237,9 @@ public interface ApplicationPackage { throw new UnsupportedOperationException("This application package cannot write its metadata"); } - /** - * Returns the single host allocation info of this, or an empty map if no allocation is available - * + /** + * Returns the single host allocation info of this, or an empty map if no allocation is available + * * @deprecated please use #getAllocatedHosts */ // TODO: Remove on Vespa 7 diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java index 15a8a670a2e..8091397237d 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/ValidateSortingSearcher.java @@ -25,7 +25,7 @@ import static com.yahoo.prelude.querytransform.NormalizingSearcher.ACCENT_REMOVA * Check sorting specification makes sense to the search cluster before * passing it on to the backend. * - * @author Steinar Knutsen + * @author Steinar Knutsen */ @Before(PhaseNames.BACKEND) @After(ACCENT_REMOVAL) @@ -118,6 +118,7 @@ public class ValidateSortingSearcher extends Searcher { for (Sorting.FieldOrder f : l) { String name = f.getFieldName(); if ("[rank]".equals(name) || "[docid]".equals(name)) { + // built-in constants - ok } else if (names.containsKey(name)) { AttributesConfig.Attribute attrConfig = names.get(name); if (attrConfig != null) { @@ -166,18 +167,13 @@ public class ValidateSortingSearcher extends Searcher { locale = "en_US"; } - // getLogger().info("locale = " + locale + " attrConfig.sortlocale.value() = " + attrConfig.sortlocale.value() + " query.getLanguage() = " + query.getModel().getLanguage()); - // getLogger().info("locale = " + locale); - Sorting.UcaSorter.Strength strength = sorter.getStrength(); if (sorter.getStrength() == Sorting.UcaSorter.Strength.UNDEFINED) { strength = config2Strength(attrConfig.sortstrength()); } if ((sorter.getStrength() == Sorting.UcaSorter.Strength.UNDEFINED) || (sorter.getLocale() == null) || sorter.getLocale().isEmpty()) { - // getLogger().info("locale = " + locale + " strength = " + strength.toString()); sorter.setLocale(locale, strength); } - //getLogger().info("locale = " + locale + " strength = " + strength.toString() + "decompose = " + sorter.getDecomposition()); } } else { return ErrorMessage.createInvalidQueryParameter("The cluster " + getClusterName() + " has attribute config for field: " + name); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java index df43225c333..1960cf94591 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/TensorConverter.java @@ -26,7 +26,7 @@ public class TensorConverter { int dimensionIndex = 0; for (long dimensionSize : shape) { if (dimensionSize == 0) dimensionSize = 1; // TensorFlow ... - b.indexed("d" + (dimensionIndex++), (int) dimensionSize); + b.indexed("d" + (dimensionIndex++), dimensionSize); } return b.build(); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java index d366c9bfbe5..9da1ba40144 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/GeneratorLambdaFunctionNode.java @@ -1,7 +1,6 @@ // Copyright 2017 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.rule; -import com.google.common.collect.ImmutableList; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -10,7 +9,6 @@ import com.yahoo.tensor.TensorType; import java.util.Collections; import java.util.Deque; import java.util.List; -import java.util.function.*; /** * A tensor generating function, whose arguments are determined by a tensor type @@ -57,14 +55,14 @@ public class GeneratorLambdaFunctionNode extends CompositeNode { /** * Returns this as an operator which converts a list of integers into a double */ - public IntegerListToDoubleLambda asIntegerListToDoubleOperator() { - return new IntegerListToDoubleLambda(); + public LongListToDoubleLambda asLongListToDoubleOperator() { + return new LongListToDoubleLambda(); } - private class IntegerListToDoubleLambda implements java.util.function.Function, Double> { + private class LongListToDoubleLambda implements java.util.function.Function, Double> { @Override - public Double apply(List arguments) { + public Double apply(List arguments) { MapContext context = new MapContext(); for (int i = 0; i < type.dimensions().size(); i++) context.put(type.dimensions().get(i).name(), arguments.get(i)); diff --git a/searchlib/src/main/javacc/RankingExpressionParser.jj b/searchlib/src/main/javacc/RankingExpressionParser.jj index 7821ab88b86..541738db8e0 100755 --- a/searchlib/src/main/javacc/RankingExpressionParser.jj +++ b/searchlib/src/main/javacc/RankingExpressionParser.jj @@ -467,7 +467,7 @@ ExpressionNode tensorGenerate() : } { type = tensorTypeArgument() generator = expression() - { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asIntegerListToDoubleOperator())); } + { return new TensorFunctionNode(new Generate(type, new GeneratorLambdaFunctionNode(type, generator).asLongListToDoubleOperator())); } } ExpressionNode tensorRange() : diff --git a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java index f6237a1977a..01bf082d32f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/DimensionSizes.java @@ -13,7 +13,7 @@ import java.util.Arrays; @Beta public final class DimensionSizes { - private final int[] sizes; + private final long[] sizes; private DimensionSizes(Builder builder) { this.sizes = builder.sizes; @@ -25,15 +25,15 @@ public final class DimensionSizes { * * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one */ - public int size(int dimensionIndex) { return sizes[dimensionIndex]; } + public long size(int dimensionIndex) { return sizes[dimensionIndex]; } /** Returns the number of dimensions this provides the size of */ public int dimensions() { return sizes.length; } /** Returns the product of the sizes of this */ - public int totalSize() { - int productSize = 1; - for (int dimensionSize : sizes ) + public long totalSize() { + long productSize = 1; + for (long dimensionSize : sizes ) productSize *= dimensionSize; return productSize; } @@ -54,13 +54,13 @@ public final class DimensionSizes { */ public final static class Builder { - private int[] sizes; + private long[] sizes; public Builder(int dimensions) { - this.sizes = new int[dimensions]; + this.sizes = new long[dimensions]; } - public Builder set(int dimensionIndex, int size) { + public Builder set(int dimensionIndex, long size) { sizes[dimensionIndex] = size; return this; } @@ -70,7 +70,7 @@ public final class DimensionSizes { * * @throws IndexOutOfBoundsException if the index is larger than the number of dimensions in this tensor minus one */ - public int size(int dimensionIndex) { return sizes[dimensionIndex]; } + public long size(int dimensionIndex) { return sizes[dimensionIndex]; } /** Returns the number of dimensions this provides the size of */ public int dimensions() { return sizes.length; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java index 6b0d769de9f..7130c053e9f 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/IndexedTensor.java @@ -38,7 +38,7 @@ public class IndexedTensor implements Tensor { } @Override - public int size() { + public long size() { return values.length; } @@ -55,10 +55,10 @@ public class IndexedTensor implements Tensor { /** Returns an iterator over all the cells in this tensor which matches the given partial address */ // TODO: Move up to Tensor and create a mixed tensor which can implement it (and subspace iterators) efficiently public SubspaceIterator cellIterator(PartialAddress partialAddress, DimensionSizes iterationSizes) { - int[] startAddress = new int[type().dimensions().size()]; + long[] startAddress = new long[type().dimensions().size()]; List iterateDimensions = new ArrayList<>(); for (int i = 0; i < type().dimensions().size(); i++) { - int partialAddressLabel = partialAddress.intLabel(type.dimensions().get(i).name()); + long partialAddressLabel = partialAddress.numericLabel(type.dimensions().get(i).name()); if (partialAddressLabel >= 0) // iterate at this label startAddress[i] = partialAddressLabel; else // iterate over this dimension @@ -102,8 +102,8 @@ public class IndexedTensor implements Tensor { * @param indexes the indexes into the dimensions of this. Must be one number per dimension of this * @throws IndexOutOfBoundsException if any of the indexes are out of bound or a wrong number of indexes are given */ - public double get(int ... indexes) { - return values[toValueIndex(indexes, dimensionSizes)]; + public double get(long ... indexes) { + return values[(int)toValueIndex(indexes, dimensionSizes)]; } /** Returns the value at this address, or NaN if there is no value at this address */ @@ -111,20 +111,20 @@ public class IndexedTensor implements Tensor { public double get(TensorAddress address) { // optimize for fast lookup within bounds: try { - return values[toValueIndex(address, dimensionSizes)]; + return values[(int)toValueIndex(address, dimensionSizes)]; } catch (IndexOutOfBoundsException e) { return Double.NaN; } } - private double get(int valueIndex) { return values[valueIndex]; } + private double get(long valueIndex) { return values[(int)valueIndex]; } - private static int toValueIndex(int[] indexes, DimensionSizes sizes) { + private static long toValueIndex(long[] indexes, DimensionSizes sizes) { if (indexes.length == 1) return indexes[0]; // for speed if (indexes.length == 0) return 0; // for speed - int valueIndex = 0; + long valueIndex = 0; for (int i = 0; i < indexes.length; i++) { if (indexes[i] >= sizes.size(i)) { throw new IndexOutOfBoundsException(); @@ -134,21 +134,21 @@ public class IndexedTensor implements Tensor { return valueIndex; } - private static int toValueIndex(TensorAddress address, DimensionSizes sizes) { + private static long toValueIndex(TensorAddress address, DimensionSizes sizes) { if (address.isEmpty()) return 0; - int valueIndex = 0; + long valueIndex = 0; for (int i = 0; i < address.size(); i++) { - if (address.intLabel(i) >= sizes.size(i)) { + if (address.numericLabel(i) >= sizes.size(i)) { throw new IndexOutOfBoundsException(); } - valueIndex += productOfDimensionsAfter(i, sizes) * address.intLabel(i); + valueIndex += productOfDimensionsAfter(i, sizes) * address.numericLabel(i); } return valueIndex; } - private static int productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { - int product = 1; + private static long productOfDimensionsAfter(int afterIndex, DimensionSizes sizes) { + long product = 1; for (int i = afterIndex + 1; i < sizes.dimensions(); i++) product *= sizes.size(i); return product; @@ -168,9 +168,9 @@ public class IndexedTensor implements Tensor { ImmutableMap.Builder builder = new ImmutableMap.Builder<>(); Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); - for (int i = 0; i < values.length; i++) { + for (long i = 0; i < values.length; i++) { indexes.next(); - builder.put(indexes.toAddress(), values[i]); + builder.put(indexes.toAddress(), values[(int)i]); } return builder.build(); } @@ -213,7 +213,7 @@ public class IndexedTensor implements Tensor { throw new IllegalArgumentException(sizes.dimensions() + " is the wrong number of dimensions " + "for " + type); for (int i = 0; i < sizes.dimensions(); i++ ) { - Optional size = type.dimensions().get(i).size(); + Optional size = type.dimensions().get(i).size(); if (size.isPresent() && size.get() < sizes.size(i)) throw new IllegalArgumentException("Size of dimension " + type.dimensions().get(i).name() + " is " + sizes.size(i) + @@ -223,7 +223,7 @@ public class IndexedTensor implements Tensor { return new BoundBuilder(type, sizes); } - public abstract Builder cell(double value, int ... indexes); + public abstract Builder cell(double value, long ... indexes); @Override public TensorType type() { return type; } @@ -255,12 +255,12 @@ public class IndexedTensor implements Tensor { if ( sizes.dimensions() != type.dimensions().size()) throw new IllegalArgumentException("Must have a dimension size entry for each dimension in " + type); this.sizes = sizes; - values = new double[sizes.totalSize()]; + values = new double[(int)sizes.totalSize()]; } @Override - public BoundBuilder cell(double value, int ... indexes) { - values[toValueIndex(indexes, sizes)] = value; + public BoundBuilder cell(double value, long ... indexes) { + values[(int)toValueIndex(indexes, sizes)] = value; return this; } @@ -271,7 +271,7 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(TensorAddress address, double value) { - values[toValueIndex(address, sizes)] = value; + values[(int)toValueIndex(address, sizes)] = value; return this; } @@ -286,9 +286,9 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(Cell cell, double value) { - int directIndex = cell.getDirectIndex(); + long directIndex = cell.getDirectIndex(); if (directIndex >= 0) // optimization - values[directIndex] = value; + values[(int)directIndex] = value; else super.cell(cell, value); return this; @@ -299,8 +299,8 @@ public class IndexedTensor implements Tensor { * This requires knowledge of the internal layout of cells in this implementation, and should therefore * probably not be used (but when it can be used it is fast). */ - public void cellByDirectIndex(int index, double value) { - values[index] = value; + public void cellByDirectIndex(long index, double value) { + values[(int)index] = value; } } @@ -326,13 +326,13 @@ public class IndexedTensor implements Tensor { return new IndexedTensor(type, new DimensionSizes.Builder(type.dimensions().size()).build(), new double[] {(Double) firstDimension.get(0) }); DimensionSizes dimensionSizes = findDimensionSizes(firstDimension); - double[] values = new double[dimensionSizes.totalSize()]; + double[] values = new double[(int)dimensionSizes.totalSize()]; fillValues(0, 0, firstDimension, dimensionSizes, values); return new IndexedTensor(type, dimensionSizes, values); } private DimensionSizes findDimensionSizes(List firstDimension) { - List dimensionSizeList = new ArrayList<>(type.dimensions().size()); + List dimensionSizeList = new ArrayList<>(type.dimensions().size()); findDimensionSizes(0, dimensionSizeList, firstDimension); DimensionSizes.Builder b = new DimensionSizes.Builder(type.dimensions().size()); // may be longer than the list but that's correct for (int i = 0; i < b.dimensions(); i++) { @@ -343,9 +343,9 @@ public class IndexedTensor implements Tensor { } @SuppressWarnings("unchecked") - private void findDimensionSizes(int currentDimensionIndex, List dimensionSizes, List currentDimension) { + private void findDimensionSizes(int currentDimensionIndex, List dimensionSizes, List currentDimension) { if (currentDimensionIndex == dimensionSizes.size()) - dimensionSizes.add(currentDimension.size()); + dimensionSizes.add((long)currentDimension.size()); else if (dimensionSizes.get(currentDimensionIndex) != currentDimension.size()) throw new IllegalArgumentException("Missing values in dimension " + type.dimensions().get(currentDimensionIndex) + " in " + type); @@ -356,16 +356,16 @@ public class IndexedTensor implements Tensor { } @SuppressWarnings("unchecked") - private void fillValues(int currentDimensionIndex, int offset, List currentDimension, + private void fillValues(int currentDimensionIndex, long offset, List currentDimension, DimensionSizes sizes, double[] values) { if (currentDimensionIndex < sizes.dimensions() - 1) { // recurse to next dimension - for (int i = 0; i < currentDimension.size(); i++) + for (long i = 0; i < currentDimension.size(); i++) fillValues(currentDimensionIndex + 1, offset + productOfDimensionsAfter(currentDimensionIndex, sizes) * i, - (List) currentDimension.get(i), sizes, values); + (List) currentDimension.get((int)i), sizes, values); } else { // last dimension - fill values - for (int i = 0; i < currentDimension.size(); i++) { - values[offset + i] = nullAsZero((Double)currentDimension.get(i)); // fill missing values as zero + for (long i = 0; i < currentDimension.size(); i++) { + values[(int)(offset + i)] = nullAsZero((Double)currentDimension.get((int)i)); // fill missing values as zero } } } @@ -382,9 +382,9 @@ public class IndexedTensor implements Tensor { @Override public Builder cell(TensorAddress address, double value) { - int[] indexes = new int[address.size()]; + long[] indexes = new long[address.size()]; for (int i = 0; i < address.size(); i++) { - indexes[i] = address.intLabel(i); + indexes[i] = address.numericLabel(i); } cell(value, indexes); return this; @@ -399,7 +399,7 @@ public class IndexedTensor implements Tensor { */ @SuppressWarnings("unchecked") @Override - public Builder cell(double value, int... indexes) { + public Builder cell(double value, long... indexes) { if (indexes.length != type.dimensions().size()) throw new IllegalArgumentException("Wrong number of indexes (" + indexes.length + ") for " + type); @@ -414,18 +414,18 @@ public class IndexedTensor implements Tensor { for (int dimensionIndex = 0; dimensionIndex < indexes.length; dimensionIndex++) { ensureCapacity(indexes[dimensionIndex], currentValues); if (dimensionIndex == indexes.length - 1) { // last dimension - currentValues.set(indexes[dimensionIndex], value); + currentValues.set((int)indexes[dimensionIndex], value); } else { - if (currentValues.get(indexes[dimensionIndex]) == null) - currentValues.set(indexes[dimensionIndex], new ArrayList<>()); - currentValues = (List) currentValues.get(indexes[dimensionIndex]); + if (currentValues.get((int)indexes[dimensionIndex]) == null) + currentValues.set((int)indexes[dimensionIndex], new ArrayList<>()); + currentValues = (List) currentValues.get((int)indexes[dimensionIndex]); } } return this; } /** Fill the given list with nulls if necessary to make sure it has a (possibly null) value at the given index */ - private void ensureCapacity(int index, List list) { + private void ensureCapacity(long index, List list) { while (list.size() <= index) list.add(list.size(), null); } @@ -434,7 +434,7 @@ public class IndexedTensor implements Tensor { private final class CellIterator implements Iterator { - private int count = 0; + private long count = 0; private final Indexes indexes = Indexes.of(dimensionSizes, dimensionSizes, values.length); private final LazyCell reusedCell = new LazyCell(indexes, Double.NaN); @@ -456,7 +456,7 @@ public class IndexedTensor implements Tensor { private final class ValueIterator implements Iterator { - private int count = 0; + private long count = 0; @Override public boolean hasNext() { @@ -466,7 +466,7 @@ public class IndexedTensor implements Tensor { @Override public Double next() { try { - return values[count++]; + return values[(int)count++]; } catch (IndexOutOfBoundsException e) { throw new NoSuchElementException("No element at position " + count); @@ -479,7 +479,7 @@ public class IndexedTensor implements Tensor { private final Indexes superindexes; - /** Those indexes this should iterate over */ + /** The indexes this should iterate over */ private final List subdimensionIndexes; /** @@ -488,7 +488,7 @@ public class IndexedTensor implements Tensor { */ private final DimensionSizes iterateSizes; - private int count = 0; + private long count = 0; private SuperspaceIterator(Set superdimensionNames, DimensionSizes iterateSizes) { this.iterateSizes = iterateSizes; @@ -533,11 +533,11 @@ public class IndexedTensor implements Tensor { * This may be any subset of the dimensions given by address and dimensionSizes. */ private final List iterateDimensions; - private final int[] address; + private final long[] address; private final DimensionSizes iterateSizes; private Indexes indexes; - private int count = 0; + private long count = 0; /** A lazy cell for reuse */ private final LazyCell reusedCell; @@ -554,7 +554,7 @@ public class IndexedTensor implements Tensor { * This is treated as immutable. * @param address the address of the first cell of this subspace. */ - private SubspaceIterator(List iterateDimensions, int[] address, DimensionSizes iterateSizes) { + private SubspaceIterator(List iterateDimensions, long[] address, DimensionSizes iterateSizes) { this.iterateDimensions = iterateDimensions; this.address = address; this.iterateSizes = iterateSizes; @@ -563,7 +563,7 @@ public class IndexedTensor implements Tensor { } /** Returns the total number of cells in this subspace */ - public int size() { + public long size() { return indexes.size(); } @@ -605,7 +605,7 @@ public class IndexedTensor implements Tensor { } @Override - int getDirectIndex() { return indexes.toIterationValueIndex(); } + long getDirectIndex() { return indexes.toIterationValueIndex(); } @Override public TensorAddress getKey() { @@ -630,7 +630,7 @@ public class IndexedTensor implements Tensor { private final DimensionSizes iterationSizes; - protected final int[] indexes; + protected final long[] indexes; public static Indexes of(DimensionSizes sizes) { return of(sizes, sizes); @@ -640,7 +640,7 @@ public class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions())); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int size) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long size) { return of(sourceSizes, iterateSizes, completeIterationOrder(iterateSizes.dimensions()), size); } @@ -648,15 +648,15 @@ public class IndexedTensor implements Tensor { return of(sourceSizes, iterateSizes, iterateDimensions, computeSize(iterateSizes, iterateDimensions)); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, int size) { - return of(sourceSizes, iterateSizes, iterateDimensions, new int[iterateSizes.dimensions()], size); + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, long size) { + return of(sourceSizes, iterateSizes, iterateDimensions, new long[iterateSizes.dimensions()], size); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, int[] initialIndexes) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, long[] initialIndexes) { return of(sourceSizes, iterateSizes, iterateDimensions, initialIndexes, computeSize(iterateSizes, iterateDimensions)); } - private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, int[] initialIndexes, int size) { + private static Indexes of(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, long[] initialIndexes, long size) { if (size == 0) { return new EmptyIndexes(sourceSizes, iterateSizes, initialIndexes); // we're told explicitly there are truly no values available } @@ -684,14 +684,14 @@ public class IndexedTensor implements Tensor { return iterationDimensions; } - private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, int[] indexes) { + private Indexes(DimensionSizes sourceSizes, DimensionSizes iterationSizes, long[] indexes) { this.sourceSizes = sourceSizes; this.iterationSizes = iterationSizes; this.indexes = indexes; } - private static int computeSize(DimensionSizes sizes, List iterateDimensions) { - int size = 1; + private static long computeSize(DimensionSizes sizes, List iterateDimensions) { + long size = 1; for (int iterateDimension : iterateDimensions) size *= sizes.size(iterateDimension); return size; @@ -702,25 +702,25 @@ public class IndexedTensor implements Tensor { return TensorAddress.of(indexes); } - public int[] indexesCopy() { + public long[] indexesCopy() { return Arrays.copyOf(indexes, indexes.length); } /** Returns a copy of the indexes of this which must not be modified */ - public int[] indexesForReading() { return indexes; } + public long[] indexesForReading() { return indexes; } - int toSourceValueIndex() { + long toSourceValueIndex() { return IndexedTensor.toValueIndex(indexes, sourceSizes); } - int toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } + long toIterationValueIndex() { return IndexedTensor.toValueIndex(indexes, iterationSizes); } DimensionSizes dimensionSizes() { return iterationSizes; } /** Returns an immutable list containing a copy of the indexes in this */ - public List toList() { - ImmutableList.Builder builder = new ImmutableList.Builder<>(); - for (int index : indexes) + public List toList() { + ImmutableList.Builder builder = new ImmutableList.Builder<>(); + for (long index : indexes) builder.add(index); return builder.build(); } @@ -730,7 +730,7 @@ public class IndexedTensor implements Tensor { return "indexes " + Arrays.toString(indexes); } - public abstract int size(); + public abstract long size(); public abstract void next(); @@ -738,12 +738,12 @@ public class IndexedTensor implements Tensor { private final static class EmptyIndexes extends Indexes { - private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) { + private EmptyIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @Override - public int size() { return 0; } + public long size() { return 0; } @Override public void next() {} @@ -752,12 +752,12 @@ public class IndexedTensor implements Tensor { private final static class SingleValueIndexes extends Indexes { - private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, int[] indexes) { + private SingleValueIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, long[] indexes) { super(sourceSizes, iterateSizes, indexes); } @Override - public int size() { return 1; } + public long size() { return 1; } @Override public void next() {} @@ -766,11 +766,11 @@ public class IndexedTensor implements Tensor { private static class MultiDimensionIndexes extends Indexes { - private final int size; + private final long size; private final List iterateDimensions; - private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, int[] initialIndexes, int size) { + private MultiDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, List iterateDimensions, long[] initialIndexes, long size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimensions = iterateDimensions; this.size = size; @@ -781,7 +781,7 @@ public class IndexedTensor implements Tensor { /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override - public int size() { + public long size() { return size; } @@ -806,36 +806,38 @@ public class IndexedTensor implements Tensor { /** In this case we can reuse the source index computation for the iteration index */ private final static class EqualSizeMultiDimensionIndexes extends MultiDimensionIndexes { - private int lastComputedSourceValueIndex = -1; + private long lastComputedSourceValueIndex = -1; - private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List iterateDimensions, int[] initialIndexes, int size) { + private EqualSizeMultiDimensionIndexes(DimensionSizes sizes, List iterateDimensions, long[] initialIndexes, long size) { super(sizes, sizes, iterateDimensions, initialIndexes, size); } - int toSourceValueIndex() { + @Override + long toSourceValueIndex() { return lastComputedSourceValueIndex = super.toSourceValueIndex(); } // NOTE: We assume the source index always gets computed first. Otherwise using this will produce a runtime exception - int toIterationValueIndex() { return lastComputedSourceValueIndex; } + @Override + long toIterationValueIndex() { return lastComputedSourceValueIndex; } } /** In this case we can keep track of indexes using a step instead of using the more elaborate computation */ private final static class SingleDimensionIndexes extends Indexes { - private final int size; + private final long size; private final int iterateDimension; /** Maintain this directly as an optimization for 1-d iteration */ - private int currentSourceValueIndex, currentIterationValueIndex; + private long currentSourceValueIndex, currentIterationValueIndex; /** The iteration step in the value index space */ - private final int sourceStep, iterationStep; + private final long sourceStep, iterationStep; private SingleDimensionIndexes(DimensionSizes sourceSizes, DimensionSizes iterateSizes, - int iterateDimension, int[] initialIndexes, int size) { + int iterateDimension, long[] initialIndexes, long size) { super(sourceSizes, iterateSizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; @@ -850,7 +852,7 @@ public class IndexedTensor implements Tensor { /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override - public int size() { + public long size() { return size; } @@ -868,28 +870,28 @@ public class IndexedTensor implements Tensor { } @Override - int toSourceValueIndex() { return currentSourceValueIndex; } + long toSourceValueIndex() { return currentSourceValueIndex; } @Override - int toIterationValueIndex() { return currentIterationValueIndex; } + long toIterationValueIndex() { return currentIterationValueIndex; } } /** In this case we only need to keep track of one index */ private final static class EqualSizeSingleDimensionIndexes extends Indexes { - private final int size; + private final long size; private final int iterateDimension; /** Maintain this directly as an optimization for 1-d iteration */ - private int currentValueIndex; + private long currentValueIndex; /** The iteration step in the value index space */ - private final int step; + private final long step; private EqualSizeSingleDimensionIndexes(DimensionSizes sizes, - int iterateDimension, int[] initialIndexes, int size) { + int iterateDimension, long[] initialIndexes, long size) { super(sizes, sizes, initialIndexes); this.iterateDimension = iterateDimension; this.size = size; @@ -902,7 +904,7 @@ public class IndexedTensor implements Tensor { /** Returns the number of values this will iterate over - i.e the product if the iterating dimension sizes */ @Override - public int size() { + public long size() { return size; } @@ -919,10 +921,10 @@ public class IndexedTensor implements Tensor { } @Override - int toSourceValueIndex() { return currentValueIndex; } + long toSourceValueIndex() { return currentValueIndex; } @Override - int toIterationValueIndex() { return currentValueIndex; } + long toIterationValueIndex() { return currentValueIndex; } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java index aba61478e69..15993072c37 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MappedTensor.java @@ -29,7 +29,7 @@ public class MappedTensor implements Tensor { public TensorType type() { return type; } @Override - public int size() { return cells.size(); } + public long size() { return cells.size(); } @Override public double get(TensorAddress address) { return cells.getOrDefault(address, Double.NaN); } @@ -80,7 +80,7 @@ public class MappedTensor implements Tensor { } @Override - public Builder cell(double value, int... labels) { + public Builder cell(double value, long... labels) { cells.put(TensorAddress.of(labels), value); return this; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 9a751e078e0..0c9ed769c0d 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -47,13 +47,13 @@ public class MixedTensor implements Tensor { /** Returns the size of the tensor measured in number of cells */ @Override - public int size() { return cells.size(); } + public long size() { return cells.size(); } /** Returns the value at the given address */ @Override public double get(TensorAddress address) { - int cellIndex = index.indexOf(address); - Cell cell = cells.get(cellIndex); + long cellIndex = index.indexOf(address); + Cell cell = cells.get((int)cellIndex); if (!address.equals(cell.getKey())) { throw new IllegalStateException("Unable to find correct cell by direct index."); } @@ -113,7 +113,7 @@ public class MixedTensor implements Tensor { } /** Returns the size of dense subspaces */ - public int denseSubspaceSize() { + public long denseSubspaceSize() { return index.denseSubspaceSize(); } @@ -148,7 +148,7 @@ public class MixedTensor implements Tensor { } @Override - public Tensor.Builder cell(double value, int... labels) { + public Tensor.Builder cell(double value, long... labels) { throw new UnsupportedOperationException("Not implemented."); } @@ -179,13 +179,13 @@ public class MixedTensor implements Tensor { index = indexBuilder.index(); } - public int denseSubspaceSize() { + public long denseSubspaceSize() { return index.denseSubspaceSize(); } private double[] denseSubspace(TensorAddress sparsePartial) { if (!denseSubspaceMap.containsKey(sparsePartial)) { - denseSubspaceMap.put(sparsePartial, new double[denseSubspaceSize()]); + denseSubspaceMap.put(sparsePartial, new double[(int)denseSubspaceSize()]); } return denseSubspaceMap.get(sparsePartial); } @@ -193,21 +193,21 @@ public class MixedTensor implements Tensor { @Override public Tensor.Builder cell(TensorAddress address, double value) { TensorAddress sparsePart = index.sparsePartialAddress(address); - int denseOffset = index.denseOffset(address); + long denseOffset = index.denseOffset(address); double[] denseSubspace = denseSubspace(sparsePart); - denseSubspace[denseOffset] = value; + denseSubspace[(int)denseOffset] = value; return this; } public Tensor.Builder block(TensorAddress sparsePart, double[] values) { double[] denseSubspace = denseSubspace(sparsePart); - System.arraycopy(values, 0, denseSubspace, 0, denseSubspaceSize()); + System.arraycopy(values, 0, denseSubspace, 0, (int)denseSubspaceSize()); return this; } @Override public MixedTensor build() { - int count = 0; + long count = 0; ImmutableList.Builder builder = new ImmutableList.Builder<>(); for (Map.Entry entry : denseSubspaceMap.entrySet()) { @@ -215,9 +215,9 @@ public class MixedTensor implements Tensor { indexBuilder.put(sparsePart, count); double[] denseSubspace = entry.getValue(); - for (int offset = 0; offset < denseSubspace.length; ++offset) { + for (long offset = 0; offset < denseSubspace.length; ++offset) { TensorAddress cellAddress = index.addressOf(sparsePart, offset); - double value = denseSubspace[offset]; + double value = denseSubspace[(int)offset]; builder.add(new Cell(cellAddress, value)); count++; } @@ -239,12 +239,12 @@ public class MixedTensor implements Tensor { public static class UnboundBuilder extends Builder { private Map cells; - private final int[] dimensionBounds; + private final long[] dimensionBounds; private UnboundBuilder(TensorType type) { super(type); cells = new HashMap<>(); - dimensionBounds = new int[type.dimensions().size()]; + dimensionBounds = new long[type.dimensions().size()]; } @Override @@ -268,7 +268,7 @@ public class MixedTensor implements Tensor { for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); if (dimension.isIndexed()) { - dimensionBounds[i] = Math.max(address.intLabel(i), dimensionBounds[i]); + dimensionBounds[i] = Math.max(address.numericLabel(i), dimensionBounds[i]); } } } @@ -280,7 +280,7 @@ public class MixedTensor implements Tensor { if (!dimension.isIndexed()) { typeBuilder.mapped(dimension.name()); } else { - int size = dimension.size().orElse(dimensionBounds[i] + 1); + long size = dimension.size().orElse(dimensionBounds[i] + 1); typeBuilder.indexed(dimension.name(), size); } } @@ -303,8 +303,8 @@ public class MixedTensor implements Tensor { private final List mappedDimensions; private final List indexedDimensions; - private ImmutableMap sparseMap; - private int denseSubspaceSize = -1; + private ImmutableMap sparseMap; + private long denseSubspaceSize = -1; private Index(TensorType type) { this.type = type; @@ -314,26 +314,27 @@ public class MixedTensor implements Tensor { this.denseType = createPartialType(indexedDimensions); } - public int indexOf(TensorAddress address) { + public long indexOf(TensorAddress address) { TensorAddress sparsePart = sparsePartialAddress(address); - if (!sparseMap.containsKey(sparsePart)) { + if ( ! sparseMap.containsKey(sparsePart)) { throw new IllegalArgumentException("Address not found"); } - int base = sparseMap.get(sparsePart); - int offset = denseOffset(address); + long base = sparseMap.get(sparsePart); + long offset = denseOffset(address); return base + offset; } public static class Builder { + private final Index index; - private final ImmutableMap.Builder builder; + private final ImmutableMap.Builder builder; public Builder(TensorType type) { index = new Index(type); builder = new ImmutableMap.Builder<>(); } - public void put(TensorAddress address, int index) { + public void put(TensorAddress address, long index) { builder.put(address, index); } @@ -347,7 +348,7 @@ public class MixedTensor implements Tensor { } } - public int denseSubspaceSize() { + public long denseSubspaceSize() { if (denseSubspaceSize == -1) { denseSubspaceSize = 1; for (int i = 0; i < type.dimensions().size(); ++i) { @@ -375,13 +376,13 @@ public class MixedTensor implements Tensor { return builder.build(); } - private int denseOffset(TensorAddress address) { - int innerSize = 1; - int offset = 0; + private long denseOffset(TensorAddress address) { + long innerSize = 1; + long offset = 0; for (int i = type.dimensions().size(); --i >= 0; ) { TensorType.Dimension dimension = type.dimensions().get(i); if (dimension.isIndexed()) { - int label = address.intLabel(i); + long label = address.numericLabel(i); offset += label * innerSize; innerSize *= dimension.size().orElseThrow(() -> new IllegalArgumentException("Unknown size of indexed dimension.")); @@ -390,18 +391,18 @@ public class MixedTensor implements Tensor { return offset; } - private TensorAddress denseOffsetToAddress(int denseOffset) { + private TensorAddress denseOffsetToAddress(long denseOffset) { if (denseOffset < 0 || denseOffset > denseSubspaceSize) { throw new IllegalArgumentException("Offset out of bounds"); } - int restSize = denseOffset; - int innerSize = denseSubspaceSize; - int[] labels = new int[indexedDimensions.size()]; + long restSize = denseOffset; + long innerSize = denseSubspaceSize; + long[] labels = new long[indexedDimensions.size()]; for (int i = 0; i < labels.length; ++i) { TensorType.Dimension dimension = indexedDimensions.get(i); - int dimensionSize = dimension.size().orElseThrow(() -> + long dimensionSize = dimension.size().orElseThrow(() -> new IllegalArgumentException("Unknown size of indexed dimension.")); innerSize /= dimensionSize; @@ -411,7 +412,7 @@ public class MixedTensor implements Tensor { return TensorAddress.of(labels); } - private TensorAddress addressOf(TensorAddress sparsePart, int denseOffset) { + private TensorAddress addressOf(TensorAddress sparsePart, long denseOffset) { TensorAddress densePart = denseOffsetToAddress(denseOffset); String[] labels = new String[type.dimensions().size()]; int mappedIndex = 0; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java index e3398850373..23ef0772aea 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/PartialAddress.java @@ -6,11 +6,11 @@ import com.google.common.annotations.Beta; /** * An address to a subset of a tensors' cells, specifying a label for some but not necessarily all of the tensors * dimensions. - * + * * @author bratseth */ -// Implementation notes: -// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. +// Implementation notes: +// - These are created in inner (though not inner-most) loops so they are implemented with minimal allocation. // We also avoid non-essential error checking. // - We can add support for string labels later without breaking the API @Beta @@ -19,7 +19,7 @@ public class PartialAddress { // Two arrays which contains corresponding dimension=label pairs. // The sizes of these are always equal. private final String[] dimensionNames; - private final int[] labels; + private final long[] labels; private PartialAddress(Builder builder) { this.dimensionNames = builder.dimensionNames; @@ -27,36 +27,36 @@ public class PartialAddress { builder.dimensionNames = null; // invalidate builder to safely take over array ownership builder.labels = null; } - + /** Returns the int label of this dimension, or -1 if no label is specified for it */ - int intLabel(String dimensionName) { + long numericLabel(String dimensionName) { for (int i = 0; i < dimensionNames.length; i++) if (dimensionNames[i].equals(dimensionName)) return labels[i]; return -1; } - + public static class Builder { private String[] dimensionNames; - private int[] labels; + private long[] labels; private int index = 0; - + public Builder(int size) { dimensionNames = new String[size]; - labels = new int[size]; + labels = new long[size]; } - - public void add(String dimensionName, int label) { + + public void add(String dimensionName, long label) { dimensionNames[index] = dimensionName; labels[index] = label; index++; } - + public PartialAddress build() { return new PartialAddress(this); } - + } - + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index 1b60e01cf7e..0c948f1fbee 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -59,7 +59,7 @@ public interface Tensor { default boolean isEmpty() { return size() == 0; } /** Returns the number of cells in this */ - int size(); + long size(); /** Returns the value of a cell, or NaN if this cell does not exist/have no value */ double get(TensorAddress address); @@ -124,7 +124,7 @@ public interface Tensor { return new Rename(new ConstantTensor(this), fromDimensions, toDimensions).evaluate(); } - static Tensor generate(TensorType type, Function, Double> valueSupplier) { + static Tensor generate(TensorType type, Function, Double> valueSupplier) { return new Generate(type, valueSupplier).evaluate(); } @@ -333,7 +333,7 @@ public interface Tensor { * This is for optimizations mapping between tensors where this is possible without creating a * TensorAddress. */ - int getDirectIndex() { return -1; } + long getDirectIndex() { return -1; } @Override public Double getValue() { return value; } @@ -396,7 +396,7 @@ public interface Tensor { Builder cell(TensorAddress address, double value); /** Add a cell */ - Builder cell(double value, int ... labels); + Builder cell(double value, long ... labels); /** * Add a cell @@ -425,7 +425,7 @@ public interface Tensor { return this; } - public CellBuilder label(String dimension, int label) { + public CellBuilder label(String dimension, long label) { return label(dimension, String.valueOf(label)); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java index ff1202463f2..38553497478 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorAddress.java @@ -2,16 +2,10 @@ package com.yahoo.tensor; import com.google.common.annotations.Beta; -import com.google.common.collect.ImmutableList; -import java.util.ArrayList; import java.util.Arrays; -import java.util.Collections; -import java.util.HashSet; -import java.util.List; import java.util.Objects; import java.util.Optional; -import java.util.Set; /** * An immutable address to a tensor cell. This simply supplies a value to each dimension @@ -26,8 +20,8 @@ public abstract class TensorAddress implements Comparable { return new StringTensorAddress(labels); } - public static TensorAddress of(int ... labels) { - return new IntTensorAddress(labels); + public static TensorAddress of(long ... labels) { + return new NumericTensorAddress(labels); } /** Returns the number of labels in this */ @@ -41,14 +35,14 @@ public abstract class TensorAddress implements Comparable { public abstract String label(int i); /** - * Returns the i'th label in this as an int. - * Prefer this if you know that this is an integer address, but not otherwise. + * Returns the i'th label in this as a long. + * Prefer this if you know that this is a numeric address, but not otherwise. * * @throws IllegalArgumentException if there is no label at this index */ - public abstract int intLabel(int i); + public abstract long numericLabel(int i); - public abstract TensorAddress withLabel(int labelIndex, int label); + public abstract TensorAddress withLabel(int labelIndex, long label); public final boolean isEmpty() { return size() == 0; } @@ -110,17 +104,17 @@ public abstract class TensorAddress implements Comparable { public String label(int i) { return labels[i]; } @Override - public int intLabel(int i) { + public long numericLabel(int i) { try { - return Integer.parseInt(labels[i]); + return Long.parseLong(labels[i]); } catch (NumberFormatException e) { - throw new IllegalArgumentException("Expected an int label in " + this + " at position " + i); + throw new IllegalArgumentException("Expected a long label in " + this + " at position " + i); } } @Override - public TensorAddress withLabel(int index, int label) { + public TensorAddress withLabel(int index, long label) { String[] labels = Arrays.copyOf(this.labels, this.labels.length); labels[index] = String.valueOf(label); return new StringTensorAddress(labels); @@ -133,11 +127,11 @@ public abstract class TensorAddress implements Comparable { } - private static final class IntTensorAddress extends TensorAddress { + private static final class NumericTensorAddress extends TensorAddress { - private final int[] labels; + private final long[] labels; - private IntTensorAddress(int[] labels) { + private NumericTensorAddress(long[] labels) { this.labels = Arrays.copyOf(labels, labels.length); } @@ -148,13 +142,13 @@ public abstract class TensorAddress implements Comparable { public String label(int i) { return String.valueOf(labels[i]); } @Override - public int intLabel(int i) { return labels[i]; } + public long numericLabel(int i) { return labels[i]; } @Override - public TensorAddress withLabel(int index, int label) { - int[] labels = Arrays.copyOf(this.labels, this.labels.length); + public TensorAddress withLabel(int index, long label) { + long[] labels = Arrays.copyOf(this.labels, this.labels.length); labels[index] = label; - return new IntTensorAddress(labels); + return new NumericTensorAddress(labels); } @Override diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index 914d853aeca..b396f831de0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -139,7 +139,7 @@ public class TensorType { public final String name() { return name; } /** Returns the size of this dimension if it is bound, empty otherwise */ - public abstract Optional size(); + public abstract Optional size(); public abstract Type type(); @@ -189,7 +189,7 @@ public class TensorType { return this.name.compareTo(other.name); } - public static Dimension indexed(String name, int size) { + public static Dimension indexed(String name, long size) { return new IndexedBoundDimension(name, size); } @@ -197,17 +197,19 @@ public class TensorType { public static class IndexedBoundDimension extends TensorType.Dimension { - private final Integer size; + private final Long size; - private IndexedBoundDimension(String name, int size) { + private IndexedBoundDimension(String name, long size) { super(name); if (size < 1) throw new IllegalArgumentException("Size of bound dimension '" + name + "' must be at least 1"); + if (size > Integer.MAX_VALUE) + throw new IllegalArgumentException("Size of bound dimension '" + name + "' cannot be larger than " + Integer.MAX_VALUE); this.size = size; } @Override - public Optional size() { return Optional.of(size); } + public Optional size() { return Optional.of(size); } @Override public Type type() { return Type.indexedBound; } @@ -248,7 +250,7 @@ public class TensorType { } @Override - public Optional size() { return Optional.empty(); } + public Optional size() { return Optional.empty(); } @Override public Type type() { return Type.indexedUnbound; } @@ -269,7 +271,7 @@ public class TensorType { } @Override - public Optional size() { return Optional.empty(); } + public Optional size() { return Optional.empty(); } @Override public Type type() { return Type.mapped; } @@ -357,7 +359,7 @@ public class TensorType { * * @throws IllegalArgumentException if the dimension is already present */ - public Builder indexed(String name, int size) { return add(new IndexedBoundDimension(name, size)); } + public Builder indexed(String name, long size) { return add(new IndexedBoundDimension(name, size)); } /** * Adds an unbound indexed dimension to this diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index faa0ca36cb6..d4affe0ef9b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -67,7 +67,7 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); + long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); @@ -75,7 +75,7 @@ public class Concat extends PrimitiveTensorFunction { return builder.build(); } - private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, + private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); for (Iterator ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { @@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); for (int i = 0; i < concatSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); - int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); - int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); + long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); + long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); if (currentDimension.equals(concatDimension)) concatSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) @@ -148,8 +148,8 @@ public class Concat extends PrimitiveTensorFunction { * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType concatType, int concatOffset, String concatDimension) { - int[] combinedLabels = new int[concatType.dimensions().size()]; + TensorType concatType, long concatOffset, String concatDimension) { + long[] combinedLabels = new long[concatType.dimensions().size()]; Arrays.fill(combinedLabels, -1); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension @@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) { + private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; if (concatDimension == toIndex) { - to[toIndex] = from.intLabel(i) + concatOffset; + to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false; - to[toIndex] = from.intLabel(i); + if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + to[toIndex] = from.numericLabel(i); } } return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index c75d8ee4753..653be8dacf0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -17,7 +17,7 @@ import java.util.stream.Stream; public class Diag extends CompositeTensorFunction { private final TensorType type; - private final Function, Double> diagFunction; + private final Function, Double> diagFunction; public Diag(TensorType type) { this.type = type; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e42d25197e2..ef2770c04f5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -22,17 +22,17 @@ import java.util.function.Function; public class Generate extends PrimitiveTensorFunction { private final TensorType type; - private final Function, Double> generator; + private final Function, Double> generator; /** * Creates a generated tensor * * @param type the type of the tensor - * @param generator the function generating values from a list of ints specifying the indexes of the + * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function, Double> generator) { + public Generate(TensorType type, Function, Double> generator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); Objects.requireNonNull(generator, "The argument function cannot be null"); validateType(type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ff887e3e9a6..174a8e4c435 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -56,8 +56,8 @@ public class Join extends PrimitiveTensorFunction { if (aDim.name().equals(bDim.name())) { // include if (aDim.isIndexed() && bDim.isIndexed()) { if (aDim.size().isPresent() || bDim.size().isPresent()) - typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), - bDim.size().orElse(Integer.MAX_VALUE))); + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE), + bDim.size().orElse(Long.MAX_VALUE))); else typeBuilder.indexed(aDim.name()); } @@ -118,11 +118,11 @@ public class Join extends PrimitiveTensorFunction { } private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator aIterator = a.valueIterator(); Iterator bIterator = b.valueIterator(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build()); - for (int i = 0; i < joinedLength; i++) + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); + for (int i = 0; i < joinedRank; i++) builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); return builder.build(); } @@ -169,10 +169,10 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private void joinSubspaces(Iterator subspace, int subspaceSize, - Iterator superspace, int superspaceSize, + private void joinSubspaces(Iterator subspace, long subspaceSize, + Iterator superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder) { - int joinedLength = Math.min(subspaceSize, superspaceSize); + long joinedLength = Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -281,7 +281,7 @@ public class Join extends PrimitiveTensorFunction { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); + builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index a56f82b026a..8e7f4e4c773 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -18,7 +18,7 @@ import java.util.stream.Stream; public class Range extends CompositeTensorFunction { private final TensorType type; - private final Function, Double> rangeFunction; + private final Function, Double> rangeFunction; public Range(TensorType type) { this.type = type; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index fb5029fbfd6..f1dadba2a29 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -14,8 +14,8 @@ import java.util.stream.Collectors; /** * Factory of scalar Java functions. * The purpose of this is to embellish anonymous functions with a runtime type - * such that they can be inspected and will return a parseable toString. - * + * such that they can be inspected and will return a parsable toString. + * * @author bratseth */ @Beta @@ -31,9 +31,9 @@ public class ScalarFunctions { public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } - public static Function, Double> random() { return new Random(); } - public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); } - public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); } + public static Function, Double> random() { return new Random(); } + public static Function, Double> equal(List argumentNames) { return new EqualElements(argumentNames); } + public static Function, Double> sum(List argumentNames) { return new SumElements(argumentNames); } // Binary operators ----------------------------------------------------------------------------- @@ -60,7 +60,7 @@ public class ScalarFunctions { public static class Multiply implements DoubleBinaryOperator { @Override - public double applyAsDouble(double left, double right) { return left * right; } + public double applyAsDouble(double left, double right) { return left * right; } @Override public String toString() { return "f(a,b)(a * b)"; } } @@ -100,26 +100,26 @@ public class ScalarFunctions { // Variable-length operators ----------------------------------------------------------------------------- - public static class EqualElements implements Function, Double> { - private final ImmutableList argumentNames; + public static class EqualElements implements Function, Double> { + private final ImmutableList argumentNames; private EqualElements(List argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @Override - public Double apply(List values) { + public Double apply(List values) { if (values.isEmpty()) return 1.0; - for (Integer value : values) + for (Long value : values) if ( ! value.equals(values.get(0))) return 0.0; return 1.0; } @Override - public String toString() { + public String toString() { if (argumentNames.size() == 0) return "1"; if (argumentNames.size() == 1) return "1"; if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1); - + StringBuilder b = new StringBuilder(); for (int i = 0; i < argumentNames.size() -1; i++) { b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")"); @@ -130,25 +130,25 @@ public class ScalarFunctions { } } - public static class Random implements Function, Double> { + public static class Random implements Function, Double> { @Override - public Double apply(List values) { + public Double apply(List values) { return ThreadLocalRandom.current().nextDouble(); } @Override public String toString() { return "random"; } } - public static class SumElements implements Function, Double> { + public static class SumElements implements Function, Double> { private final ImmutableList argumentNames; private SumElements(List argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @Override - public Double apply(List values) { - int sum = 0; - for (Integer value : values) + public Double apply(List values) { + long sum = 0; + for (Long value : values) sum += value; return (double)sum; } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java index aabb53d1c67..1e830bac461 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/DenseBinaryFormat.java @@ -36,7 +36,7 @@ public class DenseBinaryFormat implements BinaryFormat { buffer.putInt1_4Bytes(tensor.type().dimensions().size()); for (int i = 0; i < tensor.type().dimensions().size(); i++) { buffer.putUtf8String(tensor.type().dimensions().get(i).name()); - buffer.putInt1_4Bytes(tensor.dimensionSizes().size(i)); + buffer.putInt1_4Bytes((int)tensor.dimensionSizes().size(i)); // XXX: Size truncation } } @@ -71,7 +71,7 @@ public class DenseBinaryFormat implements BinaryFormat { int dimensionCount = buffer.getInt1_4Bytes(); TensorType.Builder builder = new TensorType.Builder(); for (int i = 0; i < dimensionCount; i++) - builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); + builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation return builder.build(); } @@ -84,7 +84,7 @@ public class DenseBinaryFormat implements BinaryFormat { } private void decodeCells(DimensionSizes sizes, GrowableByteBuffer buffer, IndexedTensor.BoundBuilder builder) { - for (int i = 0; i < sizes.totalSize(); i++) + for (long i = 0; i < sizes.totalSize(); i++) builder.cellByDirectIndex(i, buffer.getDouble()); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index 61dfa888567..34e6cccf0f0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -46,16 +46,16 @@ class MixedBinaryFormat implements BinaryFormat { buffer.putInt1_4Bytes(denseDimensions.size()); for (TensorType.Dimension dimension : denseDimensions) { buffer.putUtf8String(dimension.name()); - buffer.putInt1_4Bytes(dimension.size().orElseThrow(() -> - new IllegalArgumentException("Unknown size of indexed dimension."))); + buffer.putInt1_4Bytes((int)dimension.size().orElseThrow(() -> + new IllegalArgumentException("Unknown size of indexed dimension.")).longValue()); // XXX: Size truncation } } private void encodeCells(GrowableByteBuffer buffer, MixedTensor tensor) { List sparseDimensions = tensor.type().dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); - int denseSubspaceSize = tensor.denseSubspaceSize(); + long denseSubspaceSize = tensor.denseSubspaceSize(); if (sparseDimensions.size() > 0) { - buffer.putInt1_4Bytes(tensor.size() / denseSubspaceSize); + buffer.putInt1_4Bytes((int)(tensor.size() / denseSubspaceSize)); // XXX: Size truncation } Iterator cellIterator = tensor.cellIterator(); while (cellIterator.hasNext()) { @@ -98,7 +98,7 @@ class MixedBinaryFormat implements BinaryFormat { } int numIndexedDimensions = buffer.getInt1_4Bytes(); for (int i = 0; i < numIndexedDimensions; ++i) { - builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); + builder.indexed(buffer.getUtf8String(), buffer.getInt1_4Bytes()); // XXX: Size truncation } return builder.build(); } @@ -106,21 +106,21 @@ class MixedBinaryFormat implements BinaryFormat { private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { List sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); TensorType sparseType = MixedTensor.createPartialType(sparseDimensions); - int denseSubspaceSize = builder.denseSubspaceSize(); + long denseSubspaceSize = builder.denseSubspaceSize(); int numBlocks = 1; if (sparseDimensions.size() > 0) { numBlocks = buffer.getInt1_4Bytes(); } - double[] denseSubspace = new double[denseSubspaceSize]; + double[] denseSubspace = new double[(int)denseSubspaceSize]; for (int i = 0; i < numBlocks; ++i) { TensorAddress.Builder sparseAddress = new TensorAddress.Builder(sparseType); for (TensorType.Dimension sparseDimension : sparseDimensions) { sparseAddress.add(sparseDimension.name(), buffer.getUtf8String()); } - for (int denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) { - denseSubspace[denseOffset] = buffer.getDouble(); + for (long denseOffset = 0; denseOffset < denseSubspaceSize; denseOffset++) { + denseSubspace[(int)denseOffset] = buffer.getDouble(); } builder.block(sparseAddress.build(), denseSubspace); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java index 19969506eca..0cd3ff77aca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/SparseBinaryFormat.java @@ -3,13 +3,14 @@ package com.yahoo.tensor.serialization; import com.google.common.annotations.Beta; import com.yahoo.io.GrowableByteBuffer; -import com.yahoo.tensor.MappedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; -import com.yahoo.text.Utf8; -import java.util.*; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Optional; /** * Implementation of a sparse binary format for a tensor on the form: @@ -39,7 +40,7 @@ class SparseBinaryFormat implements BinaryFormat { } private void encodeCells(GrowableByteBuffer buffer, Tensor tensor) { - buffer.putInt1_4Bytes(tensor.size()); + buffer.putInt1_4Bytes((int)tensor.size()); // XXX: Size truncation for (Iterator i = tensor.cellIterator(); i.hasNext(); ) { Map.Entry cell = i.next(); encodeAddress(buffer, cell.getKey()); @@ -79,8 +80,8 @@ class SparseBinaryFormat implements BinaryFormat { } private void decodeCells(GrowableByteBuffer buffer, Tensor.Builder builder, TensorType type) { - int numCells = buffer.getInt1_4Bytes(); - for (int i = 0; i < numCells; ++i) { + long numCells = buffer.getInt1_4Bytes(); // XXX: Size truncation + for (long i = 0; i < numCells; ++i) { Tensor.Builder.CellBuilder cellBuilder = builder.cell(); decodeAddress(buffer, cellBuilder, type); cellBuilder.value(buffer.getDouble()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java index 693b0f09351..38a8329bff1 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorTestCase.java @@ -4,7 +4,6 @@ package com.yahoo.tensor; import com.google.common.collect.ImmutableList; import com.yahoo.tensor.evaluation.MapEvaluationContext; import com.yahoo.tensor.evaluation.VariableTensor; -import com.yahoo.tensor.functions.Argmax; import com.yahoo.tensor.functions.ConstantTensor; import com.yahoo.tensor.functions.Join; import com.yahoo.tensor.functions.Reduce; @@ -12,14 +11,12 @@ import com.yahoo.tensor.functions.TensorFunction; import org.junit.Test; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Set; -import java.util.stream.Collectors; -import static org.junit.Assert.assertEquals; import static com.yahoo.tensor.TensorType.Dimension.Type; +import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -99,7 +96,7 @@ public class TensorTestCase { ImmutableList.of("y", "x"))); assertEquals(Tensor.from("{ {x:0,y:0}:0, {x:0,y:1}:0, {x:1,y:0}:0, {x:1,y:1}:1, {x:2,y:0}:0, {x:2,y:1}:2, }"), Tensor.generate(new TensorType.Builder().indexed("x", 3).indexed("y", 2).build(), - (List indexes) -> (double)indexes.get(0)*indexes.get(1))); + (List indexes) -> (double)indexes.get(0)*indexes.get(1))); assertEquals(Tensor.from("{ {x:0,y:0,z:0}:0, {x:0,y:1,z:0}:1, {x:1,y:0,z:0}:1, {x:1,y:1,z:0}:2, {x:2,y:0,z:0}:2, {x:2,y:1,z:0}:3, "+ " {x:0,y:0,z:1}:1, {x:0,y:1,z:1}:2, {x:1,y:0,z:1}:2, {x:1,y:1,z:1}:3, {x:2,y:0,z:1}:3, {x:2,y:1,z:1}:4 }"), Tensor.range(new TensorType.Builder().indexed("x", 3).indexed("y", 2).indexed("z", 2).build())); -- cgit v1.2.3