summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-02 09:29:08 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-02 09:29:08 -0700
commit6c8e1b26bc33ba89f8fed9354fe2666dc796a485 (patch)
tree659dff20a3656f4a0bc888be5ba92a70f3b30de8
parentbe02e47ab5eda6d6d314c39a4f414678d09b9b9e (diff)
Allow extending beyond the last tensor dimension
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java1
-rw-r--r--vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java58
-rw-r--r--vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java94
7 files changed, 136 insertions, 45 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index f22e89cc8bb..19c2026d457 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -52,9 +52,8 @@ public abstract class ModelImporter implements MlModelImporter {
*/
protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
- log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" +
- ExpressionFormatter.inTwoColumnMode(20).format(graph.toFullString()));
- System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString());
+ log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" +
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString()));
graph.optimize();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 18ff602148a..8f029fc9c4a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
- if ( !axisOperation.getConstantValue().isPresent()) {
+ if ( ! axisOperation.getConstantValue().isPresent()) {
throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
@@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation {
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
- if (dimensionIndex == dimensionToInsert) {
- String name = String.format("%s_%d", vespaName(), dimensionIndex);
- expandDimensions.add(name);
- typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
- }
+ if (dimensionIndex == dimensionToInsert)
+ addDimension(dimensionIndex, typeBuilder);
typeBuilder.add(dimension);
dimensionIndex++;
}
-
+ if (dimensionIndex == inputType.dimensions().size()) { // Insert last dimension
+ addDimension(dimensionIndex, typeBuilder);
+ }
return typeBuilder.build();
}
+ private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputFunctionsPresent(2)) return null;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
index 5fba7cc6da0..52aec71fa3f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -41,7 +41,7 @@ public class Map extends IntermediateOperation {
@Override
public String toFullString() {
- return "\t" + lazyGetType() + "\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")";
+ return "\t" + lazyGetType() + ":\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")";
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 3f7c32294ed..cf6cc722b9e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -80,7 +80,7 @@ public class MatMul extends IntermediateOperation {
throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
" but got just " + dimensions + " from\n" +
- ExpressionFormatter.inTwoColumnMode(70).format(supplier.toFullString()));
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index 8b5d06e8d56..046ab2a1646 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -104,7 +104,6 @@ public class Sum extends IntermediateOperation {
builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
}
}
- System.out.println("----------> Sum input type is " + inputType + ", keepDimensions: " + keepDimensions + ", result: " + builder.build());
return builder.build();
}
diff --git a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java
index b7670ab70e6..280b75f9cbb 100644
--- a/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java
+++ b/vespajlib/src/main/java/com/yahoo/text/ExpressionFormatter.java
@@ -15,15 +15,26 @@ public class ExpressionFormatter {
private static final int indentUnit = 2;
/** The size of the first column, or 0 if none */
- private int firstColumnSize = 0;
-
- private ExpressionFormatter(int firstColumnSize) {
- this.firstColumnSize = firstColumnSize;
+ private final int firstColumnLength;
+
+ /**
+ * The desired size of the second column (or the entire line if no first column),
+ * or 0 to split into multiple lines as much as possible.
+ * Setting this collects larger chunks to one line across markup
+ * but will not split too long lines that have no markup.
+ */
+ private final int secondColumnLength;
+
+ private ExpressionFormatter(int firstColumnLength, int secondColumnLength) {
+ this.firstColumnLength = firstColumnLength;
+ this.secondColumnLength = secondColumnLength;
}
public String format(String parenthesisExpression) {
StringBuilder b = new StringBuilder();
format(parenthesisExpression, 0, b);
+ while (b.length() > 0 && Character.isWhitespace(b.charAt(b.length() - 1)))
+ b.setLength(b.length() - 1);
return b.toString();
}
@@ -34,9 +45,15 @@ public class ExpressionFormatter {
Markup next = Markup.next(expression);
appendIndent( ! next.isClose() || next.position() > 0 ? indent : indent - 2, b);
+
+ int endOfBalancedChunk = endOfBalancedChunk(expression, Math.max(0, secondColumnLength - indent));
if (next.isEmpty()) {
b.append(expression);
}
+ else if (endOfBalancedChunk > 0) {
+ b.append(expression, 0, endOfBalancedChunk + 1).append("\n");
+ format(expression.substring(endOfBalancedChunk + 1), indent, b);
+ }
else if (next.isComma()) {
b.append(expression, 0, next.position() + 1).append("\n");
format(expression.substring(next.position() + 1), indent, b);
@@ -55,8 +72,25 @@ public class ExpressionFormatter {
}
}
+ /** Returns the position of the end of a balanced chunk of at most the given size, or 0 if there is no such chunk */
+ private int endOfBalancedChunk(String expression, int maxSize) {
+ int chunkSize = 0;
+ int i = 0;
+ int nesting = 0;
+ while (i < maxSize && i < expression.length()) {
+ if (expression.charAt(i) == '\t') return chunkSize;
+ if (expression.charAt(i) == '(') nesting++;
+ if (expression.charAt(i) == ')') nesting--;
+ if (nesting < 0) return chunkSize;
+ if (nesting == 0 && ( expression.charAt(i)==')' || expression.charAt(i)==','))
+ chunkSize = i;
+ i++;
+ }
+ return chunkSize;
+ }
+
private String appendFirstColumn(String expression, StringBuilder b) {
- if (firstColumnSize == 0) return expression;
+ if (firstColumnLength == 0) return expression;
while (expression.charAt(0) == ' ')
expression = expression.substring(1);
@@ -65,11 +99,11 @@ public class ExpressionFormatter {
int tab2 = expression.indexOf('\t', 1);
if (tab2 >= 0) {
String firstColumn = expression.substring(1, tab2);
- b.append(asSize(firstColumnSize, firstColumn)).append(" ");
+ b.append(asSize(firstColumnLength, firstColumn)).append(" ");
return expression.substring(tab2 + 1);
}
}
- appendIndent(firstColumnSize + 1, b);
+ appendIndent(firstColumnLength + 1, b);
return expression;
}
@@ -86,11 +120,15 @@ public class ExpressionFormatter {
/** Convenience method creating a formatter and using it to format the given expression */
public static String on(String parenthesisExpression) {
- return new ExpressionFormatter(0).format(parenthesisExpression);
+ return new ExpressionFormatter(0, 80).format(parenthesisExpression);
+ }
+
+ public static ExpressionFormatter withLineLength(int maxLineLength) {
+ return new ExpressionFormatter(0, maxLineLength);
}
- public static ExpressionFormatter inTwoColumnMode(int firstColumnSize) {
- return new ExpressionFormatter(firstColumnSize);
+ public static ExpressionFormatter inTwoColumnMode(int firstColumnSize, int secondColumnSize) {
+ return new ExpressionFormatter(firstColumnSize, secondColumnSize);
}
/** Contains the next position of each kind of markup, or Integer.MAX_VALUE if not present */
diff --git a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java
index 6dfb2e6fc8a..7251ccef521 100644
--- a/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java
+++ b/vespajlib/src/test/java/com/yahoo/text/ExpressionFormatterTest.java
@@ -18,8 +18,13 @@ public class ExpressionFormatterTest {
" baz(\n" +
" )\n" +
" )\n" +
- ")\n";
- assertPrettyPrint(expected, "foo(bar(baz()))");
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz()))", 0);
+ }
+
+ @Test
+ public void testBasicDense() {
+ assertPrettyPrint("foo(bar(baz()))", "foo(bar(baz()))", 50);
}
@Test
@@ -31,8 +36,8 @@ public class ExpressionFormatterTest {
" hello world\n" +
" )\n" +
" )\n" +
- ")\n";
- assertPrettyPrint(expected, "foo(bar(baz(hello world)))");
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz(hello world)))", 0);
}
@Test
@@ -45,8 +50,23 @@ public class ExpressionFormatterTest {
" 37\n" +
" )\n" +
" )\n" +
- ")\n";
- assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))");
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz(hello world,37)))", 0);
+ }
+
+ @Test
+ public void testMultipleArgumentsSemiDense() {
+ String expected =
+ "foo(\n" +
+ " bar(\n" +
+ " baz(hi,37),\n" +
+ " baz(\n" +
+ " hello world,\n" +
+ " 37\n" +
+ " )\n" +
+ " )\n" +
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz(hi,37),baz(hello world,37)))", 15);
}
@Test
@@ -58,8 +78,8 @@ public class ExpressionFormatterTest {
" baz(\n" +
" )\n" +
" )\n" +
- " )\n";
- assertPrettyPrint(expected, "foo((bar(baz()))");
+ " )";
+ assertPrettyPrint(expected, "foo((bar(baz()))", 0);
}
@Test
@@ -71,22 +91,22 @@ public class ExpressionFormatterTest {
" )\n" +
" )\n" +
")\n" +
- ")\n";
- assertPrettyPrint(expected, "foo(bar(baz())))");
+ ")";
+ assertPrettyPrint(expected, "foo(bar(baz())))", 0);
}
@Test
public void testNoParenthesis() {
String expected =
"foo bar baz";
- assertPrettyPrint(expected, "foo bar baz");
+ assertPrettyPrint(expected, "foo bar baz", 0);
}
@Test
public void testEmpty() {
String expected =
"";
- assertPrettyPrint(expected, "");
+ assertPrettyPrint(expected, "", 0);
}
@Test
@@ -98,8 +118,8 @@ public class ExpressionFormatterTest {
"2: hello world\n" +
" )\n" +
"t(o )\n" +
- " )\n";
- ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3);
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0);
assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world)\tt(o)@olong:\t))"));
}
@@ -113,28 +133,58 @@ public class ExpressionFormatterTest {
"3: 37\n" +
" )\n" +
"t(o )\n" +
- " )\n";
- ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3);
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0);
assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world,\t3:\t37)\tt(o)@olong:\t))"));
}
@Test
- public void test2ColumnModeMultipleArgumentsWithSpaces() {
+ public void test2ColumnModeMultipleArgumentsSemiDense() {
String expected =
"1: foo(\n" +
" bar(\n" +
+ " baz(hi,37),\n" +
+ " boz(\n" +
+ "2: hello world,\n" +
+ "3: 5\n" +
+ " )\n" +
+ "t(o )\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 15);
+ assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(hi,37),boz(\t2:\thello world,\t3:\t5)\tt(o)@olong:\t))"));
+ }
+
+ @Test
+ public void test2ColumnModeMultipleArgumentsWithSpaces() {
+ String expected =
+ " foo(\n" +
+ "1: bar(\n" +
" baz(\n" +
"2: hello world,\n" +
"3: 37\n" +
" )\n" +
"t(o )\n" +
- " )\n";
- ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3);
- assertEquals(expected, pp.format("\t1:\tfoo(bar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))"));
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(3, 0);
+ assertEquals(expected, pp.format("foo(\t1:\tbar(baz(\t2:\thello world, \t3:\t37)\tt(o)@olong:\t))"));
+ }
+
+ @Test
+ public void testTwoColumnLambdaFunction() {
+ String expected =
+ " join(\n" +
+ " a,\n" +
+ " join(\n" +
+ " b, c, f(a, b)(a * b)\n" +
+ " )\n" +
+ " , f(a, b)(a * b)\n" +
+ " )";
+ ExpressionFormatter pp = ExpressionFormatter.inTwoColumnMode(5, 25);
+ assertEquals(expected, pp.format("join(a, join(b, c, f(a, b)(a * b)), f(a, b)(a * b))"));
}
- private void assertPrettyPrint(String expected, String expression) {
- assertEquals(expected, ExpressionFormatter.on(expression));
+ private void assertPrettyPrint(String expected, String expression, int lineLength) {
+ assertEquals(expected, ExpressionFormatter.withLineLength(lineLength).format(expression));
}
}