diff options
author | Lester Solbakken <lesters@oath.com> | 2018-03-07 12:51:12 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-03-07 12:51:42 +0100 |
commit | a212f025728be7b2c80a90624cd1f4c5ebfd40b8 (patch) | |
tree | 7de2a19b6db7b2d320b911849e51d634972c4571 /searchlib/src/main | |
parent | 7c1236b77b7e2264bc6199ac9d3ff974d81462d2 (diff) |
Fix TensorFlow dimension renaming at join
Diffstat (limited to 'searchlib/src/main')
2 files changed, 5 insertions, 3 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java index 3742e443a06..db762d5ddb0 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -42,6 +42,8 @@ public class OrderedTensorType { return this.type; } + public int rank() { return dimensions.size(); } + public List<TensorType.Dimension> dimensions() { return dimensions; } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java index aa27ba2684d..3e6e036636d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/Join.java @@ -61,10 +61,10 @@ public class Join extends TensorFlowOperation { // "end" (highest numbered) dimensions of the first, but I'm not sure whether this is generally true. // Anyway, we move the dimensions of b to the last dimensions of a (instead of by default, the first). - TensorType a = inputs.get(0).type().get().type(); - TensorType b = inputs.get(1).type().get().type(); + OrderedTensorType a = inputs.get(0).type().get(); + OrderedTensorType b = inputs.get(1).type().get(); if (a.rank() < b.rank()) { - TensorType temp = a; + OrderedTensorType temp = a; a = b; b = temp; } |