aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib/src
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
commit5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch)
tree2b65d4f48b92bf7ec846b3efd5d5259244bc234a /vespajlib/src
parent6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff)
Add tensor value type
Diffstat (limited to 'vespajlib/src')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java10
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java4
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java27
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java22
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java7
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java8
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java2
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java4
-rw-r--r--vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java10
-rw-r--r--vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java38
12 files changed, 74 insertions, 67 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
index 08878edeb83..c06cb2a0986 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java
@@ -319,7 +319,7 @@ public class MixedTensor implements Tensor {
}
public TensorType createBoundType() {
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(type().valueType());
for (int i = 0; i < type.dimensions().size(); ++i) {
TensorType.Dimension dimension = type.dimensions().get(i);
if (!dimension.isIndexed()) {
@@ -355,8 +355,8 @@ public class MixedTensor implements Tensor {
this.type = type;
this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList());
- this.sparseType = createPartialType(mappedDimensions);
- this.denseType = createPartialType(indexedDimensions);
+ this.sparseType = createPartialType(type.valueType(), mappedDimensions);
+ this.denseType = createPartialType(type.valueType(), indexedDimensions);
}
public long indexOf(TensorAddress address) {
@@ -476,8 +476,8 @@ public class MixedTensor implements Tensor {
}
- public static TensorType createPartialType(List<TensorType.Dimension> dimensions) {
- TensorType.Builder builder = new TensorType.Builder();
+ public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) {
+ TensorType.Builder builder = new TensorType.Builder(valueType);
for (TensorType.Dimension dimension : dimensions) {
builder.set(dimension);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
index 998f3170aa0..45a9992c9ad 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java
@@ -18,7 +18,7 @@ class TensorParser {
TensorType typeFromString = TensorTypeParser.fromSpec(typeString);
if (type.isPresent() && ! type.get().equals(typeFromString))
throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " +
- "passed type " + type);
+ "passed type " + type.get());
return tensorFromValueString(valueString, typeFromString);
}
else if (tensorString.startsWith("{")) {
@@ -48,7 +48,7 @@ class TensorParser {
addressBody = addressBody.substring(1); // remove key start
if (addressBody.isEmpty()) return TensorType.empty; // Empty key
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE);
for (String elementString : addressBody.split(",")) {
String[] pair = elementString.split(":");
if (pair.length != 2)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index bded55405c0..5bd44cbc327 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -25,8 +25,29 @@ import java.util.stream.Collectors;
public class TensorType {
/** The permissible cell value types. Default is double. */
- // Types added here must also be added to TensorTypeParser.parseValueTypeSpec
- public enum Value { DOUBLE, FLOAT};
+ public enum Value {
+
+ // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below
+ DOUBLE, FLOAT;
+
+ public static Value largestOf(List<Value> values) {
+ if (values.isEmpty()) return Value.DOUBLE; // Default
+ Value largest = null;
+ for (Value value : values) {
+ if (largest == null)
+ largest = value;
+ else
+ largest = largestOf(largest, value);
+ }
+ return largest;
+ }
+
+ public static Value largestOf(Value value1, Value value2) {
+ if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE;
+ return FLOAT;
+ }
+
+ };
/** The empty tensor type - which is the same as a double */
public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList());
@@ -170,7 +191,7 @@ public class TensorType {
if (this.equals(other)) return Optional.of(this); // shortcut
if (this.dimensions.size() != other.dimensions.size()) return Optional.empty();
- Builder b = new Builder();
+ Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType));
for (int i = 0; i < dimensions.size(); i++) {
Dimension thisDim = this.dimensions().get(i);
Dimension otherDim = other.dimensions().get(i);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
index a5733f1cc4c..d5f77be0dd0 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java
@@ -13,6 +13,7 @@ import java.util.regex.Pattern;
* Class for parsing a tensor type spec.
*
* @author geirst
+ * @author bratseth
*/
public class TensorTypeParser {
@@ -54,17 +55,24 @@ public class TensorTypeParser {
return new TensorType.Builder(valueType, dimensions).build();
}
+ public static TensorType.Value toValueType(String valueTypeString) {
+ switch (valueTypeString) {
+ case "double" : return TensorType.Value.DOUBLE;
+ case "float" : return TensorType.Value.FLOAT;
+ default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" +
+ " but was '" + valueTypeString + "'");
+ }
+ }
+
private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) {
if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">"))
throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>"));
- String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1);
- switch (valueType) {
- case "double" : return TensorType.Value.DOUBLE;
- case "float" : return TensorType.Value.FLOAT;
- default : throw formatException(fullSpecString,
- "Value type must be either 'double' or 'float'" +
- " but was '" + valueType + "'");
+ try {
+ return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1));
+ }
+ catch (IllegalArgumentException e) {
+ throw formatException(fullSpecString, e.getMessage());
}
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index 91ab4f9d046..a0a257bb909 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build();
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ .indexed(dimensionName, 1)
+ .build())
+ .cell(1,0)
+ .build();
return tensor.multiply(unitTensor);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 62ee471fcf4..062e0d92e80 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction {
return true;
}
- /**
- * Returns common dimension of a and b as a new tensor type
- */
+ /** Returns common dimension of a and b as a new tensor type */
private static TensorType commonDimensions(Tensor a, Tensor b) {
- TensorType.Builder typeBuilder = new TensorType.Builder();
TensorType aType = a.type();
TensorType bType = b.type();
+ TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(),
+ bType.valueType()));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
index 54d7710c9dc..017dc3920e6 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java
@@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) {
- if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder b = new TensorType.Builder();
+ TensorType.Builder b = new TensorType.Builder(inputType.valueType());
+ if (reduceDimensions.isEmpty()) return b.build(); // means reduce all
for (TensorType.Dimension dimension : inputType.dimensions()) {
if ( ! reduceDimensions.contains(dimension.name()))
b.dimension(dimension);
@@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction {
}
private static TensorType type(TensorType argumentType, List<String> dimensions) {
- if (dimensions.isEmpty()) return TensorType.empty; // means reduce all
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(argumentType.valueType());
+ if (dimensions.isEmpty()) return builder.build(); // means reduce all
for (TensorType.Dimension dimension : argumentType.dimensions())
if ( ! dimensions.contains(dimension.name())) // keep
builder.dimension(dimension);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index b268e33b418..db950e6c8b9 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction {
}
private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(),
+ b.type().valueType()));
for (TensorType.Dimension aDim : a.type().dimensions()) {
for (TensorType.Dimension bDim : b.type().dimensions()) {
if (aDim.name().equals(bDim.name())) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
index e18af235d59..5694684956e 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java
@@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction {
}
private TensorType type(TensorType type) {
- TensorType.Builder builder = new TensorType.Builder();
+ TensorType.Builder builder = new TensorType.Builder(type.valueType());
for (TensorType.Dimension dimension : type.dimensions())
builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name())));
return builder.build();
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
index acaeb3ef5ba..284dfea2141 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java
@@ -78,7 +78,7 @@ class MixedBinaryFormat implements BinaryFormat {
TensorType serializedType = decodeType(buffer);
if ( ! serializedType.isAssignableTo(type))
throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType +
- " cannot be assigned to type " + type);
+ " cannot be assigned to type " + type);
}
else {
type = decodeType(buffer);
@@ -103,7 +103,7 @@ class MixedBinaryFormat implements BinaryFormat {
private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) {
List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList());
- TensorType sparseType = MixedTensor.createPartialType(sparseDimensions);
+ TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions);
long denseSubspaceSize = builder.denseSubspaceSize();
int numBlocks = 1;
diff --git a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
index 9602bdb8d94..f6fed9d33ed 100644
--- a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java
@@ -69,16 +69,6 @@ public class BoundingBoxParserTestCase {
all1234(parser);
}
- /**
- * Tests various legal inputs and print the output
- */
- @Test
- public void testPrint() {
- String here = "n=63.418417 E=10.433033 S=37.7 W=-122.02";
- parser = new BoundingBoxParser(here);
- System.out.println(here+" -> "+parser);
- }
-
@Test
public void testGeoPlanetExample() {
/* example XML:
diff --git a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
index e8ceab44c78..7cf4bddaa01 100644
--- a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java
@@ -57,7 +57,6 @@ public class BinaryFormatTestCase {
@Test
public void testZigZagConversion() {
- System.out.println("test zigzag conversion");
assertThat(encode_zigzag(0), is((long)0));
assertThat(decode_zigzag(encode_zigzag(0)), is(0L));
@@ -88,7 +87,6 @@ public class BinaryFormatTestCase {
@Test
public void testDoubleConversion() {
- System.out.println("test double conversion");
assertThat(encode_double(0.0), is(0L));
assertThat(decode_double(encode_double(0.0)), is(0.0));
@@ -116,7 +114,6 @@ public class BinaryFormatTestCase {
@Test
public void testTypeAndMetaMangling() {
- System.out.println("test type and meta mangling");
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (int meta = 0; meta < META_LIMIT; ++meta) {
byte mangled = encode_type_and_meta(type, meta);
@@ -126,10 +123,8 @@ public class BinaryFormatTestCase {
}
}
- // was testCmprUlong
@Test
- public void testCmprLong() {
- System.out.println("test compressed long");
+ public void testCompressedLong() {
{
long value = 0;
byte[] wanted = { 0 };
@@ -217,11 +212,8 @@ public class BinaryFormatTestCase {
// testWriteBytes -> buffered IO test
// testReadByte -> buffered IO test
// testReadBytes -> buffered IO test
-
@Test
- public void testTypeAndSize() {
- System.out.println("test type and size conversion");
-
+ public void testTypeAndSizeConversion() {
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (long size = 0; size < 500; ++size) {
BufferedOutput expect = new BufferedOutput();
@@ -271,8 +263,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testTypeAndBytes() {
- System.out.println("test encoding and decoding of type and bytes");
+ public void testEncodingAndDecodingOfTypeAndBytes() {
for (byte type = 0; type < TYPE_LIMIT; ++type) {
for (int n = 0; n < MAX_NUM_SIZE; ++n) {
for (int pre = 0; (pre == 0) || (pre < n); ++pre) {
@@ -307,9 +298,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testEmpty() {
- System.out.println("test encoding empty slime");
-
+ public void testEncodingEmptySlime() {
Slime slime = new Slime();
BufferedOutput expect = new BufferedOutput();
expect.put((byte)0); // num symbols
@@ -321,8 +310,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testBasic() {
- System.out.println("test encoding slime holding a single basic value");
+ public void testEncodingSlimeHoldingASingleBasicValue() {
{
Slime slime = new Slime();
slime.setBool(false);
@@ -427,8 +415,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testArray() {
- System.out.println("test encoding slime holding an array of various basic values");
+ public void testEncodingSlimeArray() {
Slime slime = new Slime();
Cursor c = slime.setArray();
byte[] data = { 'd', 'a', 't', 'a' };
@@ -452,8 +439,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testObject() {
- System.out.println("test encoding slime holding an object of various basic values");
+ public void testEncodingSlimeObject() {
Slime slime = new Slime();
Cursor c = slime.setObject();
byte[] data = { 'd', 'a', 't', 'a' };
@@ -478,8 +464,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testNesting() {
- System.out.println("test encoding slime holding a more complex structure");
+ public void testEncodingComplexSlimeStructure() {
Slime slime = new Slime();
Cursor c1 = slime.setObject();
c1.setLong("bar", 10);
@@ -503,8 +488,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testSymbolReuse() {
- System.out.println("test encoding slime reusing symbols");
+ public void testEncodingSlimeReusingSymbols() {
Slime slime = new Slime();
Cursor c1 = slime.setArray();
{
@@ -533,8 +517,7 @@ public class BinaryFormatTestCase {
}
@Test
- public void testOptionalDecodeOrder() {
- System.out.println("test decoding slime with different symbol order");
+ public void testDecodingSlimeWithDifferentSymbolOrder() {
byte[] data = {
5, // num symbols
1, 'd', 1, 'e', 1, 'f', 1, 'b', 1, 'c', // symbol table
@@ -564,4 +547,5 @@ public class BinaryFormatTestCase {
assertThat(c.field("f").asData(), is(expd));
assertThat(c.entry(5).valid(), is(false)); // not ARRAY
}
+
}