summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-06-30 13:20:49 -0500
committerJon Bratseth <bratseth@verizonmedia.com>2019-06-30 13:20:49 -0500
commit40144341bdbbfcec9f21ee3784e3e3cf5e320c91 (patch)
treef21c5039c36cc550efcfeaf9fbd77e7c2e9434d3
parent32a5521059e08308b5abae10d6b5e8ce1589e705 (diff)
Output the intermediate graph
-rw-r--r--container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java12
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java4
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java13
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java6
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java11
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java10
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java28
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java11
-rw-r--r--model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java28
-rw-r--r--vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java47
-rw-r--r--vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java82
14 files changed, 268 insertions, 7 deletions
diff --git a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java
index 61ce9d98e69..b2f5d104890 100644
--- a/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java
+++ b/container-search/src/main/java/com/yahoo/prelude/searcher/BlendingSearcher.java
@@ -73,14 +73,13 @@ public class BlendingSearcher extends Searcher {
}
/**
- * Produce a single blended result list from a group of hitgroups.
+ * Produce a single blended hit list from a group of hitgroups.
*
- * It is assumed that the results are ordered in hitgroups. If not, the blend will not be performed
+ * This assumes that all hits are organized into hitgroups. If not, blending will not be performed.
*/
protected Result blendResults(Result result, Query q, int offset, int hits, Execution execution) {
//Assert that there are more than one hitgroup and that there are only hitgroups on the lowest level
-
boolean foundNonGroup = false;
Iterator<Hit> hitIterator = result.hits().iterator();
List<HitGroup> groups = new ArrayList<>();
@@ -89,14 +88,14 @@ public class BlendingSearcher extends Searcher {
if (hit instanceof HitGroup) {
groups.add((HitGroup)hit);
hitIterator.remove();
- } else if(!hit.isMeta()) {
+ } else if ( ! hit.isMeta()) {
foundNonGroup = true;
}
}
- if(foundNonGroup) {
+ if( foundNonGroup) {
result.hits().addError(ErrorMessage.createUnspecifiedError("Blendingsearcher could not blend - there are toplevel hits" +
- " that are not hitgroups"));
+ " that are not hitgroups"));
return result;
}
if (groups.size() == 0) {
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
index aec98d06874..54d4bd3cb0a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/IntermediateGraph.java
@@ -104,4 +104,16 @@ public class IntermediateGraph {
}
}
+ @Override
+ public String toString() {
+ return "intermediate graph for '" + modelName + "'";
+ }
+
+ public String toFullString() {
+ StringBuilder b = new StringBuilder();
+ for (var input : index.entrySet())
+ b.append(input.getKey()).append(": ").append(input.getValue().toFullString()).append("\n");
+ return b.toString();
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index 99bfa08db43..b88d7423a82 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -11,12 +11,15 @@ import com.yahoo.searchlib.rankingexpression.parser.ParseException;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.functions.Rename;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ParenthesisExpressionPrettyPrinter;
+import com.yahoo.text.Text;
import com.yahoo.yolean.Exceptions;
import java.io.File;
import java.util.List;
import java.util.Map;
import java.util.Optional;
+import java.util.logging.Level;
import java.util.logging.Logger;
/**
@@ -50,6 +53,9 @@ public abstract class ModelImporter implements MlModelImporter {
*/
protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
+ log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" +
+ ParenthesisExpressionPrettyPrinter.prettyPrint(graph.toFullString()));
+ System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString());
graph.optimize();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
index d6ea00ca453..7d1b6a61e2e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Argument.java
@@ -54,4 +54,8 @@ public class Argument extends IntermediateOperation {
return false;
}
+ @Override
+ public String toString() {
+ return "Argument(" + standardNamingType + ")" + " : " + lazyGetType();
+ }
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
index 41d421b1f5a..6571e77a198 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Const.java
@@ -86,4 +86,10 @@ public class Const extends IntermediateOperation {
}
return value.get();
}
+
+ @Override
+ public String toString() {
+ return "Const(" + type + ")" + " : " + lazyGetType();
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index c64b9ded601..8f7d3755005 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -88,7 +88,7 @@ public class ExpandDims extends IntermediateOperation {
List<String> renamedDimensions = new ArrayList<>(expandDimensions.size());
for (String name : expandDimensions) {
Optional<String> newName = renamer.dimensionNameOf(name);
- if (!newName.isPresent()) {
+ if ( ! newName.isPresent()) {
return; // presumably, already renamed
}
renamedDimensions.add(newName.get());
@@ -96,4 +96,15 @@ public class ExpandDims extends IntermediateOperation {
expandDimensions = renamedDimensions;
}
+ @Override
+ public String toString() {
+ return "ExpandDims(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + expandDimensions + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "ExpandDims(" + inputs().get(0).toFullString() + ", " +
+ inputs().get(1).toFullString() + ", " + expandDimensions + ")" + " : " + lazyGetType();
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 0ee54f839bc..78bed31f5b0 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -203,4 +203,10 @@ public abstract class IntermediateOperation {
Optional<List<Value>> getList(String key);
}
+ public String toFullString() { return toString(); }
+
+ String asString(Optional<OrderedTensorType> type) {
+ return type.map(t -> t.toString()).orElse("(unknown)");
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
index c2d75153586..c5e6ae49a25 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Join.java
@@ -111,4 +111,15 @@ public class Join extends IntermediateOperation {
return a.rank() < b.rank() ? inputs.get(0) : inputs.get(1);
}
+ @Override
+ public String toString() {
+ return "Join(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + operator + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "Join(" + inputs().get(0).toFullString() + ", " +
+ inputs().get(1).toFullString() + ", " + operator + ")" + " : " + lazyGetType();
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
index e0842d820f9..4f70c46e459 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -34,4 +34,14 @@ public class Map extends IntermediateOperation {
return new com.yahoo.tensor.functions.Map(input.get(), operator);
}
+ @Override
+ public String toString() {
+ return "Map(" + asString(inputs().get(0).type()) + ", " + operator + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "Map(" + inputs().get(0).toFullString() + ", " + operator + ")" + " : " + lazyGetType();
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 9a76662529d..73aa40927be 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -5,6 +5,8 @@ import ai.vespa.rankingexpression.importer.DimensionRenamer;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.tensor.TensorType;
import com.yahoo.tensor.functions.TensorFunction;
+import com.yahoo.text.ParenthesisExpressionPrettyPrinter;
+import com.yahoo.text.Text;
import java.util.List;
import java.util.Optional;
@@ -51,6 +53,12 @@ public class MatMul extends IntermediateOperation {
List<TensorType.Dimension> aDimensions = inputs.get(0).type().get().dimensions();
List<TensorType.Dimension> bDimensions = inputs.get(1).type().get().dimensions();
+ assertTwoDimensions(aDimensions, inputs.get(0), "first argument");
+ assertTwoDimensions(bDimensions, inputs.get(1), "second argument");
+
+ System.out.println("Dimensions in a: " + aDimensions);
+ System.out.println("Dimensions in b: " + bDimensions);
+
String aDim0 = aDimensions.get(0).name();
String aDim1 = aDimensions.get(1).name();
String bDim0 = bDimensions.get(0).name();
@@ -67,4 +75,24 @@ public class MatMul extends IntermediateOperation {
renamer.addConstraint(bDim0, bDim1, DimensionRenamer::greaterThan, this);
}
+ private void assertTwoDimensions(List<TensorType.Dimension> dimensions, IntermediateOperation supplier, String inputDescription) {
+ if (dimensions.size() >= 2) return;
+
+
+ throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
+ " but got just " + dimensions + " from\n" +
+ ParenthesisExpressionPrettyPrinter.prettyPrint(supplier.toFullString()));
+ }
+
+ @Override
+ public String toFullString() {
+ return "MatMul(" + inputs().get(0).toFullString() + ", " +
+ inputs().get(1).toFullString() + ")" + " : " + lazyGetType();
+ }
+
+ @Override
+ public String toString() {
+ return "MatMul(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ")";
+ }
+
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index 46b95233d11..df5c4e9cbfa 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -107,4 +107,15 @@ public class Sum extends IntermediateOperation {
return builder.build();
}
+ @Override
+ public String toString() {
+ return "Sum(" + asString(inputs().get(0).type()) + ", " + asString(inputs().get(1).type()) + ", " + reduceDimensions + ")";
+ }
+
+ @Override
+ public String toFullString() {
+ return "Sum(" + inputs().get(0).toFullString() + ", " +
+ inputs().get(1).toFullString() + ", " + reduceDimensions + ")" + " : " + lazyGetType();
+ }
+
}
diff --git a/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
new file mode 100644
index 00000000000..be0ab4b894a
--- /dev/null
+++ b/model-integration/src/test/java/ai/vespa/rankingexpression/importer/tensorflow/Issue9662TestCase.java
@@ -0,0 +1,28 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package ai.vespa.rankingexpression.importer.tensorflow;
+
+import ai.vespa.rankingexpression.importer.ImportedModel;
+import ai.vespa.rankingexpression.importer.configmodelview.ImportedMlFunction;
+import org.junit.Assert;
+import org.junit.Test;
+
+import static org.junit.Assert.assertNotNull;
+
+/**
+ * @author bratseth
+ */
+public class Issue9662TestCase {
+
+ @Test
+ public void testImporting() {
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/models/tensorflow/9662");
+ ImportedModel.Signature signature = model.get().signature("serving_default");
+ Assert.assertEquals("Should have no skipped outputs",
+ 0, model.get().signature("serving_default").skippedOutputs().size());
+
+ ImportedMlFunction output = signature.outputFunction("y", "y");
+ assertNotNull(output);
+ model.assertEqualResultSum("input", "dnn/outputs/add", 0.00001);
+ }
+
+}
diff --git a/vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java b/vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java
new file mode 100644
index 00000000000..ad235d78679
--- /dev/null
+++ b/vespajlib/src/main/java/com/yahoo/text/ParenthesisExpressionPrettyPrinter.java
@@ -0,0 +1,47 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.text;
+
+/**
+ * Pretty prints any parenthesis expression
+ *
+ * @author bratseth
+ */
+public class ParenthesisExpressionPrettyPrinter {
+
+ private static final int indentUnit = 2;
+
+ public static String prettyPrint(String parenthesisExpression) {
+ StringBuilder b = new StringBuilder();
+ prettyPrint(parenthesisExpression, 0, b);
+ return b.toString();
+ }
+
+ private static void prettyPrint(String expression, int indent, StringBuilder b) {
+ int nextStartParenthesis = expression.indexOf("(");
+ int nextEndParenthesis = expression.indexOf(")");
+ if (nextStartParenthesis < 0)
+ nextStartParenthesis = Integer.MAX_VALUE;
+ if (nextEndParenthesis < 0)
+ nextEndParenthesis = Integer.MAX_VALUE;
+
+ boolean start = nextStartParenthesis < nextEndParenthesis;
+ int nextParenthesis = Math.min(nextStartParenthesis, nextEndParenthesis);
+
+ int effectiveIndent = start || nextParenthesis > 0 ? indent : indent - 2;
+ b.append(" ".repeat(Math.max(0, effectiveIndent)));
+ if (nextParenthesis == Integer.MAX_VALUE) {
+ b.append(expression);
+ }
+ else {
+ if (! start && nextParenthesis > 0) {
+ b.append(expression, 0, nextParenthesis).append("\n");
+ b.append(" ".repeat(Math.max(0, indent - 2))).append(")\n");
+ }
+ else {
+ b.append(expression, 0, nextParenthesis + 1).append("\n");
+ }
+ prettyPrint(expression.substring(nextParenthesis + 1), indent + (start ? indentUnit : -indentUnit), b);
+ }
+ }
+
+}
diff --git a/vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java b/vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java
new file mode 100644
index 00000000000..79bdc6a5318
--- /dev/null
+++ b/vespajlib/src/test/java/com/yahoo/text/ParenthesisExpressionPrettyPrinterTest.java
@@ -0,0 +1,82 @@
+// Copyright 2019 Oath Inc. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
+package com.yahoo.text;
+
+import org.junit.Test;
+
+import static org.junit.Assert.assertEquals;
+
+/**
+ * @author bratseth
+ */
+public class ParenthesisExpressionPrettyPrinterTest {
+
+ @Test
+ public void testBasic() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " )\n" +
+ " )\n" +
+ ")\n";
+ assertPrettyPrint(expected, "foo(bar(baz()))");
+ }
+
+ @Test
+ public void testInnerContent() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " hello world\n" +
+ " )\n" +
+ " )\n" +
+ ")\n";
+ assertPrettyPrint(expected, "foo(bar(baz(hello world)))");
+ }
+ @Test
+ public void testUnmatchedStart() {
+ String expected =
+ "foo(\n" +
+ " (\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " )\n" +
+ " )\n" +
+ " )\n" +
+ " ";
+ assertPrettyPrint(expected, "foo((bar(baz()))");
+ }
+
+ @Test
+ public void testUnmatchedEnd() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(\n" +
+ " )\n" +
+ " )\n" +
+ ")\n" +
+ ")\n";
+ assertPrettyPrint(expected, "foo(bar(baz())))");
+ }
+
+ @Test
+ public void testNoParenthesis() {
+ String expected =
+ "foo bar baz";
+ assertPrettyPrint(expected, "foo bar baz");
+ }
+
+ @Test
+ public void testEmpty() {
+ String expected =
+ "";
+ assertPrettyPrint(expected, "");
+ }
+
+ private void assertPrettyPrint(String expected, String expression) {
+ assertEquals(expected, ParenthesisExpressionPrettyPrinter.prettyPrint(expression));
+ }
+
+}