summaryrefslogtreecommitdiffstats
path: root/vespajlib/src/main/java/com/yahoo/tensor/functions
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-04-22 09:21:05 +0200
committerLester Solbakken <lesters@oath.com>2021-04-22 09:21:05 +0200
commitbd531088df9eb9f7e083fd03d91e5d3f19a1664b (patch)
tree4c891ef1e0f41244fc150fabe955f163425e2bcc /vespajlib/src/main/java/com/yahoo/tensor/functions
parent1f7363ff53144c40fd27c4332b1cb3619b1525d6 (diff)
Wire in tensor cell type resolving for concat in Java
Diffstat (limited to 'vespajlib/src/main/java/com/yahoo/tensor/functions')
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java6
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java28
2 files changed, 6 insertions, 28 deletions
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
index c6f8171bd18..fe8b2f417aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/CellCast.java
@@ -3,6 +3,7 @@ package com.yahoo.tensor.functions;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -47,7 +48,7 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return new TensorType(valueType, argument.type(context).dimensions());
+ return TypeResolver.cell_cast(argument.type(context), valueType);
}
@Override
@@ -56,12 +57,11 @@ public class CellCast<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAM
if (tensor.type().valueType() == valueType) {
return tensor;
}
- TensorType type = new TensorType(valueType, tensor.type().dimensions());
+ TensorType type = TypeResolver.cell_cast(tensor.type(), valueType);
return cast(tensor, type);
}
private Tensor cast(Tensor tensor, TensorType type) {
- Tensor.Builder builder = Tensor.Builder.of(type);
TensorType.Value fromValueType = tensor.type().valueType();
switch (fromValueType) {
case DOUBLE:
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index fff2ddaf320..e71d1c717d3 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -7,6 +7,7 @@ import com.yahoo.tensor.IndexedTensor;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorAddress;
import com.yahoo.tensor.TensorType;
+import com.yahoo.tensor.TypeResolver;
import com.yahoo.tensor.evaluation.EvaluationContext;
import com.yahoo.tensor.evaluation.Name;
import com.yahoo.tensor.evaluation.TypeContext;
@@ -60,30 +61,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
@Override
public TensorType type(TypeContext<NAMETYPE> context) {
- return type(argumentA.type(context), argumentB.type(context));
- }
-
- /** Returns the type resulting from concatenating a and b */
- private TensorType type(TensorType a, TensorType b) {
- // TODO: Fail if concat dimension is present but not indexed in a or b
- TensorType.Builder builder = new TensorType.Builder(a, b);
- if ( ! unboundIn(a, dimension) && ! unboundIn(b, dimension)) {
- builder.set(TensorType.Dimension.indexed(dimension, a.sizeOfDimension(dimension).orElse(1L) +
- b.sizeOfDimension(dimension).orElse(1L)));
- /*
- MutableLong concatSize = new MutableLong(0);
- a.sizeOfDimension(dimension).ifPresent(concatSize::add);
- b.sizeOfDimension(dimension).ifPresent(concatSize::add);
- builder.set(TensorType.Dimension.indexed(dimension, concatSize.get()));
- */
- }
- return builder.build();
- }
-
- /** Returns true if this dimension is present and unbound */
- private boolean unboundIn(TensorType type, String dimensionName) {
- Optional<TensorType.Dimension> dimension = type.dimension(dimensionName);
- return dimension.isPresent() && ! dimension.get().size().isPresent();
+ return TypeResolver.concat(argumentA.type(context), argumentB.type(context), dimension);
}
@Override
@@ -97,7 +75,7 @@ public class Concat<NAMETYPE extends Name> extends PrimitiveTensorFunction<NAMET
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
- TensorType concatType = type(a.type(), b.type());
+ TensorType concatType = TypeResolver.concat(a.type(), b.type(), dimension);
DimensionSizes concatSize = concatSize(concatType, aIndexed, bIndexed, dimension);
Tensor.Builder builder = Tensor.Builder.of(concatType, concatSize);