diff options
author | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-16 10:45:53 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@yahoo-inc.com> | 2016-12-16 10:45:53 +0100 |
commit | 89f78acf41592f4a16b390a13c3e763907012c4f (patch) | |
tree | 3c3125f6755d30d128698ffdc7c4823805ec9a1b /vespajlib/src | |
parent | b51e97ba22be851e7ad028edc0eaf62251988931 (diff) |
Cleanup
Diffstat (limited to 'vespajlib/src')
4 files changed, 46 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java index bd193ebc78c..4c666e675d0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java @@ -17,6 +17,7 @@ import java.util.Arrays; import java.util.Collections; import java.util.List; import java.util.Map; +import java.util.Optional; import java.util.function.DoubleBinaryOperator; import java.util.function.DoubleUnaryOperator; import java.util.function.Function; @@ -213,24 +214,8 @@ public interface Tensor { * @param type the type of the tensor to return * @param tensorString the tensor on the standard tensor string format */ - // TODO: Allow type:value syntax also when a type is sent as a separate argument static Tensor from(TensorType type, String tensorString) { - // TODO: Rewrite next 10 lines to 2 - boolean containsIndexedDimensions = false; - boolean containsMappedDimensions = false; - for (TensorType.Dimension dimension : type.dimensions()) { - switch (dimension.type()) { - case indexedBound: case indexedUnbound: containsIndexedDimensions = true; break; - case mapped : containsMappedDimensions = true; break; - default: throw new RuntimeException("Unknown dimension type: " + dimension); - } - } - if (containsIndexedDimensions && containsMappedDimensions) - throw new IllegalArgumentException("Mixed dimension types are not supported, got: " + type); - if (containsMappedDimensions) - return MappedTensor.from(type, tensorString); - else // indexed or none - return IndexedTensor.from(type, tensorString); + return from(tensorString, Optional.of(type)); } /** @@ -241,7 +226,7 @@ public interface Tensor { * @param tensorString the tensor on the standard tensor string format */ static Tensor from(String tensorType, String tensorString) { - return from(TensorType.fromSpec(tensorType), tensorString); + return from(tensorString, Optional.of(TensorType.fromSpec(tensorType))); } /** @@ -249,18 +234,29 @@ public interface Tensor { * If a type is not specified it is derived from the first cell of the tensor */ static Tensor from(String tensorString) { + return from(tensorString, Optional.empty()); + } + + static Tensor from(String tensorString, Optional<TensorType> type) { tensorString = tensorString.trim(); try { if (tensorString.startsWith("tensor(")) { int colonIndex = tensorString.indexOf(':'); - String typeSpec = tensorString.substring(0, colonIndex); - String valueSpec = tensorString.substring(colonIndex + 1); - return from(TensorTypeParser.fromSpec(typeSpec), valueSpec); + String typeString = tensorString.substring(0, colonIndex); + String valueString = tensorString.substring(colonIndex + 1); + TensorType typeFromString = TensorTypeParser.fromSpec(typeString); + if (type.isPresent() && ! type.get().equals(typeFromString)) + throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + + "passed type " + type); + return fromValueString(valueString, typeFromString); } else if (tensorString.startsWith("{")) { - return from(typeFromCellString(tensorString), tensorString); + return fromValueString(tensorString, type.orElse(typeFromValueString(tensorString))); } else { + if (type.isPresent() && ! type.get().equals(TensorType.empty)) + throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString + + "but type is not empty but " + type.get()); return IndexedTensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build(); } } @@ -270,12 +266,24 @@ public interface Tensor { } } + static Tensor fromValueString(String tensorCellString, TensorType type) { + boolean containsIndexedDimensions = type.dimensions().stream().anyMatch(d -> d.isIndexed()); + boolean containsMappedDimensions = type.dimensions().stream().anyMatch(d -> !d.isIndexed()); + if (containsIndexedDimensions && containsMappedDimensions) + throw new IllegalArgumentException("Mixed dimension types are not supported, got: " + type); + if (containsMappedDimensions) + return MappedTensor.from(type, tensorCellString); + else // indexed or none + return IndexedTensor.from(type, tensorCellString); + } + /** Derive the tensor type from the first address string in the given tensor string */ - static TensorType typeFromCellString(String s) { + static TensorType typeFromValueString(String s) { s = s.substring(1).trim(); // remove tensor start - int firstKeyOrEmptyTensorEnd = s.indexOf('}'); - String addressBody = s.substring(0, firstKeyOrEmptyTensorEnd).trim(); + int firstKeyOrTensorEnd = s.indexOf('}'); + String addressBody = s.substring(0, firstKeyOrTensorEnd).trim(); if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor + if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor addressBody = addressBody.substring(1); // remove key start if (addressBody.isEmpty()) return TensorType.empty; // Empty key @@ -300,10 +308,10 @@ public interface Tensor { boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed()); if (containsIndexed && containsMapped) throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet"); - if (containsIndexed) - return IndexedTensor.Builder.of(type); - else + if (containsMapped) return new MappedTensor.Builder(type); + else // indexed or empty + return IndexedTensor.Builder.of(type); } /** Return a cell builder */ diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 66af7edb2ab..22ddcc33c92 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -146,7 +146,6 @@ public class Reduce extends PrimitiveTensorFunction { private Tensor reduceIndexedVector(IndexedTensor argument) { ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator); - System.out.println("Reducing " + argument); for (int i = 0; i < argument.length(0); i++) valueAggregator.aggregate(argument.get(i)); return IndexedTensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build(); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java index d561f37a316..da4b300c3ab 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java @@ -21,17 +21,15 @@ public class IndexedTensorTestCase { @Test public void testEmpty() { Tensor empty = Tensor.Builder.of(TensorType.empty).build(); - // assertTrue(empty instanceof IndexedTensor); TODO: Fix that and test these same things for a single-value tensor + assertTrue(empty instanceof IndexedTensor); assertTrue(empty.cells().isEmpty()); assertEquals("{}", empty.toString()); Tensor emptyFromString = Tensor.from(TensorType.empty, "{}"); assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString()); assertTrue(emptyFromString.cells().isEmpty()); assertTrue(emptyFromString instanceof IndexedTensor); - // assertEquals(empty, emptyFromString); + assertEquals(empty, emptyFromString); - // TODO: Equality of different tensor types - Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build(); assertEquals("{3.5}", singleValue.toString()); assertEquals("{3.5}", Tensor.from(TensorType.empty, "{3.5}").toString()); diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java index 5ce53f2604b..dc0b3c47c62 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java @@ -19,7 +19,7 @@ public class TensorFunctionBenchmark { private final static Random random = new Random(); public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType) { - Tensor queryVector = generateVectors(1, 300, dimensionType).get(0); + Tensor queryVector = vectors(1, 300, dimensionType).get(0); dotProduct(queryVector, modelVectors, Math.max(iterations/10, 10)); // warmup long startTime = System.currentTimeMillis(); dotProduct(queryVector, modelVectors, iterations); @@ -51,8 +51,7 @@ public class TensorFunctionBenchmark { return largest; } - private static List<Tensor> generateVectors(int vectorCount, int vectorSize, - TensorType.Dimension.Type dimensionType) { + private static List<Tensor> vectors(int vectorCount, int vectorSize, TensorType.Dimension.Type dimensionType) { List<Tensor> tensors = new ArrayList<>(); TensorType type = new TensorType.Builder().dimension("x", dimensionType).build(); for (int i = 0; i < vectorCount; i++) { @@ -65,8 +64,7 @@ public class TensorFunctionBenchmark { return tensors; } - private static List<Tensor> generateMatrix(int vectorCount, int vectorSize, - TensorType.Dimension.Type dimensionType) { + private static List<Tensor> matrix(int vectorCount, int vectorSize, TensorType.Dimension.Type dimensionType) { TensorType type = new TensorType.Builder().dimension("i", dimensionType).dimension("x", dimensionType).build(); Tensor.Builder builder = Tensor.Builder.of(type); for (int i = 0; i < vectorCount; i++) { @@ -88,11 +86,11 @@ public class TensorFunctionBenchmark { // - After adding type: 300 ms // - After sorting dimensions: 100 ms // - After special-casing single space: 2.4 ms - time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); + time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time); // Initial: 760 ms // - After special-casing subspace: 15 ms - time = new TensorFunctionBenchmark().benchmark(500, generateMatrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); + time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped); System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time); // ---------------- Indexed: @@ -100,11 +98,12 @@ public class TensorFunctionBenchmark { // - After special casing join: 3.6 ms // - After special-casing reduce: 0.80 ms // - After create IndexedTensor without builder: 0.41 ms - time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.indexedUnbound),TensorType.Dimension.Type.indexedUnbound); + // - After double-array backing: 0.09 ms + time = new TensorFunctionBenchmark().benchmark(10000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound); System.out.printf("Indexed vectors, time per join: %1$8.3f ms\n", time); // Initial: 3500 ms // - After special-casing subspace: 28 ms - time = new TensorFunctionBenchmark().benchmark(500, generateMatrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound); + time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound); System.out.printf("Indexed matrix, time per join: %1$8.3f ms\n", time); } |