summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2018-11-28 15:32:29 +0100
committerLester Solbakken <lesters@oath.com>2018-11-28 15:32:29 +0100
commitf8dd562f1ea3951fe22e0576a66528eecbc85093 (patch)
tree98a25dc8c703d2501d3e079486d551db31a67491 /model-integration
parent92160a2154e6c3f4fbe964f0a89b89443ea668ee (diff)
Set correct concat dimension for rename
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java14
1 files changed, 7 insertions, 7 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index a21fc5ff2f7..1a564661ccb 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -13,6 +13,7 @@ import java.util.Optional;
public class ConcatV2 extends IntermediateOperation {
private String concatDimensionName;
+ private int concatDimensionIndex;
public ConcatV2(String modelName, String nodeName, List<IntermediateOperation> inputs) {
super(modelName, nodeName, inputs);
@@ -36,9 +37,8 @@ public class ConcatV2 extends IntermediateOperation {
}
OrderedTensorType aType = inputs.get(0).type().get();
-
- int concatDim = (int)concatDimTensor.asDouble();
- long concatDimSize = aType.dimensions().get(concatDim).size().orElse(-1L);
+ concatDimensionIndex = (int)concatDimTensor.asDouble();
+ long concatDimSize = aType.dimensions().get(concatDimensionIndex).size().orElse(-1L);
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
@@ -49,7 +49,7 @@ public class ConcatV2 extends IntermediateOperation {
for (int j = 0; j < aType.rank(); ++j) {
long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
- if (j == concatDim) {
+ if (j == concatDimensionIndex) {
concatDimSize += dimSizeB;
} else if (dimSizeA != dimSizeB) {
throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
@@ -61,7 +61,7 @@ public class ConcatV2 extends IntermediateOperation {
OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
- if (dimensionIndex == concatDim) {
+ if (dimensionIndex == concatDimensionIndex) {
concatDimensionName = dimension.name();
typeBuilder.add(TensorType.Dimension.indexed(concatDimensionName, concatDimSize));
} else {
@@ -93,8 +93,8 @@ public class ConcatV2 extends IntermediateOperation {
OrderedTensorType a = inputs.get(0).type().get();
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType b = inputs.get(i).type().get();
- String bDim = b.dimensions().get(i).name();
- String aDim = a.dimensions().get(i).name();
+ String bDim = b.dimensions().get(concatDimensionIndex).name();
+ String aDim = a.dimensions().get(concatDimensionIndex).name();
renamer.addConstraint(aDim, bDim, DimensionRenamer::equals, this);
}
}