summaryrefslogtreecommitdiffstats
path: root/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
diff options
context:
space:
mode:
Diffstat (limited to 'searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java8
1 files changed, 3 insertions, 5 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
index e91c2305f7d..ff87412396d 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/importer/operations/Reshape.java
@@ -2,8 +2,8 @@
package com.yahoo.searchlib.rankingexpression.integration.ml.importer.operations;
import com.yahoo.searchlib.rankingexpression.evaluation.DoubleValue;
-import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType;
+import com.yahoo.searchlib.rankingexpression.integration.ml.importer.DimensionRenamer;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticNode;
import com.yahoo.searchlib.rankingexpression.rule.ArithmeticOperator;
import com.yahoo.searchlib.rankingexpression.rule.ComparisonNode;
@@ -24,8 +24,6 @@ import java.util.Iterator;
import java.util.List;
import java.util.stream.Collectors;
-import static com.yahoo.searchlib.rankingexpression.integration.ml.importer.OrderedTensorType.tensorSize;
-
public class Reshape extends IntermediateOperation {
public Reshape(String modelName, String nodeName, List<IntermediateOperation> inputs) {
@@ -52,7 +50,7 @@ public class Reshape extends IntermediateOperation {
int size = cell.getValue().intValue();
if (size < 0) {
size = -1 * (int)shape.reduce(Reduce.Aggregator.prod).asDouble() /
- tensorSize(inputType.type()).intValue();
+ OrderedTensorType.tensorSize(inputType.type()).intValue();
}
outputTypeBuilder.add(TensorType.Dimension.indexed(
String.format("%s_%d", vespaName(), dimensionIndex), size));
@@ -82,7 +80,7 @@ public class Reshape extends IntermediateOperation {
}
public static TensorFunction reshape(TensorFunction inputFunction, TensorType inputType, TensorType outputType) {
- if (!tensorSize(inputType).equals(tensorSize(outputType))) {
+ if (!OrderedTensorType.tensorSize(inputType).equals(OrderedTensorType.tensorSize(outputType))) {
throw new IllegalArgumentException("New and old shape of tensor must have the same size when reshaping");
}