diff options
Diffstat (limited to 'model-integration')
5 files changed, 16 insertions, 13 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java index f22e89cc8bb..19c2026d457 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/ModelImporter.java @@ -52,9 +52,8 @@ public abstract class ModelImporter implements MlModelImporter { */ protected static ImportedModel convertIntermediateGraphToModel(IntermediateGraph graph, String modelSource) { ImportedModel model = new ImportedModel(graph.name(), modelSource); - log.log(Level.FINE, () -> "Intermediate graph created from '" + modelSource + "':\n" + - ExpressionFormatter.inTwoColumnMode(20).format(graph.toFullString())); - System.out.println("Intermediate graph created from '" + modelSource + "':\n" + graph.toFullString()); + log.log(Level.FINER, () -> "Intermediate graph created from '" + modelSource + "':\n" + + ExpressionFormatter.inTwoColumnMode(70, 50).format(graph.toFullString())); graph.optimize(); diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java index 18ff602148a..8f029fc9c4a 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/ExpandDims.java @@ -30,7 +30,7 @@ public class ExpandDims extends IntermediateOperation { if ( ! allInputTypesPresent(2)) return null; IntermediateOperation axisOperation = inputs().get(1); - if ( !axisOperation.getConstantValue().isPresent()) { + if ( ! axisOperation.getConstantValue().isPresent()) { throw new IllegalArgumentException("ExpandDims in " + name + ": Axis must be a constant."); } Tensor axis = axisOperation.getConstantValue().get().asTensor(); @@ -47,18 +47,23 @@ public class ExpandDims extends IntermediateOperation { expandDimensions = new ArrayList<>(); int dimensionIndex = 0; for (TensorType.Dimension dimension : inputType.dimensions()) { - if (dimensionIndex == dimensionToInsert) { - String name = String.format("%s_%d", vespaName(), dimensionIndex); - expandDimensions.add(name); - typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); - } + if (dimensionIndex == dimensionToInsert) + addDimension(dimensionIndex, typeBuilder); typeBuilder.add(dimension); dimensionIndex++; } - + if (dimensionIndex == inputType.dimensions().size()) { // Insert last dimension + addDimension(dimensionIndex, typeBuilder); + } return typeBuilder.build(); } + private void addDimension(int dimensionIndex, OrderedTensorType.Builder typeBuilder) { + String name = String.format("%s_%d", vespaName(), dimensionIndex); + expandDimensions.add(name); + typeBuilder.add(TensorType.Dimension.indexed(name, 1L)); + } + @Override protected TensorFunction lazyGetFunction() { if ( ! allInputFunctionsPresent(2)) return null; diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java index 5fba7cc6da0..52aec71fa3f 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Map.java @@ -41,7 +41,7 @@ public class Map extends IntermediateOperation { @Override public String toFullString() { - return "\t" + lazyGetType() + "\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")"; + return "\t" + lazyGetType() + ":\tMap(" + inputs().get(0).toFullString() + ", " + operator + ")"; } } diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java index 3f7c32294ed..cf6cc722b9e 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/MatMul.java @@ -80,7 +80,7 @@ public class MatMul extends IntermediateOperation { throw new IllegalArgumentException("Expected 2 dimensions in the " + inputDescription + " to " + this + " but got just " + dimensions + " from\n" + - ExpressionFormatter.inTwoColumnMode(70).format(supplier.toFullString())); + ExpressionFormatter.inTwoColumnMode(70, 50).format(supplier.toFullString())); } @Override diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java index 8b5d06e8d56..046ab2a1646 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/Sum.java @@ -104,7 +104,6 @@ public class Sum extends IntermediateOperation { builder.add(TensorType.Dimension.indexed(dimension.name(), 1L)); } } - System.out.println("----------> Sum input type is " + inputType + ", keepDimensions: " + keepDimensions + ", result: " + builder.build()); return builder.build(); } |