aboutsummaryrefslogtreecommitdiffstats
path: root/searchlib/src/main
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-03-08 10:37:23 +0100
committerJon Bratseth <bratseth@oath.com>2018-03-08 10:37:23 +0100
commit2005e5dd57b2bafd1150917560da8066b39d64f7 (patch)
tree48a3f6cd071cb13f9a6c2112c34fc9306a5792b2 /searchlib/src/main
parent4e24502412bbac1ff6e8ae4ee5fa26590996189f (diff)
Refactor only
Diffstat (limited to 'searchlib/src/main')
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java40
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java8
2 files changed, 24 insertions, 24 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
index db762d5ddb0..acc7875623b 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java
@@ -119,23 +119,20 @@ public class OrderedTensorType {
return true;
}
- public static void verifyType(NodeDef node, OrderedTensorType type) {
- if (type == null) {
- return;
- }
+ public void verifyType(NodeDef node) {
TensorShapeProto shape = tensorFlowShape(node);
- if (shape != null && type.type != null) {
- if (shape.getDimCount() != type.type.rank()) {
+ if (shape != null) {
+ if (shape.getDimCount() != type.rank()) {
throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " +
- "does not match Vespa shape");
+ "does not match Vespa shape");
}
- for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) {
- int vespaIndex = type.dimensionMap[tensorFlowIndex];
+ for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) {
+ int vespaIndex = dimensionMap[tensorFlowIndex];
TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex);
- TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex);
+ TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex);
if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) {
throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " +
- "does not match Vespa dimensions");
+ "does not match Vespa dimensions");
}
}
}
@@ -145,23 +142,23 @@ public class OrderedTensorType {
AttrValue attrValueList = node.getAttrMap().get("_output_shapes");
if (attrValueList == null) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "does not exist");
+ "does not exist");
}
if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) {
throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " +
- "is not of expected type");
+ "is not of expected type");
}
List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList();
return shapeList.get(0); // support multiple outputs?
}
- public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) {
- List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size());
- for (TensorType.Dimension dimension : type.dimensions) {
+ public OrderedTensorType rename(DimensionRenamer renamer) {
+ List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size());
+ for (TensorType.Dimension dimension : dimensions) {
String oldName = dimension.name();
Optional<String> newName = renamer.dimensionNameOf(oldName);
if (!newName.isPresent())
- return type; // presumably, already renamed
+ return this; // presumably, already renamed
TensorType.Dimension.Type dimensionType = dimension.type();
if (dimensionType == TensorType.Dimension.Type.indexedBound) {
renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get()));
@@ -210,20 +207,21 @@ public class OrderedTensorType {
if (size >= 0) {
if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) {
throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
+ "dimension types");
}
if (!vespaDimension.size().isPresent()) {
throw new IllegalArgumentException("Tensor dimension is indexed bound but does " +
- "not have a size");
+ "not have a size");
}
if (vespaDimension.size().get() != size) {
throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get());
+ "dimension sizes. TensorFlow: " + size + " Vespa: " +
+ vespaDimension.size().get());
}
} else {
if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) {
throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " +
- "dimension types");
+ "dimension types");
}
}
this.dimensions.add(vespaDimension);
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
index 5d711aac100..2533148e5be 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java
@@ -60,7 +60,9 @@ public abstract class TensorFlowOperation {
if (type == null) {
type = lazyGetType();
}
- OrderedTensorType.verifyType(node, type);
+ if (type != null) {
+ type.verifyType(node);
+ }
return Optional.ofNullable(type);
}
@@ -96,7 +98,7 @@ public abstract class TensorFlowOperation {
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
/** Performs dimension rename for this operation */
- public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); }
+ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
/** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */
public boolean isInput() { return false; }
@@ -131,7 +133,7 @@ public abstract class TensorFlowOperation {
}
if (inputs.size() != expected) {
throw new IllegalArgumentException("Expected " + expected + " inputs " +
- "for '" + node.getName() + "', got " + inputs.size());
+ "for '" + node.getName() + "', got " + inputs.size());
}
return inputs.stream().map(func).allMatch(Optional::isPresent);
}