summaryrefslogtreecommitdiffstats
path: root/vespajlib
diff options
context:
space:
mode:
authorArne Juul <arnej@yahoo-inc.com>2019-08-20 12:21:37 +0000
committerArne Juul <arnej@yahoo-inc.com>2019-08-20 12:21:37 +0000
commit7df067cfb84f0d6e00e87bf69276d7a353c9f972 (patch)
treec537a1291c9f91a47e7a660cc49de11f722783bb /vespajlib
parentd88f2b235136691dcf08014cca60121ad2e3b62a (diff)
use same rules for cell value type resolving as C++
* pick cell value type from tensors with dimensions only * in Concat, use the expected combined cell value type for unit tensor
Diffstat (limited to 'vespajlib')
-rw-r--r--vespajlib/abi-spec.json1
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/TensorType.java12
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java9
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java3
-rw-r--r--vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java3
5 files changed, 19 insertions, 9 deletions
diff --git a/vespajlib/abi-spec.json b/vespajlib/abi-spec.json
index 9c425570a7e..6b37d58f3c7 100644
--- a/vespajlib/abi-spec.json
+++ b/vespajlib/abi-spec.json
@@ -1383,6 +1383,7 @@
"public"
],
"methods": [
+ "public static varargs com.yahoo.tensor.TensorType$Value combinedValueType(com.yahoo.tensor.TensorType[])",
"public static com.yahoo.tensor.TensorType fromSpec(java.lang.String)",
"public com.yahoo.tensor.TensorType$Value valueType()",
"public int rank()",
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
index 319947607d2..d64a62143f4 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/TensorType.java
@@ -87,6 +87,16 @@ public class TensorType {
this.dimensions = ImmutableList.copyOf(dimensionList);
}
+ static public Value combinedValueType(TensorType ... types) {
+ List<Value> valueTypes = new ArrayList<>();
+ for (TensorType type : types) {
+ if (type.rank() > 0) {
+ valueTypes.add(type.valueType());
+ }
+ }
+ return Value.largestOf(valueTypes);
+ }
+
/**
* Returns a tensor type instance from a string on the format
* <code>tensor(dimension1, dimension2, ...)</code>
@@ -456,7 +466,7 @@ public class TensorType {
* The value type will be the largest of the value types of the input types
*/
public Builder(TensorType ... types) {
- this.valueType = TensorType.Value.largestOf(Arrays.stream(types).map(type -> type.valueType()).collect(Collectors.toList()));
+ this.valueType = TensorType.combinedValueType(types);
for (TensorType type : types)
addDimensionsOf(type);
}
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
index a48ac19fbff..42c6fe2f4aa 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Concat.java
@@ -89,8 +89,9 @@ public class Concat extends PrimitiveTensorFunction {
public <NAMETYPE extends TypeContext.Name> Tensor evaluate(EvaluationContext<NAMETYPE> context) {
Tensor a = argumentA.evaluate(context);
Tensor b = argumentB.evaluate(context);
- a = ensureIndexedDimension(dimension, a);
- b = ensureIndexedDimension(dimension, b);
+ TensorType.Value combinedValueType = TensorType.combinedValueType(a.type(), b.type());
+ a = ensureIndexedDimension(dimension, a, combinedValueType);
+ b = ensureIndexedDimension(dimension, b, combinedValueType);
IndexedTensor aIndexed = (IndexedTensor) a; // If you get an exception here you have implemented a mixed tensor
IndexedTensor bIndexed = (IndexedTensor) b;
@@ -128,7 +129,7 @@ public class Concat extends PrimitiveTensorFunction {
}
}
- private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor) {
+ private Tensor ensureIndexedDimension(String dimensionName, Tensor tensor, TensorType.Value combinedValueType) {
Optional<TensorType.Dimension> dimension = tensor.type().dimension(dimensionName);
if ( dimension.isPresent() ) {
if ( ! dimension.get().isIndexed())
@@ -141,7 +142,7 @@ public class Concat extends PrimitiveTensorFunction {
if (tensor.type().dimensions().stream().anyMatch(d -> ! d.isIndexed()))
throw new IllegalArgumentException("Concat requires an indexed tensor, " +
"but got a tensor with type " + tensor.type());
- Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(tensor.type().valueType())
+ Tensor unitTensor = Tensor.Builder.of(new TensorType.Builder(combinedValueType)
.indexed(dimensionName, 1)
.build())
.cell(1,0)
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
index 062e0d92e80..2939b964f04 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/Join.java
@@ -390,8 +390,7 @@ public class Join extends PrimitiveTensorFunction {
private static TensorType commonDimensions(Tensor a, Tensor b) {
TensorType aType = a.type();
TensorType bType = b.type();
- TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.Value.largestOf(aType.valueType(),
- bType.valueType()));
+ TensorType.Builder typeBuilder = new TensorType.Builder(TensorType.combinedValueType(aType, bType));
for (int i = 0; i < aType.dimensions().size(); ++i) {
TensorType.Dimension aDim = aType.dimensions().get(i);
for (int j = 0; j < bType.dimensions().size(); ++j) {
diff --git a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
index db950e6c8b9..1134e8177ad 100644
--- a/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
+++ b/vespajlib/src/main/java/com/yahoo/tensor/functions/ReduceJoin.java
@@ -268,8 +268,7 @@ public class ReduceJoin extends CompositeTensorFunction {
}
private TensorType dimensionsInCommon(IndexedTensor a, IndexedTensor b) {
- TensorType.Builder builder = new TensorType.Builder(TensorType.Value.largestOf(a.type().valueType(),
- b.type().valueType()));
+ TensorType.Builder builder = new TensorType.Builder(TensorType.combinedValueType(a.type(), b.type()));
for (TensorType.Dimension aDim : a.type().dimensions()) {
for (TensorType.Dimension bDim : b.type().dimensions()) {
if (aDim.name().equals(bDim.name())) {