diff options
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java')
-rw-r--r-- | searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java | 8 |
1 files changed, 3 insertions, 5 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java index e91c2305f7d..ff87412396d 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java @@ -2,8 +2,8 @@ package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations; import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue; -import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType; +import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode; import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator; import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode; @@ -24,8 +24,6 @@ import java.util.Iterator; import java.util.List; import java.util.stream.Collectors; -import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize; - public class Reshape extends IntermediateOperation { public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) { @@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation { int size = cell.getValue().intValue(); if (size < 0) { size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() / - tensorSize(inputType.type()).intValue(); + OrderedTensorType.tensorSize(inputType.type()).intValue(); } outputTypeBuilder.add(TensorType.Dimension.indexed( String.format("%s_%d", vespaName(), dimensionIndex), size)); @@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation { } public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) { - if (!tensorSize(inputType).equals(tensorSize(outputType))) { + if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) { throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping"); } |