aboutsummaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java21
1 files changed, 7 insertions, 14 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 8ae6d81b8d4..c64b9ded601 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -27,20 +27,15 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected OrderedTensorType lazyGetType() {
- if (!allInputTypesPresent(2)) {
- return null;
- }
+ if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
if (!axisOperation.getConstantValue().isPresent()) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis must be a constant.");
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
- if (axis.type().rank() != 0) {
- throw new IllegalArgumentException("ExpandDims in " + name + ": " +
- "axis argument must be a scalar.");
- }
+ if (axis.type().rank() != 0)
+ throw new IllegalArgumentException("ExpandDims in " + name + ": Axis argument must be a scalar.");
OrderedTensorType inputType = inputs.get(0).type().get();
int dimensionToInsert = (int)axis.asDouble();
@@ -48,7 +43,7 @@ public class ExpandDims extends IntermediateOperation {
dimensionToInsert = inputType.dimensions().size() - dimensionToInsert;
}
- OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder();
+ OrderedTensorType.Builder typeBuilder = new OrderedTensorType.Builder(resultValueType());
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
@@ -66,12 +61,10 @@ public class ExpandDims extends IntermediateOperation {
@Override
protected TensorFunction lazyGetFunction() {
- if (!allInputFunctionsPresent(2)) {
- return null;
- }
+ if ( ! allInputFunctionsPresent(2)) return null;
// multiply with a generated tensor created from the reduced dimensions
- TensorType.Builder typeBuilder = new TensorType.Builder();
+ TensorType.Builder typeBuilder = new TensorType.Builder(resultValueType());
for (String name : expandDimensions) {
typeBuilder.indexed(name, 1);
}