diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-03-08 10:37:23 +0100 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-03-08 10:37:23 +0100 |
commit | 2005e5dd57b2bafd1150917560da8066b39d64f7 (patch) | |
tree | 48a3f6cd071cb13f9a6c2112c34fc9306a5792b2 /searchlib/src/main | |
parent | 4e24502412bbac1ff6e8ae4ee5fa26590996189f (diff) |
Refactor only
Diffstat (limited to 'searchlib/src/main')
2 files changed, 24 insertions, 24 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java index db762d5ddb0..acc7875623b 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/OrderedTensorType.java @@ -119,23 +119,20 @@ public class OrderedTensorType { return true; } - public static void verifyType(NodeDef node, OrderedTensorType type) { - if (type == null) { - return; - } + public void verifyType(NodeDef node) { TensorShapeProto shape = tensorFlowShape(node); - if (shape != null && type.type != null) { - if (shape.getDimCount() != type.type.rank()) { + if (shape != null) { + if (shape.getDimCount() != type.rank()) { throw new IllegalArgumentException("TensorFlow shape of '" + node.getName() + "' " + - "does not match Vespa shape"); + "does not match Vespa shape"); } - for (int tensorFlowIndex = 0; tensorFlowIndex < type.dimensions.size(); ++tensorFlowIndex) { - int vespaIndex = type.dimensionMap[tensorFlowIndex]; + for (int tensorFlowIndex = 0; tensorFlowIndex < dimensions.size(); ++tensorFlowIndex) { + int vespaIndex = dimensionMap[tensorFlowIndex]; TensorShapeProto.Dim tensorFlowDimension = shape.getDim(tensorFlowIndex); - TensorType.Dimension vespaDimension = type.type().dimensions().get(vespaIndex); + TensorType.Dimension vespaDimension = type().dimensions().get(vespaIndex); if (tensorFlowDimension.getSize() != vespaDimension.size().orElse(-1L)) { throw new IllegalArgumentException("TensorFlow dimensions of '" + node.getName() + "' " + - "does not match Vespa dimensions"); + "does not match Vespa dimensions"); } } } @@ -145,23 +142,23 @@ public class OrderedTensorType { AttrValue attrValueList = node.getAttrMap().get("_output_shapes"); if (attrValueList == null) { throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "does not exist"); + "does not exist"); } if (attrValueList.getValueCase() != AttrValue.ValueCase.LIST) { throw new IllegalArgumentException("_output_shapes attribute of '" + node.getName() + "' " + - "is not of expected type"); + "is not of expected type"); } List<TensorShapeProto> shapeList = attrValueList.getList().getShapeList(); return shapeList.get(0); // support multiple outputs? } - public static OrderedTensorType rename(OrderedTensorType type, DimensionRenamer renamer) { - List<TensorType.Dimension> renamedDimensions = new ArrayList<>(type.dimensions.size()); - for (TensorType.Dimension dimension : type.dimensions) { + public OrderedTensorType rename(DimensionRenamer renamer) { + List<TensorType.Dimension> renamedDimensions = new ArrayList<>(dimensions.size()); + for (TensorType.Dimension dimension : dimensions) { String oldName = dimension.name(); Optional<String> newName = renamer.dimensionNameOf(oldName); if (!newName.isPresent()) - return type; // presumably, already renamed + return this; // presumably, already renamed TensorType.Dimension.Type dimensionType = dimension.type(); if (dimensionType == TensorType.Dimension.Type.indexedBound) { renamedDimensions.add(TensorType.Dimension.indexed(newName.get(), dimension.size().get())); @@ -210,20 +207,21 @@ public class OrderedTensorType { if (size >= 0) { if (vespaDimension.type() != TensorType.Dimension.Type.indexedBound) { throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension types"); + "dimension types"); } if (!vespaDimension.size().isPresent()) { throw new IllegalArgumentException("Tensor dimension is indexed bound but does " + - "not have a size"); + "not have a size"); } if (vespaDimension.size().get() != size) { throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension sizes. TensorFlow: " + size + " Vespa: " + vespaDimension.size().get()); + "dimension sizes. TensorFlow: " + size + " Vespa: " + + vespaDimension.size().get()); } } else { if (vespaDimension.type() != TensorType.Dimension.Type.indexedUnbound) { throw new IllegalArgumentException("Non-agreement between TensorFlow and Vespa " + - "dimension types"); + "dimension types"); } } this.dimensions.add(vespaDimension); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java index 5d711aac100..2533148e5be 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/tensorflow/importer/operations/TensorFlowOperation.java @@ -60,7 +60,9 @@ public abstract class TensorFlowOperation { if (type == null) { type = lazyGetType(); } - OrderedTensorType.verifyType(node, type); + if (type != null) { + type.verifyType(node); + } return Optional.ofNullable(type); } @@ -96,7 +98,7 @@ public abstract class TensorFlowOperation { public void addDimensionNameConstraints(DimensionRenamer renamer) { } /** Performs dimension rename for this operation */ - public void renameDimensions(DimensionRenamer renamer) { type = OrderedTensorType.rename(type, renamer); } + public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } /** Return true for operations that are inputs to the model itself (as opposed to inputs to the operation) */ public boolean isInput() { return false; } @@ -131,7 +133,7 @@ public abstract class TensorFlowOperation { } if (inputs.size() != expected) { throw new IllegalArgumentException("Expected " + expected + " inputs " + - "for '" + node.getName() + "', got " + inputs.size()); + "for '" + node.getName() + "', got " + inputs.size()); } return inputs.stream().map(func).allMatch(Optional::isPresent); } |