diff options
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java')
-rw-r--r-- | model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java | 38 |
1 files changed, 38 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java index 0ee54f839bc..c3980b8fe93 100644 --- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java +++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java @@ -3,6 +3,7 @@ package ai.vespa.rankingexpression.importer.operations; import ai.vespa.rankingexpression.importer.DimensionRenamer; +import ai.vespa.rankingexpression.importer.IntermediateGraph; import ai.vespa.rankingexpression.importer.OrderedTensorType; import com.yahoo.searchlib.rankingexpression.Reference; import com.yahoo.searchlib.rankingexpression.evaluation.Value; @@ -58,6 +59,8 @@ public abstract class IntermediateOperation { protected abstract OrderedTensorType lazyGetType(); protected abstract TensorFunction lazyGetFunction(); + public String modelName() { return modelName; } + /** Returns the Vespa tensor type of this operation if it exists */ public Optional<OrderedTensorType> type() { if (type == null) { @@ -99,6 +102,20 @@ public abstract class IntermediateOperation { /** Add dimension name constraints for this operation */ public void addDimensionNameConstraints(DimensionRenamer renamer) { } + /** Conveinence method to adds dimensions and constraints of the given tensor type */ + protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) { + for (int i = 0; i < type.dimensions().size(); i++) { + renamer.addDimension(type.dimensions().get(i).name()); + + // Each dimension is distinct: + for (int j = i + 1; j < type.dimensions().size(); j++) { + renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(), + DimensionRenamer.Constraint.notEqual(false), + this); + } + } + } + /** Performs dimension rename for this operation */ public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); } @@ -175,6 +192,12 @@ public abstract class IntermediateOperation { .collect(Collectors.toList())); } + public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs); + + String asString(Optional<OrderedTensorType> type) { + return type.map(t -> t.toString()).orElse("(unknown)"); + } + /** * A method signature input and output has the form name:index. * This returns the name part without the index. @@ -203,4 +226,19 @@ public abstract class IntermediateOperation { Optional<List<Value>> getList(String key); } + public abstract String operationName(); + + @Override + public String toString() { + return operationName() + + inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) + + ")"; + } + + public String toFullString() { + return "\t" + lazyGetType() + ":\t" + operationName() + + inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) + + ")"; + } + } |