diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-30 13:20:49 -0500 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-30 13:20:49 -0500 |
commit | 40144341bdbbfcec9f21ee3784e3e3cf5e320c91 (patch) | |
tree | f21c5039c36cc550efcfeaf9fbd77e7c2e9434d3 /model-integration | |
parent | 32a5521059e08308b5abae10d6b5e8ce1589e705 (diff) |
Output the intermediate graph
Diffstat (limited to 'model-integration')
11 files changed, 134 insertions, 1 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index aec98d06874..54d4bd3cb0a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -104,4 +104,16 @@ public class IntermediateGraph { } } + @Override + public String toString() { + return "intermediate graph for '" + modelName + "'"; + } + + public String toFullString() { + StringBuilder b = new StringBuilder(); + for (var input : index.entrySet()) + b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n"); + return b.toString(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 99bfa08db43..b88d7423a82 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -11,12 +11,15 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ParenthesisExpressionPrettyPrinter; +import com.yahoo.text.Text; import com.yahoo.yolean.Exceptions; import java.io.File; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -50,6 +53,9 @@ public abstract class ModelImporter implements MlModelImporter { */ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); + log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" + + ParenthesisExpressionPrettyPrinter.prettyPrint(graph.toFullString())); + System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString()); graph.optimize(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index d6ea00ca453..7d1b6a61e2e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -54,4 +54,8 @@ public class Argument extends IntermediateOperation { return false; } + @Override + public String toString() { + return "Argument(" + standardNamingType + ")" + " : " + lazyGetType(); + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 41d421b1f5a..6571e77a198 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -86,4 +86,10 @@ public class Const extends IntermediateOperation { } return value.get(); } + + @Override + public String toString() { + return "Const(" + type + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index c64b9ded601..8f7d3755005 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -88,7 +88,7 @@ public class ExpandDims extends IntermediateOperation { List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); for (String name : expandDimensions) { Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { + if ( ! newName.isPresent()) { return; // presumably, already renamed } renamedDimensions.add(newName.get()); @@ -96,4 +96,15 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = renamedDimensions; } + @Override + public String toString() { + return "ExpandDims(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + expandDimensions + ")"; + } + + @Override + public String toFullString() { + return "ExpandDims(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ", " + expandDimensions + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 0ee54f839bc..78bed31f5b0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -203,4 +203,10 @@ public abstract class IntermediateOperation { Optional<List<Value>> getList(String key); } + public String toFullString() { return toString(); } + + String asString(Optional<OrderedTensorType> type) { + return type.map(t -> t.toString()).orElse("(unknown)"); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index c2d75153586..c5e6ae49a25 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -111,4 +111,15 @@ public class Join extends IntermediateOperation { return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); } + @Override + public String toString() { + return "Join(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + operator + ")"; + } + + @Override + public String toFullString() { + return "Join(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ", " + operator + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index e0842d820f9..4f70c46e459 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -34,4 +34,14 @@ public class Map extends IntermediateOperation { return new com.yahoo.tensor.functions.Map(input.get(), operator); } + @Override + public String toString() { + return "Map(" + asString(inputs().get(0).type()) + ", " + operator + ")"; + } + + @Override + public String toFullString() { + return "Map(" + inputs().get(0).toFullString() + ", " + operator + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 9a76662529d..73aa40927be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -5,6 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ParenthesisExpressionPrettyPrinter; +import com.yahoo.text.Text; import java.util.List; import java.util.Optional; @@ -51,6 +53,12 @@ public class MatMul extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); + assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + + System.out.println("Dimensions in a: " + aDimensions); + System.out.println("Dimensions in b: " + bDimensions); + String aDim0 = aDimensions.get(0).name(); String aDim1 = aDimensions.get(1).name(); String bDim0 = bDimensions.get(0).name(); @@ -67,4 +75,24 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { + if (dimensions.size() >= 2) return; + + + throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + + " but got just " + dimensions + " from\n" + + ParenthesisExpressionPrettyPrinter.prettyPrint(supplier.toFullString())); + } + + @Override + public String toFullString() { + return "MatMul(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ")" + " : " + lazyGetType(); + } + + @Override + public String toString() { + return "MatMul(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 46b95233d11..df5c4e9cbfa 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -107,4 +107,15 @@ public class Sum extends IntermediateOperation { return builder.build(); } + @Override + public String toString() { + return "Sum(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + reduceDimensions + ")"; + } + + @Override + public String toFullString() { + return "Sum(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ", " + reduceDimensions + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java new file mode 100644 index 00000000000..be0ab4b894a --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java @@ -0,0 +1,28 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.tensorflow; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Issue9662TestCase { + + @Test + public void testImporting() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/9662"); + ImportedModel.Signature signature = model.get().signature("serving_default"); + Assert.assertEquals("Should have no skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + ImportedMlFunction output = signature.outputFunction("y", "y"); + assertNotNull(output); + model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001); + } + +} |