summaryrefslogtreecommitdiffstats
path: root/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
diff options
context:
space:
mode:
Diffstat (limited to 'model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java')
-rw-r--r--model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java38
1 files changed, 38 insertions, 0 deletions
diff --git a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
index 0ee54f839bc..c3980b8fe93 100644
--- a/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
+++ b/model-integration/src/main/java/ai/vespa/rankingexpression/importer/operations/IntermediateOperation.java
@@ -3,6 +3,7 @@
package ai.vespa.rankingexpression.importer.operations;
import ai.vespa.rankingexpression.importer.DimensionRenamer;
+import ai.vespa.rankingexpression.importer.IntermediateGraph;
import ai.vespa.rankingexpression.importer.OrderedTensorType;
import com.yahoo.searchlib.rankingexpression.Reference;
import com.yahoo.searchlib.rankingexpression.evaluation.Value;
@@ -58,6 +59,8 @@ public abstract class IntermediateOperation {
protected abstract OrderedTensorType lazyGetType();
protected abstract TensorFunction lazyGetFunction();
+ public String modelName() { return modelName; }
+
/** Returns the Vespa tensor type of this operation if it exists */
public Optional<OrderedTensorType> type() {
if (type == null) {
@@ -99,6 +102,20 @@ public abstract class IntermediateOperation {
/** Add dimension name constraints for this operation */
public void addDimensionNameConstraints(DimensionRenamer renamer) { }
+ /** Conveinence method to adds dimensions and constraints of the given tensor type */
+ protected void addConstraintsFrom(OrderedTensorType type, DimensionRenamer renamer) {
+ for (int i = 0; i < type.dimensions().size(); i++) {
+ renamer.addDimension(type.dimensions().get(i).name());
+
+ // Each dimension is distinct:
+ for (int j = i + 1; j < type.dimensions().size(); j++) {
+ renamer.addConstraint(type.dimensions().get(i).name(), type.dimensions().get(j).name(),
+ DimensionRenamer.Constraint.notEqual(false),
+ this);
+ }
+ }
+ }
+
/** Performs dimension rename for this operation */
public void renameDimensions(DimensionRenamer renamer) { type = type.rename(renamer); }
@@ -175,6 +192,12 @@ public abstract class IntermediateOperation {
.collect(Collectors.toList()));
}
+ public abstract IntermediateOperation withInputs(List<IntermediateOperation> inputs);
+
+ String asString(Optional<OrderedTensorType> type) {
+ return type.map(t -> t.toString()).orElse("(unknown)");
+ }
+
/**
* A method signature input and output has the form name:index.
* This returns the name part without the index.
@@ -203,4 +226,19 @@ public abstract class IntermediateOperation {
Optional<List<Value>> getList(String key);
}
+ public abstract String operationName();
+
+ @Override
+ public String toString() {
+ return operationName() +
+ inputs().stream().map(input -> asString(input.type())).collect(Collectors.joining(", ")) +
+ ")";
+ }
+
+ public String toFullString() {
+ return "\t" + lazyGetType() + ":\t" + operationName() +
+ inputs().stream().map(input -> input.toFullString()).collect(Collectors.joining(", ")) +
+ ")";
+ }
+
}