diff options
author | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-02 09:29:08 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@verizonmedia.com> | 2019-07-02 09:29:08 -0700 |
commit | 6c8e1b26bc33ba89f8fed9354fe2666dc796a485 (patch) | |
tree | 659dff20a3656f4a0bc888be5ba92a70f3b30de8 | |
parent | be02e47ab5eda6d6d314c39a4f414678d09b9b9e (diff) |
Allow extending beyond the last tensor dimension
7 files changed, 136 insertions, 45 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index f22e89cc8bb..19c2026d457 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -52,9 +52,8 @@ public abstract class ModelImporter implements MlModelImporter { */ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); - log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" + - ExpressionFormatter.inTwoColumnMode(20).format(graph.toFullString())); - System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString()); + log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString())); graph.optimize(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 18ff602148a..8f029fc9c4a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); - if ( !axisOperation.getConstantValue().isPresent()) { + if ( ! axisOperation.getConstantValue().isPresent()) { throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); @@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { - if (dimensionIndex == dimensionToInsert) { - String name = String.format("%s_%d", vespaName(), dimensionIndex); - expandDimensions.add(name); - typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); - } + if (dimensionIndex == dimensionToInsert) + addDimension(dimensionIndex, typeBuilder); typeBuilder.add(dimension); dimensionIndex++; } - + if (dimensionIndex == inputType.dimensions().size()) { // Insert last dimension + addDimension(dimensionIndex, typeBuilder); + } return typeBuilder.build(); } + private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + @Override protected TensorFunction lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index 5fba7cc6da0..52aec71fa3f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -41,7 +41,7 @@ public class Map extends IntermediateOperation { @Override public String toFullString() { - return "\t" + lazyGetType() + "\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")"; + return "\t" + lazyGetType() + ":\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")"; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 3f7c32294ed..cf6cc722b9e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -80,7 +80,7 @@ public class MatMul extends IntermediateOperation { throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70).format(supplier.toFullString())); + ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 8b5d06e8d56..046ab2a1646 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -104,7 +104,6 @@ public class Sum extends IntermediateOperation { builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); } } - System.out.println("----------> Sum input type is " + inputType + ", keepDimensions: " + keepDimensions + ", result: " + builder.build()); return builder.build(); } diff --git a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java index b7670ab70e6..280b75f9cbb 100644 --- a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java +++ b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java @@ -15,15 +15,26 @@ public class ExpressionFormatter { private static final int indentUnit = 2; /** The size of the first column, or 0 if none */ - private int firstColumnSize = 0; - - private ExpressionFormatter(int firstColumnSize) { - this.firstColumnSize = firstColumnSize; + private final int firstColumnLength; + + /** + * The desired size of the second column (or the entire line if no first column), + * or 0 to split into multiple lines as much as possible. + * Setting this collects larger chunks to one line across markup + * but will not split too long lines that have no markup. + */ + private final int secondColumnLength; + + private ExpressionFormatter(int firstColumnLength, int secondColumnLength) { + this.firstColumnLength = firstColumnLength; + this.secondColumnLength = secondColumnLength; } public String format(String parenthesisExpression) { StringBuilder b = new StringBuilder(); format(parenthesisExpression, 0, b); + while (b.length() > 0 && Character.isWhitespace(b.charAt(b.length() - 1))) + b.setLength(b.length() - 1); return b.toString(); } @@ -34,9 +45,15 @@ public class ExpressionFormatter { Markup next = Markup.next(expression); appendIndent( ! next.isClose() || next.position() > 0 ? indent : indent - 2, b); + + int endOfBalancedChunk = endOfBalancedChunk(expression, Math.max(0, secondColumnLength - indent)); if (next.isEmpty()) { b.append(expression); } + else if (endOfBalancedChunk > 0) { + b.append(expression, 0, endOfBalancedChunk + 1).append("\n"); + format(expression.substring(endOfBalancedChunk + 1), indent, b); + } else if (next.isComma()) { b.append(expression, 0, next.position() + 1).append("\n"); format(expression.substring(next.position() + 1), indent, b); @@ -55,8 +72,25 @@ public class ExpressionFormatter { } } + /** Returns the position of the end of a balanced chunk of at most the given size, or 0 if there is no such chunk */ + private int endOfBalancedChunk(String expression, int maxSize) { + int chunkSize = 0; + int i = 0; + int nesting = 0; + while (i < maxSize && i < expression.length()) { + if (expression.charAt(i) == '\t') return chunkSize; + if (expression.charAt(i) == '(') nesting++; + if (expression.charAt(i) == ')') nesting--; + if (nesting < 0) return chunkSize; + if (nesting == 0 && ( expression.charAt(i)==')' || expression.charAt(i)==',')) + chunkSize = i; + i++; + } + return chunkSize; + } + private String appendFirstColumn(String expression, StringBuilder b) { - if (firstColumnSize == 0) return expression; + if (firstColumnLength == 0) return expression; while (expression.charAt(0) == ' ') expression = expression.substring(1); @@ -65,11 +99,11 @@ public class ExpressionFormatter { int tab2 = expression.indexOf('\t', 1); if (tab2 >= 0) { String firstColumn = expression.substring(1, tab2); - b.append(asSize(firstColumnSize, firstColumn)).append(" "); + b.append(asSize(firstColumnLength, firstColumn)).append(" "); return expression.substring(tab2 + 1); } } - appendIndent(firstColumnSize + 1, b); + appendIndent(firstColumnLength + 1, b); return expression; } @@ -86,11 +120,15 @@ public class ExpressionFormatter { /** Convenience method creating a formatter and using it to format the given expression */ public static String on(String parenthesisExpression) { - return new ExpressionFormatter(0).format(parenthesisExpression); + return new ExpressionFormatter(0, 80).format(parenthesisExpression); + } + + public static ExpressionFormatter withLineLength(int maxLineLength) { + return new ExpressionFormatter(0, maxLineLength); } - public static ExpressionFormatter inTwoColumnMode(int firstColumnSize) { - return new ExpressionFormatter(firstColumnSize); + public static ExpressionFormatter inTwoColumnMode(int firstColumnSize, int secondColumnSize) { + return new ExpressionFormatter(firstColumnSize, secondColumnSize); } /** Contains the next position of each kind of markup, or Integer.MAX_VALUE if not present */ diff --git a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java index 6dfb2e6fc8a..7251ccef521 100644 --- a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java +++ b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java @@ -18,8 +18,13 @@ public class ExpressionFormatterTest { " baz(\n" + " )\n" + " )\n" + - ")\n"; - assertPrettyPrint(expected, "foo(bar(baz()))"); + ")"; + assertPrettyPrint(expected, "foo(bar(baz()))", 0); + } + + @Test + public void testBasicDense() { + assertPrettyPrint("foo(bar(baz()))", "foo(bar(baz()))", 50); } @Test @@ -31,8 +36,8 @@ public class ExpressionFormatterTest { " hello world\n" + " )\n" + " )\n" + - ")\n"; - assertPrettyPrint(expected, "foo(bar(baz(hello world)))"); + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hello world)))", 0); } @Test @@ -45,8 +50,23 @@ public class ExpressionFormatterTest { " 37\n" + " )\n" + " )\n" + - ")\n"; - assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))"); + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))", 0); + } + + @Test + public void testMultipleArgumentsSemiDense() { + String expected = + "foo(\n" + + " bar(\n" + + " baz(hi,37),\n" + + " baz(\n" + + " hello world,\n" + + " 37\n" + + " )\n" + + " )\n" + + ")"; + assertPrettyPrint(expected, "foo(bar(baz(hi,37),baz(hello world,37)))", 15); } @Test @@ -58,8 +78,8 @@ public class ExpressionFormatterTest { " baz(\n" + " )\n" + " )\n" + - " )\n"; - assertPrettyPrint(expected, "foo((bar(baz()))"); + " )"; + assertPrettyPrint(expected, "foo((bar(baz()))", 0); } @Test @@ -71,22 +91,22 @@ public class ExpressionFormatterTest { " )\n" + " )\n" + ")\n" + - ")\n"; - assertPrettyPrint(expected, "foo(bar(baz())))"); + ")"; + assertPrettyPrint(expected, "foo(bar(baz())))", 0); } @Test public void testNoParenthesis() { String expected = "foo bar baz"; - assertPrettyPrint(expected, "foo bar baz"); + assertPrettyPrint(expected, "foo bar baz", 0); } @Test public void testEmpty() { String expected = ""; - assertPrettyPrint(expected, ""); + assertPrettyPrint(expected, "", 0); } @Test @@ -98,8 +118,8 @@ public class ExpressionFormatterTest { "2: hello world\n" + " )\n" + "t(o )\n" + - " )\n"; - ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3); + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world)\tt(o)@olong:\t))")); } @@ -113,28 +133,58 @@ public class ExpressionFormatterTest { "3: 37\n" + " )\n" + "t(o )\n" + - " )\n"; - ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3); + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world,\t3:\t37)\tt(o)@olong:\t))")); } @Test - public void test2ColumnModeMultipleArgumentsWithSpaces() { + public void test2ColumnModeMultipleArgumentsSemiDense() { String expected = "1: foo(\n" + " bar(\n" + + " baz(hi,37),\n" + + " boz(\n" + + "2: hello world,\n" + + "3: 5\n" + + " )\n" + + "t(o )\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 15); + assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(hi,37),boz(\t2:\thello world,\t3:\t5)\tt(o)@olong:\t))")); + } + + @Test + public void test2ColumnModeMultipleArgumentsWithSpaces() { + String expected = + " foo(\n" + + "1: bar(\n" + " baz(\n" + "2: hello world,\n" + "3: 37\n" + " )\n" + "t(o )\n" + - " )\n"; - ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3); - assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))")); + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0); + assertEquals(expected, pp.format("foo(\t1:\tbar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))")); + } + + @Test + public void testTwoColumnLambdaFunction() { + String expected = + " join(\n" + + " a,\n" + + " join(\n" + + " b, c, f(a, b)(a * b)\n" + + " )\n" + + " , f(a, b)(a * b)\n" + + " )"; + ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(5, 25); + assertEquals(expected, pp.format("join(a, join(b, c, f(a, b)(a * b)), f(a, b)(a * b))")); } - private void assertPrettyPrint(String expected, String expression) { - assertEquals(expected, ExpressionFormatter.on(expression)); + private void assertPrettyPrint(String expected, String expression, int lineLength) { + assertEquals(expected, ExpressionFormatter.withLineLength(lineLength).format(expression)); } } |