diff options
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
6 files changed, 42 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index faa0ca36cb6..d4affe0ef9b 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -67,7 +67,7 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); - int aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); + long aDimensionLength = aIndexed.type().indexOfDimension(dimension).map(d -> aIndexed.dimensionSizes().size(d)).orElseThrow(RuntimeException::new); int[] aToIndexes = mapIndexes(a.type(), concatType); int[] bToIndexes = mapIndexes(b.type(), concatType); concatenateTo(aIndexed, bIndexed, aDimensionLength, concatType, aToIndexes, bToIndexes, builder); @@ -75,7 +75,7 @@ public class Concat extends PrimitiveTensorFunction { return builder.build(); } - private void concatenateTo(IndexedTensor a, IndexedTensor b, int offset, TensorType concatType, + private void concatenateTo(IndexedTensor a, IndexedTensor b, long offset, TensorType concatType, int[] aToIndexes, int[] bToIndexes, Tensor.Builder builder) { Set<String> otherADimensions = a.type().dimensionNames().stream().filter(d -> !d.equals(dimension)).collect(Collectors.toSet()); for (Iterator<IndexedTensor.SubspaceIterator> ia = a.subspaceIterator(otherADimensions); ia.hasNext();) { @@ -129,8 +129,8 @@ public class Concat extends PrimitiveTensorFunction { DimensionSizes.Builder concatSizes = new DimensionSizes.Builder(concatType.dimensions().size()); for (int i = 0; i < concatSizes.dimensions(); i++) { String currentDimension = concatType.dimensions().get(i).name(); - int aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0); - int bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0); + long aSize = a.type().indexOfDimension(currentDimension).map(d -> a.dimensionSizes().size(d)).orElse(0L); + long bSize = b.type().indexOfDimension(currentDimension).map(d -> b.dimensionSizes().size(d)).orElse(0L); if (currentDimension.equals(concatDimension)) concatSizes.set(i, aSize + bSize); else if (aSize != 0 && bSize != 0 && aSize!=bSize ) @@ -148,8 +148,8 @@ public class Concat extends PrimitiveTensorFunction { * (in some other dimension than the concat dimension) */ private TensorAddress combineAddresses(TensorAddress a, int[] aToIndexes, TensorAddress b, int[] bToIndexes, - TensorType concatType, int concatOffset, String concatDimension) { - int[] combinedLabels = new int[concatType.dimensions().size()]; + TensorType concatType, long concatOffset, String concatDimension) { + long[] combinedLabels = new long[concatType.dimensions().size()]; Arrays.fill(combinedLabels, -1); int concatDimensionIndex = concatType.indexOfDimension(concatDimension).get(); mapContent(a, combinedLabels, aToIndexes, concatDimensionIndex, concatOffset); // note: This sets a nonsensical value in the concat dimension @@ -179,15 +179,15 @@ public class Concat extends PrimitiveTensorFunction { * @return true if the mapping was successful, false if one of the destination positions was * occupied by a different value */ - private boolean mapContent(TensorAddress from, int[] to, int[] indexMap, int concatDimension, int concatOffset) { + private boolean mapContent(TensorAddress from, long[] to, int[] indexMap, int concatDimension, long concatOffset) { for (int i = 0; i < from.size(); i++) { int toIndex = indexMap[i]; if (concatDimension == toIndex) { - to[toIndex] = from.intLabel(i) + concatOffset; + to[toIndex] = from.numericLabel(i) + concatOffset; } else { - if (to[toIndex] != -1 && to[toIndex] != from.intLabel(i)) return false; - to[toIndex] = from.intLabel(i); + if (to[toIndex] != -1 && to[toIndex] != from.numericLabel(i)) return false; + to[toIndex] = from.numericLabel(i); } } return true; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java index c75d8ee4753..653be8dacf0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Diag.java @@ -17,7 +17,7 @@ import java.util.stream.Stream; public class Diag extends CompositeTensorFunction { private final TensorType type; - private final Function<List<Integer>, Double> diagFunction; + private final Function<List<Long>, Double> diagFunction; public Diag(TensorType type) { this.type = type; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java index e42d25197e2..ef2770c04f5 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Generate.java @@ -22,17 +22,17 @@ import java.util.function.Function; public class Generate extends PrimitiveTensorFunction { private final TensorType type; - private final Function<List<Integer>, Double> generator; + private final Function<List<Long>, Double> generator; /** * Creates a generated tensor * * @param type the type of the tensor - * @param generator the function generating values from a list of ints specifying the indexes of the + * @param generator the function generating values from a list of numbers specifying the indexes of the * tensor cell which will receive the value * @throws IllegalArgumentException if any of the tensor dimensions are not indexed bound */ - public Generate(TensorType type, Function<List<Integer>, Double> generator) { + public Generate(TensorType type, Function<List<Long>, Double> generator) { Objects.requireNonNull(type, "The argument tensor type cannot be null"); Objects.requireNonNull(generator, "The argument function cannot be null"); validateType(type); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index ff887e3e9a6..174a8e4c435 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -56,8 +56,8 @@ public class Join extends PrimitiveTensorFunction { if (aDim.name().equals(bDim.name())) { // include if (aDim.isIndexed() && bDim.isIndexed()) { if (aDim.size().isPresent() || bDim.size().isPresent()) - typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Integer.MAX_VALUE), - bDim.size().orElse(Integer.MAX_VALUE))); + typeBuilder.indexed(aDim.name(), Math.min(aDim.size().orElse(Long.MAX_VALUE), + bDim.size().orElse(Long.MAX_VALUE))); else typeBuilder.indexed(aDim.name()); } @@ -118,11 +118,11 @@ public class Join extends PrimitiveTensorFunction { } private Tensor indexedVectorJoin(IndexedTensor a, IndexedTensor b, TensorType type) { - int joinedLength = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); + long joinedRank = Math.min(a.dimensionSizes().size(0), b.dimensionSizes().size(0)); Iterator<Double> aIterator = a.valueIterator(); Iterator<Double> bIterator = b.valueIterator(); - IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedLength).build()); - for (int i = 0; i < joinedLength; i++) + IndexedTensor.Builder builder = IndexedTensor.Builder.of(type, new DimensionSizes.Builder(1).set(0, joinedRank).build()); + for (int i = 0; i < joinedRank; i++) builder.cell(combinator.applyAsDouble(aIterator.next(), bIterator.next()), i); return builder.build(); } @@ -169,10 +169,10 @@ public class Join extends PrimitiveTensorFunction { return builder.build(); } - private void joinSubspaces(Iterator<Double> subspace, int subspaceSize, - Iterator<Tensor.Cell> superspace, int superspaceSize, + private void joinSubspaces(Iterator<Double> subspace, long subspaceSize, + Iterator<Tensor.Cell> superspace, long superspaceSize, boolean reversedArgumentOrder, IndexedTensor.Builder builder) { - int joinedLength = Math.min(subspaceSize, superspaceSize); + long joinedLength = Math.min(subspaceSize, superspaceSize); if (reversedArgumentOrder) { for (int i = 0; i < joinedLength; i++) { Tensor.Cell supercell = superspace.next(); @@ -281,7 +281,7 @@ public class Join extends PrimitiveTensorFunction { PartialAddress.Builder builder = new PartialAddress.Builder(retainDimensions.size()); for (int i = 0; i < addressType.dimensions().size(); i++) if (retainDimensions.contains(addressType.dimensions().get(i).name())) - builder.add(addressType.dimensions().get(i).name(), address.intLabel(i)); + builder.add(addressType.dimensions().get(i).name(), address.numericLabel(i)); return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java index a56f82b026a..8e7f4e4c773 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Range.java @@ -18,7 +18,7 @@ import java.util.stream.Stream; public class Range extends CompositeTensorFunction { private final TensorType type; - private final Function<List<Integer>, Double> rangeFunction; + private final Function<List<Long>, Double> rangeFunction; public Range(TensorType type) { this.type = type; diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java index fb5029fbfd6..f1dadba2a29 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ScalarFunctions.java @@ -14,8 +14,8 @@ import java.util.stream.Collectors; /** * Factory of scalar Java functions. * The purpose of this is to embellish anonymous functions with a runtime type - * such that they can be inspected and will return a parseable toString. - * + * such that they can be inspected and will return a parsable toString. + * * @author bratseth */ @Beta @@ -31,9 +31,9 @@ public class ScalarFunctions { public static DoubleUnaryOperator sqrt() { return new Sqrt(); } public static DoubleUnaryOperator square() { return new Square(); } - public static Function<List<Integer>, Double> random() { return new Random(); } - public static Function<List<Integer>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } - public static Function<List<Integer>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } + public static Function<List<Long>, Double> random() { return new Random(); } + public static Function<List<Long>, Double> equal(List<String> argumentNames) { return new EqualElements(argumentNames); } + public static Function<List<Long>, Double> sum(List<String> argumentNames) { return new SumElements(argumentNames); } // Binary operators ----------------------------------------------------------------------------- @@ -60,7 +60,7 @@ public class ScalarFunctions { public static class Multiply implements DoubleBinaryOperator { @Override - public double applyAsDouble(double left, double right) { return left * right; } + public double applyAsDouble(double left, double right) { return left * right; } @Override public String toString() { return "f(a,b)(a * b)"; } } @@ -100,26 +100,26 @@ public class ScalarFunctions { // Variable-length operators ----------------------------------------------------------------------------- - public static class EqualElements implements Function<List<Integer>, Double> { - private final ImmutableList<String> argumentNames; + public static class EqualElements implements Function<List<Long>, Double> { + private final ImmutableList<String> argumentNames; private EqualElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @Override - public Double apply(List<Integer> values) { + public Double apply(List<Long> values) { if (values.isEmpty()) return 1.0; - for (Integer value : values) + for (Long value : values) if ( ! value.equals(values.get(0))) return 0.0; return 1.0; } @Override - public String toString() { + public String toString() { if (argumentNames.size() == 0) return "1"; if (argumentNames.size() == 1) return "1"; if (argumentNames.size() == 2) return argumentNames.get(0) + "==" + argumentNames.get(1); - + StringBuilder b = new StringBuilder(); for (int i = 0; i < argumentNames.size() -1; i++) { b.append("(").append(argumentNames.get(i)).append("==").append(argumentNames.get(i+1)).append(")"); @@ -130,25 +130,25 @@ public class ScalarFunctions { } } - public static class Random implements Function<List<Integer>, Double> { + public static class Random implements Function<List<Long>, Double> { @Override - public Double apply(List<Integer> values) { + public Double apply(List<Long> values) { return ThreadLocalRandom.current().nextDouble(); } @Override public String toString() { return "random"; } } - public static class SumElements implements Function<List<Integer>, Double> { + public static class SumElements implements Function<List<Long>, Double> { private final ImmutableList<String> argumentNames; private SumElements(List<String> argumentNames) { this.argumentNames = ImmutableList.copyOf(argumentNames); } @Override - public Double apply(List<Integer> values) { - int sum = 0; - for (Integer value : values) + public Double apply(List<Long> values) { + long sum = 0; + for (Long value : values) sum += value; return (double)sum; } |