diff options
author | Lester Solbakken <lesters@oath.com> | 2020-06-03 16:29:51 +0200 |
---|---|---|
committer | Lester Solbakken <lesters@oath.com> | 2020-06-03 16:29:51 +0200 |
commit | b8dbf8cc1705d2e568153175b1d2520b5cb72cbf (patch) | |
tree | f968c2d242bbe0bd46753a33b420926d93601234 /model-integration | |
parent | c42ccbb874b581c7394a31dcb6a5e0e715d46e18 (diff) |
Ensure model name is valid as a dimension name
Diffstat (limited to 'model-integration')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java | 25 |
1 files changed, 15 insertions, 10 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index af134fac6cf..8ab15f4a330 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -59,7 +59,7 @@ public abstract class IntermediateOperation { IntermediateOperation(String modelName, String name, List<IntermediateOperation> inputs) { this.name = name; - this.modelName = modelName; + this.modelName = ensureValidAsDimensionName(modelName); this.inputs = new ArrayList<>(inputs); this.inputs.forEach(i -> i.outputs.add(this)); } @@ -118,8 +118,8 @@ public abstract class IntermediateOperation { // Each dimension is distinct and ordered correctly: for (int j = i + 1; j < type.dimensions().size(); j++) { renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), - DimensionRenamer.Constraint.notEqual(false), - this); + DimensionRenamer.Constraint.notEqual(false), + this); } } } @@ -188,7 +188,7 @@ public abstract class IntermediateOperation { boolean verifyInputs(int expected, Function<IntermediateOperation, Optional<?>> func) { if (inputs.size() != expected) { throw new IllegalArgumentException("Expected " + expected + " inputs for '" + - name + "', got " + inputs.size()); + name + "', got " + inputs.size()); } return inputs.stream().map(func).allMatch(Optional::isPresent); } @@ -321,8 +321,8 @@ public abstract class IntermediateOperation { */ TensorType.Value resultValueType() { return TensorType.Value.largestOf(inputs.stream() - .map(input -> input.type().get().type().valueType()) - .collect(Collectors.toList())); + .map(input -> input.type().get().type().valueType()) + .collect(Collectors.toList())); } public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs); @@ -351,17 +351,22 @@ public abstract class IntermediateOperation { public abstract String operationName(); + /** Required due to tensor dimension name restrictions */ + private static String ensureValidAsDimensionName(String modelName) { + return modelName.replaceAll("[^\\w\\d\\$@_]", "_"); + } + @Override public String toString() { return operationName() + "(" + - inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + - ")"; + inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + + ")"; } public String toFullString() { return "\t" + type + ":\t" + operationName() + "(" + - inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + - ")"; + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + + ")"; } /** |