summaryrefslogtreecommitdiffstats
path: root/model-integration
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@verizonmedia.com>2019-07-02 09:29:08 -0700
committerJon Bratseth <bratseth@verizonmedia.com>2019-07-02 09:29:08 -0700
commit6c8e1b26bc33ba89f8fed9354fe2666dc796a485 (patch)
tree659dff20a3656f4a0bc888be5ba92a70f3b30de8 /model-integration
parentbe02e47ab5eda6d6d314c39a4f414678d09b9b9e (diff)
Allow extending beyond the last tensor dimension
Diffstat (limited to 'model-integration')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java5
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java19
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java2
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java1
5 files changed, 16 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
index f22e89cc8bb..19c2026d457 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java
@@ -52,9 +52,8 @@ public abstract class ModelImporter implements MlModelImporter {
*/
protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) {
ImportedModel model = new ImportedModel(graph.name(), modelSource);
- log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" +
- ExpressionFormatter.inTwoColumnMode(20).format(graph.toFullString()));
- System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString());
+ log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" +
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString()));
graph.optimize();
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
index 18ff602148a..8f029fc9c4a 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java
@@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation {
if ( ! allInputTypesPresent(2)) return null;
IntermediateOperation axisOperation = inputs().get(1);
- if ( !axisOperation.getConstantValue().isPresent()) {
+ if ( ! axisOperation.getConstantValue().isPresent()) {
throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant.");
}
Tensor axis = axisOperation.getConstantValue().get().asTensor();
@@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation {
expandDimensions = new ArrayList<>();
int dimensionIndex = 0;
for (TensorType.Dimension dimension : inputType.dimensions()) {
- if (dimensionIndex == dimensionToInsert) {
- String name = String.format("%s_%d", vespaName(), dimensionIndex);
- expandDimensions.add(name);
- typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
- }
+ if (dimensionIndex == dimensionToInsert)
+ addDimension(dimensionIndex, typeBuilder);
typeBuilder.add(dimension);
dimensionIndex++;
}
-
+ if (dimensionIndex == inputType.dimensions().size()) { // Insert last dimension
+ addDimension(dimensionIndex, typeBuilder);
+ }
return typeBuilder.build();
}
+ private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) {
+ String name = String.format("%s_%d", vespaName(), dimensionIndex);
+ expandDimensions.add(name);
+ typeBuilder.add(TensorType.Dimension.indexed(name, 1L));
+ }
+
@Override
protected TensorFunction lazyGetFunction() {
if ( ! allInputFunctionsPresent(2)) return null;
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
index 5fba7cc6da0..52aec71fa3f 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java
@@ -41,7 +41,7 @@ public class Map extends IntermediateOperation {
@Override
public String toFullString() {
- return "\t" + lazyGetType() + "\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")";
+ return "\t" + lazyGetType() + ":\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")";
}
}
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
index 3f7c32294ed..cf6cc722b9e 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java
@@ -80,7 +80,7 @@ public class MatMul extends IntermediateOperation {
throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this +
" but got just " + dimensions + " from\n" +
- ExpressionFormatter.inTwoColumnMode(70).format(supplier.toFullString()));
+ ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString()));
}
@Override
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
index 8b5d06e8d56..046ab2a1646 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java
@@ -104,7 +104,6 @@ public class Sum extends IntermediateOperation {
builder.add(TensorType.Dimension.indexed(dimension.name(), 1L));
}
}
- System.out.println("----------> Sum input type is " + inputType + ", keepDimensions: " + keepDimensions + ", result: " + builder.build());
return builder.build();
}