diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-04-03 21:30:28 +0200 |
commit | 5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch) | |
tree | 2b65d4f48b92bf7ec846b3efd5d5259244bc234a /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java | |
parent | 6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff) |
Add tensor value type
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java | 26 |
1 files changed, 10 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java index 1a564661ccb..7ae50a0549d 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java @@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) { - return null; - } + if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null; IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input - if (!concatDimOp.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a constant."); - } + if ( ! concatDimOp.getConstantValue().isPresent()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant."); + Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor(); - if (concatDimTensor.type().rank() != 0) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "concat dimension must be a scalar."); - } + if (concatDimTensor.type().rank() != 0) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar."); OrderedTensorType aType = inputs.get(0).type().get(); concatDimensionIndex = (int)concatDimTensor.asDouble(); @@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation { for (int i = 1; i < inputs.size() - 1; ++i) { OrderedTensorType bType = inputs.get(i).type().get(); - if (bType.rank() != aType.rank()) { - throw new IllegalArgumentException("ConcatV2 in " + name + ": " + - "inputs must have save rank."); - } + if (bType.rank() != aType.rank()) + throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank."); + for (int j = 0; j < aType.rank(); ++j) { long dimSizeA = aType.dimensions().get(j).size().orElse(-1L); long dimSizeB = bType.dimensions().get(j).size().orElse(-1L); @@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation { } } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); int dimensionIndex = 0; for (TensorType.Dimension dimension : aType.dimensions()) { if (dimensionIndex == concatDimensionIndex) { |