diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-30 13:20:49 -0500 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-06-30 13:20:49 -0500 |
commit | 40144341bdbbfcec9f21ee3784e3e3cf5e320c91 (patch) | |
tree | f21c5039c36cc550efcfeaf9fbd77e7c2e9434d3 | |
parent | 32a5521059e08308b5abae10d6b5e8ce1589e705 (diff) |
Output the intermediate graph
14 files changed, 268 insertions, 7 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java index 61ce9d98e69..b2f5d104890 100644 --- a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java +++ b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java @@ -73,14 +73,13 @@ public class BlendingSearcher extends Searcher { } /** - * Produce a single blended result list from a group of hitgroups. + * Produce a single blended hit list from a group of hitgroups. * - * It is assumed that the results are ordered in hitgroups. If not, the blend will not be performed + * This assumes that all hits are organized into hitgroups. If not, blending will not be performed. */ protected Result blendResults(Result result, Query q, int offset, int hits, Execution execution) { //Assert that there are more than one hitgroup and that there are only hitgroups on the lowest level - boolean foundNonGroup = false; Iterator<Hit> hitIterator = result.hits().iterator(); List<HitGroup> groups = new ArrayList<>(); @@ -89,14 +88,14 @@ public class BlendingSearcher extends Searcher { if (hit instanceof HitGroup) { groups.add((HitGroup)hit); hitIterator.remove(); - } else if(!hit.isMeta()) { + } else if ( ! hit.isMeta()) { foundNonGroup = true; } } - if(foundNonGroup) { + if( foundNonGroup) { result.hits().addError(ErrorMessage.createUnspecifiedError("Blendingsearcher could not blend - there are toplevel hits" + - " that are not hitgroups")); + " that are not hitgroups")); return result; } if (groups.size() == 0) { diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java index aec98d06874..54d4bd3cb0a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java @@ -104,4 +104,16 @@ public class IntermediateGraph { } } + @Override + public String toString() { + return "intermediate graph for '" + modelName + "'"; + } + + public String toFullString() { + StringBuilder b = new StringBuilder(); + for (var input : index.entrySet()) + b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n"); + return b.toString(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index 99bfa08db43..b88d7423a82 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -11,12 +11,15 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.functions.Rename; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ParenthesisExpressionPrettyPrinter; +import com.yahoo.text.Text; import com.yahoo.yolean.Exceptions; import java.io.File; import java.util.List; import java.util.Map; import java.util.Optional; +import java.util.logging.Level; import java.util.logging.Logger; /** @@ -50,6 +53,9 @@ public abstract class ModelImporter implements MlModelImporter { */ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); + log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" + + ParenthesisExpressionPrettyPrinter.prettyPrint(graph.toFullString())); + System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString()); graph.optimize(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java index d6ea00ca453..7d1b6a61e2e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java @@ -54,4 +54,8 @@ public class Argument extends IntermediateOperation { return false; } + @Override + public String toString() { + return "Argument(" + standardNamingType + ")" + " : " + lazyGetType(); + } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java index 41d421b1f5a..6571e77a198 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java @@ -86,4 +86,10 @@ public class Const extends IntermediateOperation { } return value.get(); } + + @Override + public String toString() { + return "Const(" + type + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index c64b9ded601..8f7d3755005 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -88,7 +88,7 @@ public class ExpandDims extends IntermediateOperation { List<String> renamedDimensions = new ArrayList<>(expandDimensions.size()); for (String name : expandDimensions) { Optional<String> newName = renamer.dimensionNameOf(name); - if (!newName.isPresent()) { + if ( ! newName.isPresent()) { return; // presumably, already renamed } renamedDimensions.add(newName.get()); @@ -96,4 +96,15 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = renamedDimensions; } + @Override + public String toString() { + return "ExpandDims(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + expandDimensions + ")"; + } + + @Override + public String toFullString() { + return "ExpandDims(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ", " + expandDimensions + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 0ee54f839bc..78bed31f5b0 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -203,4 +203,10 @@ public abstract class IntermediateOperation { Optional<List<Value>> getList(String key); } + public String toFullString() { return toString(); } + + String asString(Optional<OrderedTensorType> type) { + return type.map(t -> t.toString()).orElse("(unknown)"); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java index c2d75153586..c5e6ae49a25 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java @@ -111,4 +111,15 @@ public class Join extends IntermediateOperation { return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1); } + @Override + public String toString() { + return "Join(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + operator + ")"; + } + + @Override + public String toFullString() { + return "Join(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ", " + operator + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index e0842d820f9..4f70c46e459 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -34,4 +34,14 @@ public class Map extends IntermediateOperation { return new com.yahoo.tensor.functions.Map(input.get(), operator); } + @Override + public String toString() { + return "Map(" + asString(inputs().get(0).type()) + ", " + operator + ")"; + } + + @Override + public String toFullString() { + return "Map(" + inputs().get(0).toFullString() + ", " + operator + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 9a76662529d..73aa40927be 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -5,6 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.tensor.TensorType; import com.yahoo.tensor.functions.TensorFunction; +import com.yahoo.text.ParenthesisExpressionPrettyPrinter; +import com.yahoo.text.Text; import java.util.List; import java.util.Optional; @@ -51,6 +53,12 @@ public class MatMul extends IntermediateOperation { List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions(); List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions(); + assertTwoDimensions(aDimensions, inputs.get(0), "first argument"); + assertTwoDimensions(bDimensions, inputs.get(1), "second argument"); + + System.out.println("Dimensions in a: " + aDimensions); + System.out.println("Dimensions in b: " + bDimensions); + String aDim0 = aDimensions.get(0).name(); String aDim1 = aDimensions.get(1).name(); String bDim0 = bDimensions.get(0).name(); @@ -67,4 +75,24 @@ public class MatMul extends IntermediateOperation { renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this); } + private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) { + if (dimensions.size() >= 2) return; + + + throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + + " but got just " + dimensions + " from\n" + + ParenthesisExpressionPrettyPrinter.prettyPrint(supplier.toFullString())); + } + + @Override + public String toFullString() { + return "MatMul(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ")" + " : " + lazyGetType(); + } + + @Override + public String toString() { + return "MatMul(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ")"; + } + } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 46b95233d11..df5c4e9cbfa 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -107,4 +107,15 @@ public class Sum extends IntermediateOperation { return builder.build(); } + @Override + public String toString() { + return "Sum(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + reduceDimensions + ")"; + } + + @Override + public String toFullString() { + return "Sum(" + inputs().get(0).toFullString() + ", " + + inputs().get(1).toFullString() + ", " + reduceDimensions + ")" + " : " + lazyGetType(); + } + } diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java new file mode 100644 index 00000000000..be0ab4b894a --- /dev/null +++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java @@ -0,0 +1,28 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package ai.vespa.rankingexpression.importer.tensorflow; + +import ai.vespa.rankingexpression.importer.ImportedModel; +import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction; +import org.junit.Assert; +import org.junit.Test; + +import static org.junit.Assert.assertNotNull; + +/** + * @author bratseth + */ +public class Issue9662TestCase { + + @Test + public void testImporting() { + TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/9662"); + ImportedModel.Signature signature = model.get().signature("serving_default"); + Assert.assertEquals("Should have no skipped outputs", + 0, model.get().signature("serving_default").skippedOutputs().size()); + + ImportedMlFunction output = signature.outputFunction("y", "y"); + assertNotNull(output); + model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001); + } + +} diff --git a/vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java b/vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java new file mode 100644 index 00000000000..ad235d78679 --- /dev/null +++ b/vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java @@ -0,0 +1,47 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.text; + +/** + * Pretty prints any parenthesis expression + * + * @author bratseth + */ +public class ParenthesisExpressionPrettyPrinter { + + private static final int indentUnit = 2; + + public static String prettyPrint(String parenthesisExpression) { + StringBuilder b = new StringBuilder(); + prettyPrint(parenthesisExpression, 0, b); + return b.toString(); + } + + private static void prettyPrint(String expression, int indent, StringBuilder b) { + int nextStartParenthesis = expression.indexOf("("); + int nextEndParenthesis = expression.indexOf(")"); + if (nextStartParenthesis < 0) + nextStartParenthesis = Integer.MAX_VALUE; + if (nextEndParenthesis < 0) + nextEndParenthesis = Integer.MAX_VALUE; + + boolean start = nextStartParenthesis < nextEndParenthesis; + int nextParenthesis = Math.min(nextStartParenthesis, nextEndParenthesis); + + int effectiveIndent = start || nextParenthesis > 0 ? indent : indent - 2; + b.append(" ".repeat(Math.max(0, effectiveIndent))); + if (nextParenthesis == Integer.MAX_VALUE) { + b.append(expression); + } + else { + if (! start && nextParenthesis > 0) { + b.append(expression, 0, nextParenthesis).append("\n"); + b.append(" ".repeat(Math.max(0, indent - 2))).append(")\n"); + } + else { + b.append(expression, 0, nextParenthesis + 1).append("\n"); + } + prettyPrint(expression.substring(nextParenthesis + 1), indent + (start ? indentUnit : -indentUnit), b); + } + } + +} diff --git a/vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java b/vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java new file mode 100644 index 00000000000..79bdc6a5318 --- /dev/null +++ b/vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java @@ -0,0 +1,82 @@ +// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. +package com.yahoo.text; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +/** + * @author bratseth + */ +public class ParenthesisExpressionPrettyPrinterTest { + + @Test + public void testBasic() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + ")\n"; + assertPrettyPrint(expected, "foo(bar(baz()))"); + } + + @Test + public void testInnerContent() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " hello world\n" + + " )\n" + + " )\n" + + ")\n"; + assertPrettyPrint(expected, "foo(bar(baz(hello world)))"); + } + @Test + public void testUnmatchedStart() { + String expected = + "foo(\n" + + " (\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + " )\n" + + " "; + assertPrettyPrint(expected, "foo((bar(baz()))"); + } + + @Test + public void testUnmatchedEnd() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(\n" + + " )\n" + + " )\n" + + ")\n" + + ")\n"; + assertPrettyPrint(expected, "foo(bar(baz())))"); + } + + @Test + public void testNoParenthesis() { + String expected = + "foo bar baz"; + assertPrettyPrint(expected, "foo bar baz"); + } + + @Test + public void testEmpty() { + String expected = + ""; + assertPrettyPrint(expected, ""); + } + + private void assertPrettyPrint(String expected, String expression) { + assertEquals(expected, ParenthesisExpressionPrettyPrinter.prettyPrint(expression)); + } + +} |