diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java | 21 |
1 files changed, 7 insertions, 14 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 8ae6d81b8d4..c64b9ded601 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation { @Override protected OrderedTensorType lazyGetType() { - if (!allInputTypesPresent(2)) { - return null; - } + if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); if (!axisOperation.getConstantValue().isPresent()) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis must be a constant."); + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); - if (axis.type().rank() != 0) { - throw new IllegalArgumentException("ExpandDims in " + name + ": " + - "axis argument must be a scalar."); - } + if (axis.type().rank() != 0) + throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar."); OrderedTensorType inputType = inputs.get(0).type().get(); int dimensionToInsert = (int)axis.asDouble(); @@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation { dimensionToInsert = inputType.dimensions().size() - dimensionToInsert; } - OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(); + OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType()); expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { @@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation { @Override protected TensorFunction lazyGetFunction() { - if (!allInputFunctionsPresent(2)) { - return null; - } + if ( ! allInputFunctionsPresent(2)) return null; // multiply with a generated tensor created from the reduced dimensions - TensorType.Builder typeBuilder = new TensorType.Builder(); + TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType()); for (String name : expandDimensions) { typeBuilder.indexed(name, 1); } |