summaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@yahoo-inc.com>2016-12-16 10:45:53 +0100
committerJon Bratseth <bratseth@yahoo-inc.com>2016-12-16 10:45:53 +0100
commit89f78acf41592f4a16b390a13c3e763907012c4f (patch)
tree3c3125f6755d30d128698ffdc7c4823805ec9a1b /vespajlib/src
parentb51e97ba22be851e7ad028edc0eaf62251988931 (diff)
Cleanup
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/Tensor.java64
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java1
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java6
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java17
4 files changed, 46 insertions, 42 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
index bd193ebc78c..4c666e675d0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/Tensor.java
@@ -17,6 +17,7 @@ import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
+import java.util.Optional;
import java.util.function.DoubleBinaryOperator;
import java.util.function.DoubleUnaryOperator;
import java.util.function.Function;
@@ -213,24 +214,8 @@ public interface Tensor {
* @param type the type of the tensor to return
* @param tensorString the tensor on the standard tensor string format
*/
- // TODO: Allow type:value syntax also when a type is sent as a separate argument
static Tensor from(TensorType type, String tensorString) {
- // TODO: Rewrite next 10 lines to 2
- boolean containsIndexedDimensions = false;
- boolean containsMappedDimensions = false;
- for (TensorType.Dimension dimension : type.dimensions()) {
- switch (dimension.type()) {
- case indexedBound: case indexedUnbound: containsIndexedDimensions = true; break;
- case mapped : containsMappedDimensions = true; break;
- default: throw new RuntimeException("Unknown dimension type: " + dimension);
- }
- }
- if (containsIndexedDimensions && containsMappedDimensions)
- throw new IllegalArgumentException("Mixed dimension types are not supported, got: " + type);
- if (containsMappedDimensions)
- return MappedTensor.from(type, tensorString);
- else // indexed or none
- return IndexedTensor.from(type, tensorString);
+ return from(tensorString, Optional.of(type));
}
/**
@@ -241,7 +226,7 @@ public interface Tensor {
* @param tensorString the tensor on the standard tensor string format
*/
static Tensor from(String tensorType, String tensorString) {
- return from(TensorType.fromSpec(tensorType), tensorString);
+ return from(tensorString, Optional.of(TensorType.fromSpec(tensorType)));
}
/**
@@ -249,18 +234,29 @@ public interface Tensor {
* If a type is not specified it is derived from the first cell of the tensor
*/
static Tensor from(String tensorString) {
+ return from(tensorString, Optional.empty());
+ }
+
+ static Tensor from(String tensorString, Optional<TensorType> type) {
tensorString = tensorString.trim();
try {
if (tensorString.startsWith("tensor(")) {
int colonIndex = tensorString.indexOf(':');
- String typeSpec = tensorString.substring(0, colonIndex);
- String valueSpec = tensorString.substring(colonIndex + 1);
- return from(TensorTypeParser.fromSpec(typeSpec), valueSpec);
+ String typeString = tensorString.substring(0, colonIndex);
+ String valueString = tensorString.substring(colonIndex + 1);
+ TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
+ if (type.isPresent() && ! type.get().equals(typeFromString))
+ throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
+ "passed type " + type);
+ return fromValueString(valueString, typeFromString);
}
else if (tensorString.startsWith("{")) {
- return from(typeFromCellString(tensorString), tensorString);
+ return fromValueString(tensorString, type.orElse(typeFromValueString(tensorString)));
}
else {
+ if (type.isPresent() && ! type.get().equals(TensorType.empty))
+ throw new IllegalArgumentException("Got zero-dimensional tensor '" + tensorString +
+ "but type is not empty but " + type.get());
return IndexedTensor.Builder.of(TensorType.empty).cell(Double.parseDouble(tensorString)).build();
}
}
@@ -270,12 +266,24 @@ public interface Tensor {
}
}
+ static Tensor fromValueString(String tensorCellString, TensorType type) {
+ boolean containsIndexedDimensions = type.dimensions().stream().anyMatch(d -> d.isIndexed());
+ boolean containsMappedDimensions = type.dimensions().stream().anyMatch(d -> !d.isIndexed());
+ if (containsIndexedDimensions && containsMappedDimensions)
+ throw new IllegalArgumentException("Mixed dimension types are not supported, got: " + type);
+ if (containsMappedDimensions)
+ return MappedTensor.from(type, tensorCellString);
+ else // indexed or none
+ return IndexedTensor.from(type, tensorCellString);
+ }
+
/** Derive the tensor type from the first address string in the given tensor string */
- static TensorType typeFromCellString(String s) {
+ static TensorType typeFromValueString(String s) {
s = s.substring(1).trim(); // remove tensor start
- int firstKeyOrEmptyTensorEnd = s.indexOf('}');
- String addressBody = s.substring(0, firstKeyOrEmptyTensorEnd).trim();
+ int firstKeyOrTensorEnd = s.indexOf('}');
+ String addressBody = s.substring(0, firstKeyOrTensorEnd).trim();
if (addressBody.isEmpty()) return TensorType.empty; // Empty tensor
+ if ( ! addressBody.startsWith("{")) return TensorType.empty; // Single value tensor
addressBody = addressBody.substring(1); // remove key start
if (addressBody.isEmpty()) return TensorType.empty; // Empty key
@@ -300,10 +308,10 @@ public interface Tensor {
boolean containsMapped = type.dimensions().stream().anyMatch( d -> ! d.isIndexed());
if (containsIndexed && containsMapped)
throw new IllegalArgumentException("Combining indexed and mapped dimensions is not supported yet");
- if (containsIndexed)
- return IndexedTensor.Builder.of(type);
- else
+ if (containsMapped)
return new MappedTensor.Builder(type);
+ else // indexed or empty
+ return IndexedTensor.Builder.of(type);
}
/** Return a cell builder */
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 66af7edb2ab..22ddcc33c92 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -146,7 +146,6 @@ public class Reduce extends PrimitiveTensorFunction {
private Tensor reduceIndexedVector(IndexedTensor argument) {
ValueAggregator valueAggregator = ValueAggregator.ofType(aggregator);
- System.out.println("Reducing " + argument);
for (int i = 0; i < argument.length(0); i++)
valueAggregator.aggregate(argument.get(i));
return IndexedTensor.Builder.of(TensorType.empty).cell((valueAggregator.aggregatedValue())).build();
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
index d561f37a316..da4b300c3ab 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/IndexedTensorTestCase.java
@@ -21,17 +21,15 @@ public class IndexedTensorTestCase {
@Test
public void testEmpty() {
Tensor empty = Tensor.Builder.of(TensorType.empty).build();
- // assertTrue(empty instanceof IndexedTensor); TODO: Fix that and test these same things for a single-value tensor
+ assertTrue(empty instanceof IndexedTensor);
assertTrue(empty.cells().isEmpty());
assertEquals("{}", empty.toString());
Tensor emptyFromString = Tensor.from(TensorType.empty, "{}");
assertEquals("{}", Tensor.from(TensorType.empty, "{}").toString());
assertTrue(emptyFromString.cells().isEmpty());
assertTrue(emptyFromString instanceof IndexedTensor);
- // assertEquals(empty, emptyFromString);
+ assertEquals(empty, emptyFromString);
- // TODO: Equality of different tensor types
-
Tensor singleValue = Tensor.Builder.of(TensorType.empty).cell(TensorAddress.empty, 3.5).build();
assertEquals("{3.5}", singleValue.toString());
assertEquals("{3.5}", Tensor.from(TensorType.empty, "{3.5}").toString());
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
index 5ce53f2604b..dc0b3c47c62 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TensorFunctionBenchmark.java
@@ -19,7 +19,7 @@ public class TensorFunctionBenchmark {
private final static Random random = new Random();
public double benchmark(int iterations, List<Tensor> modelVectors, TensorType.Dimension.Type dimensionType) {
- Tensor queryVector = generateVectors(1, 300, dimensionType).get(0);
+ Tensor queryVector = vectors(1, 300, dimensionType).get(0);
dotProduct(queryVector, modelVectors, Math.max(iterations/10, 10)); // warmup
long startTime = System.currentTimeMillis();
dotProduct(queryVector, modelVectors, iterations);
@@ -51,8 +51,7 @@ public class TensorFunctionBenchmark {
return largest;
}
- private static List<Tensor> generateVectors(int vectorCount, int vectorSize,
- TensorType.Dimension.Type dimensionType) {
+ private static List<Tensor> vectors(int vectorCount, int vectorSize, TensorType.Dimension.Type dimensionType) {
List<Tensor> tensors = new ArrayList<>();
TensorType type = new TensorType.Builder().dimension("x", dimensionType).build();
for (int i = 0; i < vectorCount; i++) {
@@ -65,8 +64,7 @@ public class TensorFunctionBenchmark {
return tensors;
}
- private static List<Tensor> generateMatrix(int vectorCount, int vectorSize,
- TensorType.Dimension.Type dimensionType) {
+ private static List<Tensor> matrix(int vectorCount, int vectorSize, TensorType.Dimension.Type dimensionType) {
TensorType type = new TensorType.Builder().dimension("i", dimensionType).dimension("x", dimensionType).build();
Tensor.Builder builder = Tensor.Builder.of(type);
for (int i = 0; i < vectorCount; i++) {
@@ -88,11 +86,11 @@ public class TensorFunctionBenchmark {
// - After adding type: 300 ms
// - After sorting dimensions: 100 ms
// - After special-casing single space: 2.4 ms
- time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
+ time = new TensorFunctionBenchmark().benchmark(5000, vectors(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
System.out.printf("Mapped vectors, time per join: %1$8.3f ms\n", time);
// Initial: 760 ms
// - After special-casing subspace: 15 ms
- time = new TensorFunctionBenchmark().benchmark(500, generateMatrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
+ time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.mapped), TensorType.Dimension.Type.mapped);
System.out.printf("Mapped matrix, time per join: %1$8.3f ms\n", time);
// ---------------- Indexed:
@@ -100,11 +98,12 @@ public class TensorFunctionBenchmark {
// - After special casing join: 3.6 ms
// - After special-casing reduce: 0.80 ms
// - After create IndexedTensor without builder: 0.41 ms
- time = new TensorFunctionBenchmark().benchmark(5000, generateVectors(100, 300, TensorType.Dimension.Type.indexedUnbound),TensorType.Dimension.Type.indexedUnbound);
+ // - After double-array backing: 0.09 ms
+ time = new TensorFunctionBenchmark().benchmark(10000, vectors(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound);
System.out.printf("Indexed vectors, time per join: %1$8.3f ms\n", time);
// Initial: 3500 ms
// - After special-casing subspace: 28 ms
- time = new TensorFunctionBenchmark().benchmark(500, generateMatrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound);
+ time = new TensorFunctionBenchmark().benchmark(500, matrix(100, 300, TensorType.Dimension.Type.indexedUnbound), TensorType.Dimension.Type.indexedUnbound);
System.out.printf("Indexed matrix, time per join: %1$8.3f ms\n", time);
}