From b8dbf8cc1705d2e568153175b1d2520b5cb72cbf Mon Sep 17 00:00:00 2001 From: Lester Solbakken Date: Wed, 3 Jun 2020 16:29:51 +0200 Subject: Ensure model name is valid as a dimension name --- .../importer/operations/IntermediateOperation.java | 25 +++++++++++++--------- 1 file changed, 15 insertions(+), 10 deletions(-) (limited to 'model-integration') diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index af134fac6cf..8ab15f4a330 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -59,7 +59,7 @@ public abstract class IntermediateOperation { IntermediateOperation(String modelName, String name, List inputs) { this.name = name; - this.modelName = modelName; + this.modelName = ensureValidAsDimensionName(modelName); this.inputs = new ArrayList<>(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } @@ -118,8 +118,8 @@ public abstract class IntermediateOperation { // Each dimension is distinct and ordered correctly: for (int j = i + 1; j < type.dimensions().size(); j++) { renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), - DimensionRenamer.Constraint.notEqual(false), - this); + DimensionRenamer.Constraint.notEqual(false), + this); } } } @@ -188,7 +188,7 @@ public abstract class IntermediateOperation { boolean verifyInputs(int expected, Function> func) { if (inputs.size() != expected) { throw new IllegalArgumentException("Expected " + expected + " inputs for '" + - name + "', got " + inputs.size()); + name + "', got " + inputs.size()); } return inputs.stream().map(func).allMatch(Optional::isPresent); } @@ -321,8 +321,8 @@ public abstract class IntermediateOperation { */ TensorType.Value resultValueType() { return TensorType.Value.largestOf(inputs.stream() - .map(input -> input.type().get().type().valueType()) - .collect(Collectors.toList())); + .map(input -> input.type().get().type().valueType()) + .collect(Collectors.toList())); } public abstract IntermediateOperation withInputs(List inputs); @@ -351,17 +351,22 @@ public abstract class IntermediateOperation { public abstract String operationName(); + /** Required due to tensor dimension name restrictions */ + private static String ensureValidAsDimensionName(String modelName) { + return modelName.replaceAll("[^\\w\\d\\$@_]", "_"); + } + @Override public String toString() { return operationName() + "(" + - inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + - ")"; + inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + + ")"; } public String toFullString() { return "\t" + type + ":\t" + operationName() + "(" + - inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + - ")"; + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + + ")"; } /** -- cgit v1.2.3