aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
committerJon Bratseth <bratseth@verizonmedia.com>2019-04-03 21:30:28 +0200
commit5792d3a23890edaa5d32b0f6bfc726c3e9956f3a (patch)
tree2b65d4f48b92bf7ec846b3efd5d5259244bc234a /model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
parent6eb80166172e10255841fd3d3cf70bed09d3d8c1 (diff)
Add tensor value type
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java26
1 files changed, 10 insertions, 16 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
index 1a564661ccb..7ae50a0549d 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ConcatV2.java
@@ -21,20 +21,15 @@ public class ConcatV2 extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) {
- return null;
- }
+ if ( ! inputs.stream().map(IntermediateOperation::type).allMatch(Optional::isPresent)) return null;
IntermediateOperation concatDimOp = inputs.get(inputs.size() - 1); // ConcatV2: concat dimension is the last input
- if (!concatDimOp.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a constant.");
- }
+ if ( ! concatDimOp.getConstantValue().isPresent())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a constant.");
+
Tensor concatDimTensor = concatDimOp.getConstantValue().get().asTensor();
- if (concatDimTensor.type().rank() != 0) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "concat dimension must be a scalar.");
- }
+ if (concatDimTensor.type().rank() != 0)
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Concat dimension must be a scalar.");
OrderedTensorType aType = inputs.get(0).type().get();
concatDimensionIndex = (int)concatDimTensor.asDouble();
@@ -42,10 +37,9 @@ public class ConcatV2 extends IntermediateOperation {
for (int i = 1; i < inputs.size() - 1; ++i) {
OrderedTensorType bType = inputs.get(i).type().get();
- if (bType.rank() != aType.rank()) {
- throw new IllegalArgumentException("ConcatV2 in " + name + ": " +
- "inputs must have save rank.");
- }
+ if (bType.rank() != aType.rank())
+ throw new IllegalArgumentException("ConcatV2 in " + name + ": Inputs must have the same rank.");
+
for (int j = 0; j < aType.rank(); ++j) {
long dimSizeA = aType.dimensions().get(j).size().orElse(-1L);
long dimSizeB = bType.dimensions().get(j).size().orElse(-1L);
@@ -58,7 +52,7 @@ public class ConcatV2 extends IntermediateOperation {
}
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
int dimensionIndex = 0;
for (TensorType.Dimension dimension : aType.dimensions()) {
if (dimensionIndex == concatDimensionIndex) {