diff options
Diffstat (limited to 'vespajlib')
8 files changed, 66 insertions, 58 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json index 9e9d32a5a6e..ebca0a4d852 100644 --- a/vespajlib/abi-spec.json +++ b/vespajlib/abi-spec.json @@ -679,7 +679,8 @@ "public static void ensureSmaller(java.lang.String, java.lang.Comparable, java.lang.String, java.lang.Comparable)", "public static void ensure(java.lang.String, boolean)", "public static varargs void ensure(boolean, java.lang.Object[])", - "public static void ensureInstanceOf(java.lang.String, java.lang.Object, java.lang.Class)" + "public static void ensureInstanceOf(java.lang.String, java.lang.Object, java.lang.Class)", + "public static void ensureNotInstanceOf(java.lang.String, java.lang.Object, java.lang.Class)" ], "fields": [] }, diff --git a/vespajlib/src/main/java/com/yahoo/protect/Validator.java b/vespajlib/src/main/java/com/yahoo/protect/Validator.java index 49fe7716ba2..ee4a93c2f01 100644 --- a/vespajlib/src/main/java/com/yahoo/protect/Validator.java +++ b/vespajlib/src/main/java/com/yahoo/protect/Validator.java @@ -68,14 +68,10 @@ public abstract class Validator { * Throws an IllegalArgumentException if the first argument is not strictly * smaller than the second argument * - * @param smallDescription - * description of the smallest argument - * @param small - * the smallest argument - * @param largeDescription - * description of the largest argument - * @param large - * the largest argument + * @param smallDescription description of the smallest argument + * @param small the smallest argument + * @param largeDescription description of the largest argument + * @param large the largest argument */ @SuppressWarnings({ "rawtypes", "unchecked" }) public static void ensureSmaller(String smallDescription, Comparable small, String largeDescription, Comparable large) { @@ -115,14 +111,10 @@ public abstract class Validator { /** * Ensures that an item is of a particular class * - * @param description - * a description of the item to be checked - * @param item - * the item to check the type of - * @param type - * the type the given item should be instanceof - * @throws IllegalArgumentException - * if the given item is not of the correct type + * @param description a description of the item to be checked + * @param item the item to check the type of + * @param type the type the given item should be instanceof + * @throws IllegalArgumentException if the given item is not of the correct type */ public static void ensureInstanceOf(String description, Object item, Class<?> type) { if ( ! type.isAssignableFrom(item.getClass())) { @@ -131,4 +123,19 @@ public abstract class Validator { } } + /** + * Ensures that an item is not of a particular class + * + * @param description a description of the item to be checked + * @param item the item to check the type of + * @param type the type the given item should NOT be instanceof + * @throws IllegalArgumentException if the given item is of the wrong type + */ + public static void ensureNotInstanceOf(String description, Object item, Class<?> type) { + if ( type.isAssignableFrom(item.getClass())) { + throw new IllegalArgumentException(description + " " + item + " should NOT be an instance of " + type + + " but is " + item.getClass()); + } + } + } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java index 37a4bf375d0..651bec6a1aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java @@ -126,6 +126,14 @@ public class TypeResolver { first.name().equals(second.name())); } + private static boolean firstIsSmaller(Dimension first, Dimension second) { + return (first.type() == Dimension.Type.indexedBound && + second.type() == Dimension.Type.indexedBound && + first.name().equals(second.name()) && + first.size().isPresent() && second.size().isPresent() && + first.size().get() < second.size().get()); + } + static public TensorType join(TensorType lhs, TensorType rhs) { Value cellType = Value.DOUBLE; if (lhs.rank() > 0 && rhs.rank() > 0) { @@ -153,6 +161,10 @@ public class TypeResolver { map.put(dim.name(), dim); } else if (firstIsBoundSecond(other, dim)) { map.put(dim.name(), other); + } else if (dim.isMapped() && other.isIndexed()) { + map.put(dim.name(), dim); // {} and [] -> {}. Note: this is not allowed in C++ + } else if (dim.isIndexed() && other.isMapped()) { + map.put(dim.name(), other); // {} and [] -> {}. Note: this is not allowed in C++ } else { throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs); } @@ -215,9 +227,13 @@ public class TypeResolver { Dimension other = map.get(dim.name()); if (! other.equals(dim)) { if (firstIsBoundSecond(dim, other)) { - map.put(dim.name(), dim); + map.put(dim.name(), other); // [N] and [] -> [] } else if (firstIsBoundSecond(other, dim)) { - map.put(dim.name(), other); + map.put(dim.name(), dim); // [N] and [] -> [] + } else if (firstIsSmaller(dim, other)) { + map.put(dim.name(), dim); // [N] and [M] -> [ min(N,M] ]. + } else if (firstIsSmaller(other, dim)) { + map.put(dim.name(), other); // [N] and [M] -> [ min(N,M] ]. } else { throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs); } diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java index c6f8171bd18..fe8b2f417aa 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java @@ -3,6 +3,7 @@ package com.yahoo.tensor.functions; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -47,7 +48,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM @Override public TensorType type(TypeContext<NAMETYPE> context) { - return new TensorType(valueType, argument.type(context).dimensions()); + return TypeResolver.cell_cast(argument.type(context), valueType); } @Override @@ -56,12 +57,11 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM if (tensor.type().valueType() == valueType) { return tensor; } - TensorType type = new TensorType(valueType, tensor.type().dimensions()); + TensorType type = TypeResolver.cell_cast(tensor.type(), valueType); return cast(tensor, type); } private Tensor cast(Tensor tensor, TensorType type) { - Tensor.Builder builder = Tensor.Builder.of(type); TensorType.Value fromValueType = tensor.type().valueType(); switch (fromValueType) { case DOUBLE: diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java index fff2ddaf320..59a452588ca 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java @@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -60,44 +61,20 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET @Override public TensorType type(TypeContext<NAMETYPE> context) { - return type(argumentA.type(context), argumentB.type(context)); - } - - /** Returns the type resulting from concatenating a and b */ - private TensorType type(TensorType a, TensorType b) { - // TODO: Fail if concat dimension is present but not indexed in a or b - TensorType.Builder builder = new TensorType.Builder(a, b); - if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) { - builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) + - b.sizeOfDimension(dimension).orElse(1L))); - /* - MutableLong concatSize = new MutableLong(0); - a.sizeOfDimension(dimension).ifPresent(concatSize::add); - b.sizeOfDimension(dimension).ifPresent(concatSize::add); - builder.set(TensorType.Dimension.indexed(dimension, concatSize.get())); - */ - } - return builder.build(); - } - - /** Returns true if this dimension is present and unbound */ - private boolean unboundIn(TensorType type, String dimensionName) { - Optional<TensorType.Dimension> dimension = type.dimension(dimensionName); - return dimension.isPresent() && ! dimension.get().size().isPresent(); + return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension); } @Override public Tensor evaluate(EvaluationContext<NAMETYPE> context) { Tensor a = argumentA.evaluate(context); Tensor b = argumentB.evaluate(context); - TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type()); - a = ensureIndexedDimension(dimension, a, combinedValueType); - b = ensureIndexedDimension(dimension, b, combinedValueType); + TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension); + + a = ensureIndexedDimension(dimension, a, concatType.valueType()); + b = ensureIndexedDimension(dimension, b, concatType.valueType()); IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor IndexedTensor bIndexed = (IndexedTensor) b; - - TensorType concatType = type(a.type(), b.type()); DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension); Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java index 5419d04a4fb..d43b7889982 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -49,7 +50,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP /** Returns the type resulting from applying Join to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { try { - return new TensorType.Builder(false, a, b).build(); + return TypeResolver.join(a, b); } catch (IllegalArgumentException e) { throw new IllegalArgumentException("Can not join " + a + " and " + b, e); diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java index d5633bde36c..4aa09f3f4e3 100644 --- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java +++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java @@ -9,6 +9,7 @@ import com.yahoo.tensor.PartialAddress; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorAddress; import com.yahoo.tensor.TensorType; +import com.yahoo.tensor.TypeResolver; import com.yahoo.tensor.evaluation.EvaluationContext; import com.yahoo.tensor.evaluation.Name; import com.yahoo.tensor.evaluation.TypeContext; @@ -48,9 +49,7 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY /** Returns the type resulting from applying Merge to the two given types */ public static TensorType outputType(TensorType a, TensorType b) { - Optional<TensorType> outputType = a.dimensionwiseGeneralizationWith(b); - if (outputType.isPresent()) return outputType.get(); - throw new IllegalArgumentException("Cannot merge " + a + " and " + b + ": Arguments must have compatible types"); + return TypeResolver.merge(a, b); } public DoubleBinaryOperator merger() { return merger; } diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java index 8e4205c8c27..7eee50c6785 100644 --- a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java +++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java @@ -77,10 +77,14 @@ public class TypeResolverTestCase { checkJoin("tensor(x{})", "tensor<bfloat16>(y{})", "tensor(x{},y{})"); checkJoin("tensor(x{})", "tensor<float>(y{})", "tensor(x{},y{})"); checkJoin("tensor(x{})", "tensor<int8>(y{})", "tensor(x{},y{})"); + // specific for Java + checkJoin("tensor(x[])", "tensor(x{})", "tensor(x{})"); + checkJoin("tensor(x[3])", "tensor(x{})", "tensor(x{})"); + checkJoin("tensor(x{})", "tensor(x[])", "tensor(x{})"); + checkJoin("tensor(x{})", "tensor(x[3])", "tensor(x{})"); // dimension mismatch should fail: checkJoinFails("tensor(x[3])", "tensor(x[5])"); checkJoinFails("tensor(x[5])", "tensor(x[3])"); - checkJoinFails("tensor(x{})", "tensor(x[5])"); } @Test @@ -156,6 +160,7 @@ public class TypeResolverTestCase { checkMerge("tensor(x{},y{})", "tensor<float>(x{},y{})", "tensor(x{},y{})"); checkMerge("tensor(x{},y{})", "tensor<int8>(x{},y{})", "tensor(x{},y{})"); checkMerge("tensor(y{})", "tensor(y{})", "tensor(y{})"); + checkMerge("tensor(x{})", "tensor(x[5])", "tensor(x{})"); checkMergeFails("tensor(a[10])", "tensor()"); checkMergeFails("tensor(a[10])", "tensor(x{},y{},z{})"); checkMergeFails("tensor<bfloat16>(x[5])", "tensor()"); @@ -168,7 +173,6 @@ public class TypeResolverTestCase { checkMergeFails("tensor(x[3])", "tensor(x[5])"); checkMergeFails("tensor(x[5])", "tensor(x[3])"); checkMergeFails("tensor(x{})", "tensor()"); - checkMergeFails("tensor(x{})", "tensor(x[5])"); checkMergeFails("tensor(x{},y{})", "tensor(x{},z{})"); checkMergeFails("tensor(y{})", "tensor()"); } @@ -221,11 +225,14 @@ public class TypeResolverTestCase { checkConcat("tensor<float>(x[3])", "tensor()", "x", "tensor<float>(x[4])"); checkConcat("tensor<bfloat16>(x[3])", "tensor()", "x", "tensor<bfloat16>(x[4])"); checkConcat("tensor<int8>(x[3])", "tensor()", "x", "tensor<int8>(x[4])"); + // specific for Java + checkConcat("tensor(x[])", "tensor(x[2])", "x", "tensor(x[])"); + checkConcat("tensor(x[])", "tensor(x[2])", "y", "tensor(x[],y[2])"); + checkConcat("tensor(x[3])", "tensor(x[2])", "y", "tensor(x[2],y[2])"); // invalid combinations must fail checkConcatFails("tensor(x{})", "tensor(x[2])", "x"); checkConcatFails("tensor(x{})", "tensor(x{})", "x"); checkConcatFails("tensor(x{})", "tensor()", "x"); - checkConcatFails("tensor(x[3])", "tensor(x[2])", "y"); } @Test |