summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json3
-rw-r--r--vespajlib/src/main/java/com/yahoo/protect/Validator.java39
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java20
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java35
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java5
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java13
8 files changed, 66 insertions, 58 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 9e9d32a5a6e..ebca0a4d852 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -679,7 +679,8 @@
"public static void ensureSmaller(java.lang.String, java.lang.Comparable, java.lang.String, java.lang.Comparable)",
"public static void ensure(java.lang.String, boolean)",
"public static varargs void ensure(boolean, java.lang.Object[])",
- "public static void ensureInstanceOf(java.lang.String, java.lang.Object, java.lang.Class)"
+ "public static void ensureInstanceOf(java.lang.String, java.lang.Object, java.lang.Class)",
+ "public static void ensureNotInstanceOf(java.lang.String, java.lang.Object, java.lang.Class)"
],
"fields": []
},
diff --git a/vespajlib/src/main/java/com/yahoo/protect/Validator.java b/vespajlib/src/main/java/com/yahoo/protect/Validator.java
index 49fe7716ba2..ee4a93c2f01 100644
--- a/vespajlib/src/main/java/com/yahoo/protect/Validator.java
+++ b/vespajlib/src/main/java/com/yahoo/protect/Validator.java
@@ -68,14 +68,10 @@ public abstract class Validator {
* Throws an IllegalArgumentException if the first argument is not strictly
* smaller than the second argument
*
- * @param smallDescription
- * description of the smallest argument
- * @param small
- * the smallest argument
- * @param largeDescription
- * description of the largest argument
- * @param large
- * the largest argument
+ * @param smallDescription description of the smallest argument
+ * @param small the smallest argument
+ * @param largeDescription description of the largest argument
+ * @param large the largest argument
*/
@SuppressWarnings({ "rawtypes", "unchecked" })
public static void ensureSmaller(String smallDescription, Comparable small, String largeDescription, Comparable large) {
@@ -115,14 +111,10 @@ public abstract class Validator {
/**
* Ensures that an item is of a particular class
*
- * @param description
- * a description of the item to be checked
- * @param item
- * the item to check the type of
- * @param type
- * the type the given item should be instanceof
- * @throws IllegalArgumentException
- * if the given item is not of the correct type
+ * @param description a description of the item to be checked
+ * @param item the item to check the type of
+ * @param type the type the given item should be instanceof
+ * @throws IllegalArgumentException if the given item is not of the correct type
*/
public static void ensureInstanceOf(String description, Object item, Class<?> type) {
if ( ! type.isAssignableFrom(item.getClass())) {
@@ -131,4 +123,19 @@ public abstract class Validator {
}
}
+ /**
+ * Ensures that an item is not of a particular class
+ *
+ * @param description a description of the item to be checked
+ * @param item the item to check the type of
+ * @param type the type the given item should NOT be instanceof
+ * @throws IllegalArgumentException if the given item is of the wrong type
+ */
+ public static void ensureNotInstanceOf(String description, Object item, Class<?> type) {
+ if ( type.isAssignableFrom(item.getClass())) {
+ throw new IllegalArgumentException(description + " " + item + " should NOT be an instance of " + type +
+ " but is " + item.getClass());
+ }
+ }
+
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
index 37a4bf375d0..651bec6a1aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
@@ -126,6 +126,14 @@ public class TypeResolver {
first.name().equals(second.name()));
}
+ private static boolean firstIsSmaller(Dimension first, Dimension second) {
+ return (first.type() == Dimension.Type.indexedBound &&
+ second.type() == Dimension.Type.indexedBound &&
+ first.name().equals(second.name()) &&
+ first.size().isPresent() && second.size().isPresent() &&
+ first.size().get() < second.size().get());
+ }
+
static public TensorType join(TensorType lhs, TensorType rhs) {
Value cellType = Value.DOUBLE;
if (lhs.rank() > 0 && rhs.rank() > 0) {
@@ -153,6 +161,10 @@ public class TypeResolver {
map.put(dim.name(), dim);
} else if (firstIsBoundSecond(other, dim)) {
map.put(dim.name(), other);
+ } else if (dim.isMapped() && other.isIndexed()) {
+ map.put(dim.name(), dim); // {} and [] -> {}. Note: this is not allowed in C++
+ } else if (dim.isIndexed() && other.isMapped()) {
+ map.put(dim.name(), other); // {} and [] -> {}. Note: this is not allowed in C++
} else {
throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
}
@@ -215,9 +227,13 @@ public class TypeResolver {
Dimension other = map.get(dim.name());
if (! other.equals(dim)) {
if (firstIsBoundSecond(dim, other)) {
- map.put(dim.name(), dim);
+ map.put(dim.name(), other); // [N] and [] -> []
} else if (firstIsBoundSecond(other, dim)) {
- map.put(dim.name(), other);
+ map.put(dim.name(), dim); // [N] and [] -> []
+ } else if (firstIsSmaller(dim, other)) {
+ map.put(dim.name(), dim); // [N] and [M] -> [ min(N,M] ].
+ } else if (firstIsSmaller(other, dim)) {
+ map.put(dim.name(), other); // [N] and [M] -> [ min(N,M] ].
} else {
throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index c6f8171bd18..fe8b2f417aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -47,7 +48,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType(valueType, argument.type(context).dimensions());
+ return TypeResolver.cell_cast(argument.type(context), valueType);
}
@Override
@@ -56,12 +57,11 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
if (tensor.type().valueType() == valueType) {
return tensor;
}
- TensorType type = new TensorType(valueType, tensor.type().dimensions());
+ TensorType type = TypeResolver.cell_cast(tensor.type(), valueType);
return cast(tensor, type);
}
private Tensor cast(Tensor tensor, TensorType type) {
- Tensor.Builder builder = Tensor.Builder.of(type);
TensorType.Value fromValueType = tensor.type().valueType();
switch (fromValueType) {
case DOUBLE:
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index fff2ddaf320..59a452588ca 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -60,44 +61,20 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return type(argumentA.type(context), argumentB.type(context));
- }
-
- /** Returns the type resulting from concatenating a and b */
- private TensorType type(TensorType a, TensorType b) {
- // TODO: Fail if concat dimension is present but not indexed in a or b
- TensorType.Builder builder = new TensorType.Builder(a, b);
- if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
- builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
- b.sizeOfDimension(dimension).orElse(1L)));
- /*
- MutableLong concatSize = new MutableLong(0);
- a.sizeOfDimension(dimension).ifPresent(concatSize::add);
- b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
- }
- return builder.build();
- }
-
- /** Returns true if this dimension is present and unbound */
- private boolean unboundIn(TensorType type, String dimensionName) {
- Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
- return dimension.isPresent() && ! dimension.get().size().isPresent();
+ return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
}
@Override
public Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
- a = ensureIndexedDimension(dimension, a, combinedValueType);
- b = ensureIndexedDimension(dimension, b, combinedValueType);
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
+
+ a = ensureIndexedDimension(dimension, a, concatType.valueType());
+ b = ensureIndexedDimension(dimension, b, concatType.valueType());
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
-
- TensorType concatType = type(a.type(), b.type());
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 5419d04a4fb..d43b7889982 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -49,7 +50,7 @@ public class Join<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETYP
/** Returns the type resulting from applying Join to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
try {
- return new TensorType.Builder(false, a, b).build();
+ return TypeResolver.join(a, b);
}
catch (IllegalArgumentException e) {
throw new IllegalArgumentException("Can not join " + a + " and " + b, e);
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
index d5633bde36c..4aa09f3f4e3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Merge.java
@@ -9,6 +9,7 @@ import com.yahoo.tensor.PartialAddress;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -48,9 +49,7 @@ public class Merge<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMETY
/** Returns the type resulting from applying Merge to the two given types */
public static TensorType outputType(TensorType a, TensorType b) {
- Optional<TensorType> outputType = a.dimensionwiseGeneralizationWith(b);
- if (outputType.isPresent()) return outputType.get();
- throw new IllegalArgumentException("Cannot merge " + a + " and " + b + ": Arguments must have compatible types");
+ return TypeResolver.merge(a, b);
}
public DoubleBinaryOperator merger() { return merger; }
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
index 8e4205c8c27..7eee50c6785 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
@@ -77,10 +77,14 @@ public class TypeResolverTestCase {
checkJoin("tensor(x{})", "tensor<bfloat16>(y{})", "tensor(x{},y{})");
checkJoin("tensor(x{})", "tensor<float>(y{})", "tensor(x{},y{})");
checkJoin("tensor(x{})", "tensor<int8>(y{})", "tensor(x{},y{})");
+ // specific for Java
+ checkJoin("tensor(x[])", "tensor(x{})", "tensor(x{})");
+ checkJoin("tensor(x[3])", "tensor(x{})", "tensor(x{})");
+ checkJoin("tensor(x{})", "tensor(x[])", "tensor(x{})");
+ checkJoin("tensor(x{})", "tensor(x[3])", "tensor(x{})");
// dimension mismatch should fail:
checkJoinFails("tensor(x[3])", "tensor(x[5])");
checkJoinFails("tensor(x[5])", "tensor(x[3])");
- checkJoinFails("tensor(x{})", "tensor(x[5])");
}
@Test
@@ -156,6 +160,7 @@ public class TypeResolverTestCase {
checkMerge("tensor(x{},y{})", "tensor<float>(x{},y{})", "tensor(x{},y{})");
checkMerge("tensor(x{},y{})", "tensor<int8>(x{},y{})", "tensor(x{},y{})");
checkMerge("tensor(y{})", "tensor(y{})", "tensor(y{})");
+ checkMerge("tensor(x{})", "tensor(x[5])", "tensor(x{})");
checkMergeFails("tensor(a[10])", "tensor()");
checkMergeFails("tensor(a[10])", "tensor(x{},y{},z{})");
checkMergeFails("tensor<bfloat16>(x[5])", "tensor()");
@@ -168,7 +173,6 @@ public class TypeResolverTestCase {
checkMergeFails("tensor(x[3])", "tensor(x[5])");
checkMergeFails("tensor(x[5])", "tensor(x[3])");
checkMergeFails("tensor(x{})", "tensor()");
- checkMergeFails("tensor(x{})", "tensor(x[5])");
checkMergeFails("tensor(x{},y{})", "tensor(x{},z{})");
checkMergeFails("tensor(y{})", "tensor()");
}
@@ -221,11 +225,14 @@ public class TypeResolverTestCase {
checkConcat("tensor<float>(x[3])", "tensor()", "x", "tensor<float>(x[4])");
checkConcat("tensor<bfloat16>(x[3])", "tensor()", "x", "tensor<bfloat16>(x[4])");
checkConcat("tensor<int8>(x[3])", "tensor()", "x", "tensor<int8>(x[4])");
+ // specific for Java
+ checkConcat("tensor(x[])", "tensor(x[2])", "x", "tensor(x[])");
+ checkConcat("tensor(x[])", "tensor(x[2])", "y", "tensor(x[],y[2])");
+ checkConcat("tensor(x[3])", "tensor(x[2])", "y", "tensor(x[2],y[2])");
// invalid combinations must fail
checkConcatFails("tensor(x{})", "tensor(x[2])", "x");
checkConcatFails("tensor(x{})", "tensor(x{})", "x");
checkConcatFails("tensor(x{})", "tensor()", "x");
- checkConcatFails("tensor(x[3])", "tensor(x[2])", "y");
}
@Test