aboutsummaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-22 09:21:05 +0200
committerLester Solbakken <lesters@oath.com>2021-04-22 09:21:05 +0200
commitbd531088df9eb9f7e083fd03d91e5d3f19a1664b (patch)
tree4c891ef1e0f41244fc150fabe955f163425e2bcc /vespajlib
parent1f7363ff53144c40fd27c4332b1cb3619b1525d6 (diff)
Wire in tensor cell type resolving for concat in Java
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java16
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java28
-rw-r--r--vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java5
4 files changed, 24 insertions, 31 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
index f9bc9072cfa..651bec6a1aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TypeResolver.java
@@ -126,6 +126,14 @@ public class TypeResolver {
first.name().equals(second.name()));
}
+ private static boolean firstIsSmaller(Dimension first, Dimension second) {
+ return (first.type() == Dimension.Type.indexedBound &&
+ second.type() == Dimension.Type.indexedBound &&
+ first.name().equals(second.name()) &&
+ first.size().isPresent() && second.size().isPresent() &&
+ first.size().get() < second.size().get());
+ }
+
static public TensorType join(TensorType lhs, TensorType rhs) {
Value cellType = Value.DOUBLE;
if (lhs.rank() > 0 && rhs.rank() > 0) {
@@ -219,9 +227,13 @@ public class TypeResolver {
Dimension other = map.get(dim.name());
if (! other.equals(dim)) {
if (firstIsBoundSecond(dim, other)) {
- map.put(dim.name(), dim);
+ map.put(dim.name(), other); // [N] and [] -> []
} else if (firstIsBoundSecond(other, dim)) {
- map.put(dim.name(), other);
+ map.put(dim.name(), dim); // [N] and [] -> []
+ } else if (firstIsSmaller(dim, other)) {
+ map.put(dim.name(), dim); // [N] and [M] -> [ min(N,M] ].
+ } else if (firstIsSmaller(other, dim)) {
+ map.put(dim.name(), other); // [N] and [M] -> [ min(N,M] ].
} else {
throw new IllegalArgumentException("Unequal dimension " + dim.name() + " in " + lhs+ " and "+rhs);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index c6f8171bd18..fe8b2f417aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -47,7 +48,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType(valueType, argument.type(context).dimensions());
+ return TypeResolver.cell_cast(argument.type(context), valueType);
}
@Override
@@ -56,12 +57,11 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
if (tensor.type().valueType() == valueType) {
return tensor;
}
- TensorType type = new TensorType(valueType, tensor.type().dimensions());
+ TensorType type = TypeResolver.cell_cast(tensor.type(), valueType);
return cast(tensor, type);
}
private Tensor cast(Tensor tensor, TensorType type) {
- Tensor.Builder builder = Tensor.Builder.of(type);
TensorType.Value fromValueType = tensor.type().valueType();
switch (fromValueType) {
case DOUBLE:
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index fff2ddaf320..e71d1c717d3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -60,30 +61,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return type(argumentA.type(context), argumentB.type(context));
- }
-
- /** Returns the type resulting from concatenating a and b */
- private TensorType type(TensorType a, TensorType b) {
- // TODO: Fail if concat dimension is present but not indexed in a or b
- TensorType.Builder builder = new TensorType.Builder(a, b);
- if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
- builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
- b.sizeOfDimension(dimension).orElse(1L)));
- /*
- MutableLong concatSize = new MutableLong(0);
- a.sizeOfDimension(dimension).ifPresent(concatSize::add);
- b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
- }
- return builder.build();
- }
-
- /** Returns true if this dimension is present and unbound */
- private boolean unboundIn(TensorType type, String dimensionName) {
- Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
- return dimension.isPresent() && ! dimension.get().size().isPresent();
+ return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
}
@Override
@@ -97,7 +75,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
- TensorType concatType = type(a.type(), b.type());
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);
diff --git a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
index b7f998c6cf7..7eee50c6785 100644
--- a/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
+++ b/vespajlib/src/test/java/com/yahoo/tensor/TypeResolverTestCase.java
@@ -225,11 +225,14 @@ public class TypeResolverTestCase {
checkConcat("tensor<float>(x[3])", "tensor()", "x", "tensor<float>(x[4])");
checkConcat("tensor<bfloat16>(x[3])", "tensor()", "x", "tensor<bfloat16>(x[4])");
checkConcat("tensor<int8>(x[3])", "tensor()", "x", "tensor<int8>(x[4])");
+ // specific for Java
+ checkConcat("tensor(x[])", "tensor(x[2])", "x", "tensor(x[])");
+ checkConcat("tensor(x[])", "tensor(x[2])", "y", "tensor(x[],y[2])");
+ checkConcat("tensor(x[3])", "tensor(x[2])", "y", "tensor(x[2],y[2])");
// invalid combinations must fail
checkConcatFails("tensor(x{})", "tensor(x[2])", "x");
checkConcatFails("tensor(x{})", "tensor(x{})", "x");
checkConcatFails("tensor(x{})", "tensor()", "x");
- checkConcatFails("tensor(x[3])", "tensor(x[2])", "y");
}
@Test