summaryrefslogtreecommitdiffstats
path: root/model-evaluation
diff options
context:
space:
mode:
authorArne Juul <arnej@yahooinc.com>2023-03-08 12:13:46 +0000
committerArne Juul <arnej@yahooinc.com>2023-03-08 15:10:34 +0000
commita5276a07e0e9a4b06a715609aeec9f6a78dadba2 (patch)
tree07506115593af8787776e9f4ec82aa8d9e80c090 /model-evaluation
parent4676e5bcd2300ef40059cc34a93c6fb9d8e06b57 (diff)
use OnnxExpressionNode
Diffstat (limited to 'model-evaluation')
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java66
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java8
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java6
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java14
-rw-r--r--model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json6
-rw-r--r--model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx23
-rw-r--r--model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg16
-rw-r--r--model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg130
-rw-r--r--model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg3
-rw-r--r--model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg0
10 files changed, 267 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index f173a6b453f..66612b7ccc3 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -8,11 +8,19 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer;
import com.yahoo.stream.CustomCollectors;
import com.yahoo.tensor.TensorType;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.searchlib.rankingexpression.rule.CompositeNode;
+import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
+import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode;
+import com.yahoo.searchlib.rankingexpression.transform.TransformContext;
+import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer;
+
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.logging.Logger;
import java.util.stream.Collectors;
/**
@@ -23,6 +31,8 @@ import java.util.stream.Collectors;
@Beta
public class Model implements AutoCloseable {
+ private static final Logger logger = Logger.getLogger(Model.class.getName());
+
/** The prefix generated by model-integration/../IntermediateOperation */
private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_";
@@ -54,6 +64,60 @@ public class Model implements AutoCloseable {
List.of());
}
+ static class OnnxReplacer extends ExpressionTransformer<TransformContext> {
+ private final List<OnnxModel> onnxModels;
+ private final Map<String, TensorType> declaredTypes;
+
+ private OnnxModel getModel(String name) {
+ for (var m : onnxModels) if (m.name().equals(name)) return m;
+ return null;
+ }
+ public OnnxReplacer(List<OnnxModel> onnxModels,
+ Map<String, TensorType> declaredTypes)
+ {
+ this.onnxModels = onnxModels;
+ this.declaredTypes = declaredTypes;
+ }
+
+ @Override
+ public ExpressionNode transform(ExpressionNode node, TransformContext context) {
+ var orig = node;
+ if (node instanceof ReferenceNode r) {
+ var ref = r.reference();
+ if (ref.name().equals("onnx") || ref.name().equals("onnxModel")) {
+ logger.fine("consider replacing: " + ref);
+ var m = getModel(ref.simpleArgument().orElse(null));
+ if (m != null) {
+ // Load the model (if not already loaded) to extract inputs
+ m.load();
+ var expr = m.getExpressionForOutput(ref.output());
+ if (expr != null) {
+ logger.fine("Replacing " + node + " => " + expr);
+ node = expr;
+ for (var inputSpec : m.inputSpecs) {
+ var old = declaredTypes.put(inputSpec.source, inputSpec.wantedType);
+ if (old != null && ! old.equals(inputSpec.wantedType)) {
+ throw new IllegalArgumentException("Conflicting types needed for " + inputSpec.source + "; " + old + " != " + inputSpec.wantedType);
+ }
+ }
+ } else {
+ logger.fine("no output named " + ref.output() + " from " + m);
+ }
+ } else {
+ logger.fine("no onnx model named " + ref.simpleArgument());
+ }
+ }
+ }
+ if (node instanceof CompositeNode c) {
+ node = transformChildren(c, context);
+ }
+ if (node != orig) {
+ logger.fine("transformed: " + orig + " => " + node);
+ }
+ return node;
+ }
+ }
+
Model(String name,
Map<FunctionReference, ExpressionFunction> functions,
Map<FunctionReference, ExpressionFunction> referencedFunctions,
@@ -68,6 +132,8 @@ public class Model implements AutoCloseable {
Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>();
for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) {
try {
+ var body = function.getValue().getBody();
+ body.setRoot(new OnnxReplacer(onnxModels, declaredTypes).transform(body.getRoot(), null));
LazyArrayContext context = new LazyArrayContext(function.getValue(), bindingExtractor, referencedFunctions, constants, this);
contextBuilder.put(function.getValue().getName(), context);
if (function.getValue().returnType().isEmpty()) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java
index a50d9e36d74..67c016074de 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java
@@ -82,7 +82,7 @@ class OnnxExpressionNode extends CompositeNode {
Value inputValue = inputRefs.get(i).evaluate(context);
inputs.put(modelInputs.get(i), inputValue.asTensor());
}
- return new TensorValue(model.evaluate(inputs, onnxOutputName));
+ return new TensorValue(model.unmappedEvaluate(inputs, onnxOutputName));
}
@Override
@@ -94,9 +94,9 @@ class OnnxExpressionNode extends CompositeNode {
@Override
public StringBuilder toString(StringBuilder b, SerializationContext context, Deque<String> path, CompositeNode parent) {
b.append("onnx_expression_node(").append(model.name()).append(")");
- if (outputAs != null && ! outputAs.equals("")) {
+ if (outputAs != null && ! outputAs.equals("")) {
b.append(".").append(outputAs);
- }
- return b;
+ }
+ return b;
}
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
index 59febf7cdbf..ad27f9d2d15 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java
@@ -193,7 +193,11 @@ class OnnxModel implements AutoCloseable {
if (onnxName == null) {
throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output);
}
- return evaluator().evaluate(mapped, onnxName);
+ return unmappedEvaluate(mapped, onnxName);
+ }
+
+ Tensor unmappedEvaluate(Map<String, Tensor> inputs, String onnxOutputName) {
+ return evaluator().evaluate(inputs, onnxOutputName);
}
private OnnxEvaluator evaluator() {
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
index 65eb55ae46d..ff1f9afcc91 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java
@@ -60,4 +60,18 @@ public class RankProfileImportingTest {
ModelTester tester = new ModelTester("src/test/resources/config/ranking-macros/");
assertEquals(5, tester.models().size());
}
+
+ @Test
+ public void testImportingAdvancedGlobalPhase() {
+ ModelTester tester = new ModelTester("src/test/resources/config/advanced-global-phase/");
+ assertEquals(6, tester.models().size());
+ Model m = tester.models().get("global_phase");
+ assertEquals("global_phase", m.name());
+ var func = m.function("globalphase");
+ assertEquals("globalphase", func.getName());
+ var args = func.argumentTypes();
+ assertEquals(2, args.size());
+ assertEquals("tensor(d0[2])", args.get("attribute(doc_vec)").toString());
+ assertEquals("tensor(d0[2])", args.get("query(query_vec)").toString());
+ }
}
diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json b/model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json
new file mode 100644
index 00000000000..cd9bae2501e
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json
@@ -0,0 +1,6 @@
+{ "cells": [
+ { "address": { "d0": "0", "d1": "0" }, "value": 1.0 },
+ { "address": { "d0": "1", "d1": "0" }, "value": 1.0 },
+ { "address": { "d0": "0", "d1": "1" }, "value": 1.0 },
+ { "address": { "d0": "1", "d1": "1" }, "value": 1.0 }
+]}
diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx b/model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx
new file mode 100644
index 00000000000..5cbd44f6635
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx
@@ -0,0 +1,23 @@
+:Ý
+*
+ model_input_1
+ model_input_2XA"MatMul
+(
+XA
+ model_input_3model_output_1"Add multiply_addZ
+ model_input_1
+  
+
+Z
+ model_input_2
+
+ 
+Z
+ model_input_3
+
+ 
+b
+model_output_1
+
+ 
+B \ No newline at end of file
diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg
new file mode 100644
index 00000000000..821087cbd8c
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg
@@ -0,0 +1,16 @@
+model[0].name "multiply_add"
+model[0].fileref "multiply_add.onnx"
+model[0].input[0].name "model_input_2"
+model[0].input[0].source "rankingExpression(input_two)"
+model[0].input[1].name "model_input_1"
+model[0].input[1].source "rankingExpression(input_one)"
+model[0].input[2].name "model_input_3"
+model[0].input[2].source "rankingExpression(input_three)"
+model[0].output[0].name "model_output_1"
+model[0].output[0].as "multiply_add_output"
+model[0].dry_run_on_setup true
+model[0].stateless_execution_mode ""
+model[0].stateless_interop_threads -1
+model[0].stateless_intraop_threads -1
+model[0].gpu_device -1
+model[0].gpu_device_required false
diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg
new file mode 100644
index 00000000000..bb725f13d91
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg
@@ -0,0 +1,130 @@
+rankprofile[0].name "default"
+rankprofile[0].fef.property[0].name "vespa.type.attribute.doc_vec"
+rankprofile[0].fef.property[0].value "tensor(d0[2])"
+rankprofile[1].name "unranked"
+rankprofile[1].fef.property[0].name "vespa.rank.firstphase"
+rankprofile[1].fef.property[0].value "value(0)"
+rankprofile[1].fef.property[1].name "vespa.hitcollector.heapsize"
+rankprofile[1].fef.property[1].value "0"
+rankprofile[1].fef.property[2].name "vespa.hitcollector.arraysize"
+rankprofile[1].fef.property[2].value "0"
+rankprofile[1].fef.property[3].name "vespa.dump.ignoredefaultfeatures"
+rankprofile[1].fef.property[3].value "true"
+rankprofile[1].fef.property[4].name "vespa.type.attribute.doc_vec"
+rankprofile[1].fef.property[4].value "tensor(d0[2])"
+rankprofile[2].name "basebase"
+rankprofile[2].fef.property[0].name "vespa.type.attribute.doc_vec"
+rankprofile[2].fef.property[0].value "tensor(d0[2])"
+rankprofile[3].name "base"
+rankprofile[3].fef.property[0].name "rankingExpression(input_one).rankingScript"
+rankprofile[3].fef.property[0].value "constant(matrix)"
+rankprofile[3].fef.property[1].name "rankingExpression(input_one).type"
+rankprofile[3].fef.property[1].value "tensor(d0[2],d1[2])"
+rankprofile[3].fef.property[2].name "rankingExpression(input_two).rankingScript"
+rankprofile[3].fef.property[2].value "attribute(doc_vec)"
+rankprofile[3].fef.property[3].name "rankingExpression(input_two).type"
+rankprofile[3].fef.property[3].value "tensor(d0[2])"
+rankprofile[3].fef.property[4].name "rankingExpression(input_three).rankingScript"
+rankprofile[3].fef.property[4].value "query(query_vec)"
+rankprofile[3].fef.property[5].name "rankingExpression(input_three).type"
+rankprofile[3].fef.property[5].value "tensor(d0[2])"
+rankprofile[3].fef.property[6].name "rankingExpression(fn_query_vec).rankingScript"
+rankprofile[3].fef.property[6].value "query(query_vec)"
+rankprofile[3].fef.property[7].name "rankingExpression(fn_query_vec).type"
+rankprofile[3].fef.property[7].value "tensor(d0[2])"
+rankprofile[3].fef.property[8].name "rankingExpression(fn_doc_vec).rankingScript"
+rankprofile[3].fef.property[8].value "attribute(doc_vec)"
+rankprofile[3].fef.property[9].name "rankingExpression(fn_doc_vec).type"
+rankprofile[3].fef.property[9].value "tensor(d0[2])"
+rankprofile[3].fef.property[10].name "vespa.rank.firstphase"
+rankprofile[3].fef.property[10].value "rankingExpression(firstphase)"
+rankprofile[3].fef.property[11].name "rankingExpression(firstphase).rankingScript"
+rankprofile[3].fef.property[11].value "-attribute(score)"
+rankprofile[3].fef.property[12].name "vespa.summary.feature"
+rankprofile[3].fef.property[12].value "query(query_vec)"
+rankprofile[3].fef.property[13].name "vespa.summary.feature"
+rankprofile[3].fef.property[13].value "onnx(multiply_add).multiply_add_output"
+rankprofile[3].fef.property[14].name "vespa.type.attribute.doc_vec"
+rankprofile[3].fef.property[14].value "tensor(d0[2])"
+rankprofile[3].fef.property[15].name "vespa.type.query.query_vec"
+rankprofile[3].fef.property[15].value "tensor(d0[2])"
+rankprofile[4].name "global_phase"
+rankprofile[4].fef.property[0].name "rankingExpression(fn_query_vec).rankingScript"
+rankprofile[4].fef.property[0].value "query(query_vec)"
+rankprofile[4].fef.property[1].name "rankingExpression(fn_query_vec).type"
+rankprofile[4].fef.property[1].value "tensor(d0[2])"
+rankprofile[4].fef.property[2].name "rankingExpression(input_one).rankingScript"
+rankprofile[4].fef.property[2].value "constant(matrix)"
+rankprofile[4].fef.property[3].name "rankingExpression(input_one).type"
+rankprofile[4].fef.property[3].value "tensor(d0[2],d1[2])"
+rankprofile[4].fef.property[4].name "rankingExpression(input_two).rankingScript"
+rankprofile[4].fef.property[4].value "attribute(doc_vec)"
+rankprofile[4].fef.property[5].name "rankingExpression(input_two).type"
+rankprofile[4].fef.property[5].value "tensor(d0[2])"
+rankprofile[4].fef.property[6].name "rankingExpression(input_three).rankingScript"
+rankprofile[4].fef.property[6].value "query(query_vec)"
+rankprofile[4].fef.property[7].name "rankingExpression(input_three).type"
+rankprofile[4].fef.property[7].value "tensor(d0[2])"
+rankprofile[4].fef.property[8].name "rankingExpression(fn_doc_vec).rankingScript"
+rankprofile[4].fef.property[8].value "attribute(doc_vec)"
+rankprofile[4].fef.property[9].name "rankingExpression(fn_doc_vec).type"
+rankprofile[4].fef.property[9].value "tensor(d0[2])"
+rankprofile[4].fef.property[10].name "vespa.rank.firstphase"
+rankprofile[4].fef.property[10].value "rankingExpression(firstphase)"
+rankprofile[4].fef.property[11].name "rankingExpression(firstphase).rankingScript"
+rankprofile[4].fef.property[11].value "-attribute(score)"
+rankprofile[4].fef.property[12].name "vespa.rank.globalphase"
+rankprofile[4].fef.property[12].value "rankingExpression(globalphase)"
+rankprofile[4].fef.property[13].name "rankingExpression(globalphase).rankingScript"
+rankprofile[4].fef.property[13].value "reduce(onnx(multiply_add).multiply_add_output - rankingExpression(fn_query_vec), sum)"
+rankprofile[4].fef.property[14].name "vespa.summary.feature"
+rankprofile[4].fef.property[14].value "query(query_vec)"
+rankprofile[4].fef.property[15].name "vespa.summary.feature"
+rankprofile[4].fef.property[15].value "onnx(multiply_add).multiply_add_output"
+rankprofile[4].fef.property[16].name "vespa.match.feature"
+rankprofile[4].fef.property[16].value "attribute(doc_vec)"
+rankprofile[4].fef.property[17].name "vespa.globalphase.rerankcount"
+rankprofile[4].fef.property[17].value "3"
+rankprofile[4].fef.property[18].name "vespa.type.attribute.doc_vec"
+rankprofile[4].fef.property[18].value "tensor(d0[2])"
+rankprofile[4].fef.property[19].name "vespa.type.query.query_vec"
+rankprofile[4].fef.property[19].value "tensor(d0[2])"
+rankprofile[5].name "second_phase"
+rankprofile[5].fef.property[0].name "rankingExpression(fn_query_vec).rankingScript"
+rankprofile[5].fef.property[0].value "query(query_vec)"
+rankprofile[5].fef.property[1].name "rankingExpression(fn_query_vec).type"
+rankprofile[5].fef.property[1].value "tensor(d0[2])"
+rankprofile[5].fef.property[2].name "rankingExpression(input_one).rankingScript"
+rankprofile[5].fef.property[2].value "constant(matrix)"
+rankprofile[5].fef.property[3].name "rankingExpression(input_one).type"
+rankprofile[5].fef.property[3].value "tensor(d0[2],d1[2])"
+rankprofile[5].fef.property[4].name "rankingExpression(input_two).rankingScript"
+rankprofile[5].fef.property[4].value "attribute(doc_vec)"
+rankprofile[5].fef.property[5].name "rankingExpression(input_two).type"
+rankprofile[5].fef.property[5].value "tensor(d0[2])"
+rankprofile[5].fef.property[6].name "rankingExpression(input_three).rankingScript"
+rankprofile[5].fef.property[6].value "query(query_vec)"
+rankprofile[5].fef.property[7].name "rankingExpression(input_three).type"
+rankprofile[5].fef.property[7].value "tensor(d0[2])"
+rankprofile[5].fef.property[8].name "rankingExpression(fn_doc_vec).rankingScript"
+rankprofile[5].fef.property[8].value "attribute(doc_vec)"
+rankprofile[5].fef.property[9].name "rankingExpression(fn_doc_vec).type"
+rankprofile[5].fef.property[9].value "tensor(d0[2])"
+rankprofile[5].fef.property[10].name "vespa.rank.firstphase"
+rankprofile[5].fef.property[10].value "rankingExpression(firstphase)"
+rankprofile[5].fef.property[11].name "rankingExpression(firstphase).rankingScript"
+rankprofile[5].fef.property[11].value "-attribute(score)"
+rankprofile[5].fef.property[12].name "vespa.rank.secondphase"
+rankprofile[5].fef.property[12].value "rankingExpression(secondphase)"
+rankprofile[5].fef.property[13].name "rankingExpression(secondphase).rankingScript"
+rankprofile[5].fef.property[13].value "reduce(onnx(multiply_add).multiply_add_output - rankingExpression(fn_query_vec), sum)"
+rankprofile[5].fef.property[14].name "vespa.summary.feature"
+rankprofile[5].fef.property[14].value "query(query_vec)"
+rankprofile[5].fef.property[15].name "vespa.summary.feature"
+rankprofile[5].fef.property[15].value "onnx(multiply_add).multiply_add_output"
+rankprofile[5].fef.property[16].name "vespa.hitcollector.heapsize"
+rankprofile[5].fef.property[16].value "3"
+rankprofile[5].fef.property[17].name "vespa.type.attribute.doc_vec"
+rankprofile[5].fef.property[17].value "tensor(d0[2])"
+rankprofile[5].fef.property[18].name "vespa.type.query.query_vec"
+rankprofile[5].fef.property[18].value "tensor(d0[2])"
diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg
new file mode 100644
index 00000000000..9ee78311b79
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg
@@ -0,0 +1,3 @@
+constant[0].name "matrix"
+constant[0].fileref "matrix.json"
+constant[0].type "tensor(d0[2],d1[2])"
diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg
new file mode 100644
index 00000000000..e69de29bb2d
--- /dev/null
+++ b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg