diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a /vespajlib | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
Diffstat (limited to 'vespajlib')
13 files changed, 87 insertions, 79 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 239efa0f89c..b071566ae31 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -947,7 +947,7 @@ "public java.lang.String toString()", "public boolean equals(java.lang.Object)", "public long denseSubspaceSize()", - "public static com.yahoo.tensor.TensorType createPartialType(java.util.List)" + "public static com.yahoo.tensor.TensorType createPartialType(com.yahoo.tensor.TensorType$Value, java.util.List)" ], "fields": [] }, @@ -1162,11 +1162,11 @@ ], "methods": [ "public void <init>()", - "public void <init>(com.yahoo.tensor.TensorType$ValueType)", + "public void <init>(com.yahoo.tensor.TensorType$Value)", "public varargs void <init>(com.yahoo.tensor.TensorType[])", - "public varargs void <init>(com.yahoo.tensor.TensorType$ValueType, com.yahoo.tensor.TensorType[])", + "public varargs void <init>(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType[])", "public void <init>(java.lang.Iterable)", - "public void <init>(com.yahoo.tensor.TensorType$ValueType, java.lang.Iterable)", + "public void <init>(com.yahoo.tensor.TensorType$Value, java.lang.Iterable)", "public int rank()", "public com.yahoo.tensor.TensorType$Builder set(com.yahoo.tensor.TensorType$Dimension)", "public com.yahoo.tensor.TensorType$Builder indexed(java.lang.String, long)", @@ -1270,7 +1270,7 @@ ], "fields": [] }, - "com.yahoo.tensor.TensorType$ValueType": { + "com.yahoo.tensor.TensorType$Value": { "superClass": "java.lang.Enum", "interfaces": [], "attributes": [ @@ -1279,12 +1279,14 @@ "enum" ], "methods": [ - "public static com.yahoo.tensor.TensorType$ValueType[] values()", - "public static com.yahoo.tensor.TensorType$ValueType valueOf(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value[] values()", + "public static com.yahoo.tensor.TensorType$Value valueOf(java.lang.String)", + "public static com.yahoo.tensor.TensorType$Value largestOf(java.util.List)", + "public static com.yahoo.tensor.TensorType$Value largestOf(com.yahoo.tensor.TensorType$Value, com.yahoo.tensor.TensorType$Value)" ], "fields": [ - "public static final enum com.yahoo.tensor.TensorType$ValueType DOUBLE", - "public static final enum com.yahoo.tensor.TensorType$ValueType FLOAT" + "public static final enum com.yahoo.tensor.TensorType$Value DOUBLE", + "public static final enum com.yahoo.tensor.TensorType$Value FLOAT" ] }, "com.yahoo.tensor.TensorType": { @@ -1294,9 +1296,8 @@ "public" ], "methods": [ - "public final com.yahoo.tensor.TensorType$ValueType valueType()", - "public final com.yahoo.tensor.TensorType valueType(com.yahoo.tensor.TensorType$ValueType)", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", + "public com.yahoo.tensor.TensorType$Value valueType()", "public int rank()", "public java.util.List dimensions()", "public java.util.Set dimensionNames()", @@ -1325,7 +1326,7 @@ "methods": [ "public void <init>()", "public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)", - "public static java.util.List dimensionsFromSpec(java.lang.String)" + "public static com.yahoo.tensor.TensorType$Value toValueType(java.lang.String)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java index 08878edeb83..c06cb2a0986 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/MixedTensor.java @@ -319,7 +319,7 @@ public class MixedTensor implements Tensor { } public TensorType createBoundType() { - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(type().valueType()); for (int i = 0; i < type.dimensions().size(); ++i) { TensorType.Dimension dimension = type.dimensions().get(i); if (!dimension.isIndexed()) { @@ -355,8 +355,8 @@ public class MixedTensor implements Tensor { this.type = type; this.mappedDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); this.indexedDimensions = type.dimensions().stream().filter(d -> d.isIndexed()).collect(Collectors.toList()); - this.sparseType = createPartialType(mappedDimensions); - this.denseType = createPartialType(indexedDimensions); + this.sparseType = createPartialType(type.valueType(), mappedDimensions); + this.denseType = createPartialType(type.valueType(), indexedDimensions); } public long indexOf(TensorAddress address) { @@ -476,8 +476,8 @@ public class MixedTensor implements Tensor { } - public static TensorType createPartialType(List<TensorType.Dimension> dimensions) { - TensorType.Builder builder = new TensorType.Builder(); + public static TensorType createPartialType(TensorType.Value valueType, List<TensorType.Dimension> dimensions) { + TensorType.Builder builder = new TensorType.Builder(valueType); for (TensorType.Dimension dimension : dimensions) { builder.set(dimension); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java index 998f3170aa0..45a9992c9ad 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorParser.java @@ -18,7 +18,7 @@ class TensorParser { TensorType typeFromString = TensorTypeParser.fromSpec(typeString); if (type.isPresent() && ! type.get().equals(typeFromString)) throw new IllegalArgumentException("Got tensor with type string '" + typeString + "', but was " + - "passed type " + type); + "passed type " + type.get()); return tensorFromValueString(valueString, typeFromString); } else if (tensorString.startsWith("{")) { @@ -48,7 +48,7 @@ class TensorParser { addressBody = addressBody.substring(1); // remove key start if (addressBody.isEmpty()) return TensorType.empty; // Empty key - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.DOUBLE); for (String elementString : addressBody.split(",")) { String[] pair = elementString.split(":"); if (pair.length != 2) diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java index bded55405c0..5bd44cbc327 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java @@ -25,8 +25,29 @@ import java.util.stream.Collectors; public class TensorType { /** The permissible cell value types. Default is double. */ - // Types added here must also be added to TensorTypeParser.parseValueTypeSpec - public enum Value { DOUBLE, FLOAT}; + public enum Value { + + // Types added must also be added to TensorTypeParser.parseValueTypeSpec, serialization, and largestOf below + DOUBLE, FLOAT; + + public static Value largestOf(List<Value> values) { + if (values.isEmpty()) return Value.DOUBLE; // Default + Value largest = null; + for (Value value : values) { + if (largest == null) + largest = value; + else + largest = largestOf(largest, value); + } + return largest; + } + + public static Value largestOf(Value value1, Value value2) { + if (value1 == DOUBLE || value2 == DOUBLE) return DOUBLE; + return FLOAT; + } + + }; /** The empty tensor type - which is the same as a double */ public static final TensorType empty = new TensorType(Value.DOUBLE, Collections.emptyList()); @@ -170,7 +191,7 @@ public class TensorType { if (this.equals(other)) return Optional.of(this); // shortcut if (this.dimensions.size() != other.dimensions.size()) return Optional.empty(); - Builder b = new Builder(); + Builder b = new Builder(TensorType.Value.largestOf(valueType, other.valueType)); for (int i = 0; i < dimensions.size(); i++) { Dimension thisDim = this.dimensions().get(i); Dimension otherDim = other.dimensions().get(i); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java index a5733f1cc4c..d5f77be0dd0 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorTypeParser.java @@ -13,6 +13,7 @@ import java.util.regex.Pattern; * Class for parsing a tensor type spec. * * @author geirst + * @author bratseth */ public class TensorTypeParser { @@ -54,17 +55,24 @@ public class TensorTypeParser { return new TensorType.Builder(valueType, dimensions).build(); } + public static TensorType.Value toValueType(String valueTypeString) { + switch (valueTypeString) { + case "double" : return TensorType.Value.DOUBLE; + case "float" : return TensorType.Value.FLOAT; + default : throw new IllegalArgumentException("Value type must be either 'double' or 'float'" + + " but was '" + valueTypeString + "'"); + } + } + private static TensorType.Value parseValueTypeSpec(String valueTypeSpec, String fullSpecString) { if ( ! valueTypeSpec.startsWith("<") || ! valueTypeSpec.endsWith(">")) throw formatException(fullSpecString, Optional.of("Value type spec must be enclosed in <>")); - String valueType = valueTypeSpec.substring(1, valueTypeSpec.length() - 1); - switch (valueType) { - case "double" : return TensorType.Value.DOUBLE; - case "float" : return TensorType.Value.FLOAT; - default : throw formatException(fullSpecString, - "Value type must be either 'double' or 'float'" + - " but was '" + valueType + "'"); + try { + return toValueType(valueTypeSpec.substring(1, valueTypeSpec.length() - 1)); + } + catch (IllegalArgumentException e) { + throw formatException(fullSpecString, e.getMessage()); } } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index 91ab4f9d046..a0a257bb909 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -141,7 +141,11 @@ public class Concat extends PrimitiveTensorFunction { if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed())) throw new IllegalArgumentException("Concat requires an indexed tensor, " + "but got a tensor with type " + tensor.type()); - Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder().indexed(dimensionName, 1).build()).cell(1,0).build(); + Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType()) + .indexed(dimensionName, 1) + .build()) + .cell(1,0) + .build(); return tensor.multiply(unitTensor); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 62ee471fcf4..062e0d92e80 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -386,13 +386,12 @@ public class Join extends PrimitiveTensorFunction { return true; } - /** - * Returns common dimension of a and b as a new tensor type - */ + /** Returns common dimension of a and b as a new tensor type */ private static TensorType commonDimensions(Tensor a, Tensor b) { - TensorType.Builder typeBuilder = new TensorType.Builder(); TensorType aType = a.type(); TensorType bType = b.type(); + TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(), + bType.valueType())); for (int i = 0; i < aType.dimensions().size(); ++i) { TensorType.Dimension aDim = aType.dimensions().get(i); for (int j = 0; j < bType.dimensions().size(); ++j) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java index 54d7710c9dc..017dc3920e6 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Reduce.java @@ -61,8 +61,8 @@ public class Reduce extends PrimitiveTensorFunction { } public static TensorType outputType(TensorType inputType, List<String> reduceDimensions) { - if (reduceDimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder b = new TensorType.Builder(); + TensorType.Builder b = new TensorType.Builder(inputType.valueType()); + if (reduceDimensions.isEmpty()) return b.build(); // means reduce all for (TensorType.Dimension dimension : inputType.dimensions()) { if ( ! reduceDimensions.contains(dimension.name())) b.dimension(dimension); @@ -109,8 +109,8 @@ public class Reduce extends PrimitiveTensorFunction { } private static TensorType type(TensorType argumentType, List<String> dimensions) { - if (dimensions.isEmpty()) return TensorType.empty; // means reduce all - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(argumentType.valueType()); + if (dimensions.isEmpty()) return builder.build(); // means reduce all for (TensorType.Dimension dimension : argumentType.dimensions()) if ( ! dimensions.contains(dimension.name())) // keep builder.dimension(dimension); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java index b268e33b418..db950e6c8b9 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java @@ -268,7 +268,8 @@ public class ReduceJoin extends CompositeTensorFunction { } private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(), + b.type().valueType())); for (TensorType.Dimension aDim : a.type().dimensions()) { for (TensorType.Dimension bDim : b.type().dimensions()) { if (aDim.name().equals(bDim.name())) { diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java index e18af235d59..5694684956e 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Rename.java @@ -75,7 +75,7 @@ public class Rename extends PrimitiveTensorFunction { } private TensorType type(TensorType type) { - TensorType.Builder builder = new TensorType.Builder(); + TensorType.Builder builder = new TensorType.Builder(type.valueType()); for (TensorType.Dimension dimension : type.dimensions()) builder.dimension(dimension.withName(fromToMap.getOrDefault(dimension.name(), dimension.name()))); return builder.build(); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java index acaeb3ef5ba..284dfea2141 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/serialization/MixedBinaryFormat.java @@ -78,7 +78,7 @@ class MixedBinaryFormat implements BinaryFormat { TensorType serializedType = decodeType(buffer); if ( ! serializedType.isAssignableTo(type)) throw new IllegalArgumentException("Type/instance mismatch: A tensor of type " + serializedType + - " cannot be assigned to type " + type); + " cannot be assigned to type " + type); } else { type = decodeType(buffer); @@ -103,7 +103,7 @@ class MixedBinaryFormat implements BinaryFormat { private void decodeCells(GrowableByteBuffer buffer, MixedTensor.BoundBuilder builder, TensorType type) { List<TensorType.Dimension> sparseDimensions = type.dimensions().stream().filter(d -> !d.isIndexed()).collect(Collectors.toList()); - TensorType sparseType = MixedTensor.createPartialType(sparseDimensions); + TensorType sparseType = MixedTensor.createPartialType(type.valueType(), sparseDimensions); long denseSubspaceSize = builder.denseSubspaceSize(); int numBlocks = 1; diff --git a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java index 9602bdb8d94..f6fed9d33ed 100644 --- a/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/geo/BoundingBoxParserTestCase.java @@ -69,16 +69,6 @@ public class BoundingBoxParserTestCase { all1234(parser); } - /** - * Tests various legal inputs and print the output - */ - @Test - public void testPrint() { - String here = "n=63.418417 E=10.433033 S=37.7 W=-122.02"; - parser = new BoundingBoxParser(here); - System.out.println(here+" -> "+parser); - } - @Test public void testGeoPlanetExample() { /* example XML: diff --git a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java index e8ceab44c78..7cf4bddaa01 100644 --- a/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/slime/BinaryFormatTestCase.java @@ -57,7 +57,6 @@ public class BinaryFormatTestCase { @Test public void testZigZagConversion() { - System.out.println("test zigzag conversion"); assertThat(encode_zigzag(0), is((long)0)); assertThat(decode_zigzag(encode_zigzag(0)), is(0L)); @@ -88,7 +87,6 @@ public class BinaryFormatTestCase { @Test public void testDoubleConversion() { - System.out.println("test double conversion"); assertThat(encode_double(0.0), is(0L)); assertThat(decode_double(encode_double(0.0)), is(0.0)); @@ -116,7 +114,6 @@ public class BinaryFormatTestCase { @Test public void testTypeAndMetaMangling() { - System.out.println("test type and meta mangling"); for (byte type = 0; type < TYPE_LIMIT; ++type) { for (int meta = 0; meta < META_LIMIT; ++meta) { byte mangled = encode_type_and_meta(type, meta); @@ -126,10 +123,8 @@ public class BinaryFormatTestCase { } } - // was testCmprUlong @Test - public void testCmprLong() { - System.out.println("test compressed long"); + public void testCompressedLong() { { long value = 0; byte[] wanted = { 0 }; @@ -217,11 +212,8 @@ public class BinaryFormatTestCase { // testWriteBytes -> buffered IO test // testReadByte -> buffered IO test // testReadBytes -> buffered IO test - @Test - public void testTypeAndSize() { - System.out.println("test type and size conversion"); - + public void testTypeAndSizeConversion() { for (byte type = 0; type < TYPE_LIMIT; ++type) { for (long size = 0; size < 500; ++size) { BufferedOutput expect = new BufferedOutput(); @@ -271,8 +263,7 @@ public class BinaryFormatTestCase { } @Test - public void testTypeAndBytes() { - System.out.println("test encoding and decoding of type and bytes"); + public void testEncodingAndDecodingOfTypeAndBytes() { for (byte type = 0; type < TYPE_LIMIT; ++type) { for (int n = 0; n < MAX_NUM_SIZE; ++n) { for (int pre = 0; (pre == 0) || (pre < n); ++pre) { @@ -307,9 +298,7 @@ public class BinaryFormatTestCase { } @Test - public void testEmpty() { - System.out.println("test encoding empty slime"); - + public void testEncodingEmptySlime() { Slime slime = new Slime(); BufferedOutput expect = new BufferedOutput(); expect.put((byte)0); // num symbols @@ -321,8 +310,7 @@ public class BinaryFormatTestCase { } @Test - public void testBasic() { - System.out.println("test encoding slime holding a single basic value"); + public void testEncodingSlimeHoldingASingleBasicValue() { { Slime slime = new Slime(); slime.setBool(false); @@ -427,8 +415,7 @@ public class BinaryFormatTestCase { } @Test - public void testArray() { - System.out.println("test encoding slime holding an array of various basic values"); + public void testEncodingSlimeArray() { Slime slime = new Slime(); Cursor c = slime.setArray(); byte[] data = { 'd', 'a', 't', 'a' }; @@ -452,8 +439,7 @@ public class BinaryFormatTestCase { } @Test - public void testObject() { - System.out.println("test encoding slime holding an object of various basic values"); + public void testEncodingSlimeObject() { Slime slime = new Slime(); Cursor c = slime.setObject(); byte[] data = { 'd', 'a', 't', 'a' }; @@ -478,8 +464,7 @@ public class BinaryFormatTestCase { } @Test - public void testNesting() { - System.out.println("test encoding slime holding a more complex structure"); + public void testEncodingComplexSlimeStructure() { Slime slime = new Slime(); Cursor c1 = slime.setObject(); c1.setLong("bar", 10); @@ -503,8 +488,7 @@ public class BinaryFormatTestCase { } @Test - public void testSymbolReuse() { - System.out.println("test encoding slime reusing symbols"); + public void testEncodingSlimeReusingSymbols() { Slime slime = new Slime(); Cursor c1 = slime.setArray(); { @@ -533,8 +517,7 @@ public class BinaryFormatTestCase { } @Test - public void testOptionalDecodeOrder() { - System.out.println("test decoding slime with different symbol order"); + public void testDecodingSlimeWithDifferentSymbolOrder() { byte[] data = { 5, // num symbols 1, 'd', 1, 'e', 1, 'f', 1, 'b', 1, 'c', // symbol table @@ -564,4 +547,5 @@ public class BinaryFormatTestCase { assertThat(c.field("f").asData(), is(expd)); assertThat(c.entry(5).valid(), is(false)); // not ARRAY } + } |