diff options
author | Lester Solbakken <lesters@oath.com> | 2018-11-28 15:32:29 +0100 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2018-11-28 15:32:29 +0100 |
commit | f8dd562f1ea3951fe22e0576a66528eecbc85093 (patch) | |
tree | 98a25dc8c703d2501d3e079486d551db31a67491 /model-integration | |
parent | 92160a2154e6c3f4fbe964f0a89b89443ea668ee (diff) |
Set correct concat dimension for rename
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java | 14 |
1 files changed, 7 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index a21fc5ff2f7..1a564661ccb 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -13,6 +13,7 @@ import java.util.Optional; public class ConcatV2 extends IntermediateOperation { private String concatDimensionName; + private int concatDimensionIndex; public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) { super(modelName, nodeName, inputs); @@ -36,9 +37,8 @@ public class ConcatV2 extends IntermediateOperation { } OrderedTensorType aType = inputs.get(0).type().get(); - - int concatDim = (int)concatDimTensor.asDouble(); - long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L); + concatDimensionIndex = (int)concatDimTensor.asDouble(); + long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L); for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); @@ -49,7 +49,7 @@ public class ConcatV2 extends IntermediateOperation { for (int j = 0; j < aType.rank(); ++j) { long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); - if (j == concatDim) { + if (j == concatDimensionIndex) { concatDimSize += dimSizeB; } else if (dimSizeA != dimSizeB) { throw new IllegalArgumentException("ConcatV2 in " + name + ": " + @@ -61,7 +61,7 @@ public class ConcatV2 extends IntermediateOperation { OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { - if (dimensionIndex == concatDim) { + if (dimensionIndex == concatDimensionIndex) { concatDimensionName = dimension.name(); typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize)); } else { @@ -93,8 +93,8 @@ public class ConcatV2 extends IntermediateOperation { OrderedTensorType a = inputs.get(0).type().get(); for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType b = inputs.get(i).type().get(); - String bDim = b.dimensions().get(i).name(); - String aDim = a.dimensions().get(i).name(); + String bDim = b.dimensions().get(concatDimensionIndex).name(); + String aDim = a.dimensions().get(concatDimensionIndex).name(); renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this); } } |