diff options
Diffstat (limited to 'model-evaluation/src')
10 files changed, 267 insertions, 5 deletions
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java index f173a6b453f..66612b7ccc3 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java @@ -8,11 +8,19 @@ import com.yahoo.searchlib.rankingexpression.evaluation.ExpressionOptimizer; import com.yahoo.stream.CustomCollectors; import com.yahoo.tensor.TensorType; +import com.yahoo.searchlib.rankingexpression.RankingExpression; +import com.yahoo.searchlib.rankingexpression.rule.CompositeNode; +import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode; +import com.yahoo.searchlib.rankingexpression.rule.ReferenceNode; +import com.yahoo.searchlib.rankingexpression.transform.TransformContext; +import com.yahoo.searchlib.rankingexpression.transform.ExpressionTransformer; + import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; +import java.util.logging.Logger; import java.util.stream.Collectors; /** @@ -23,6 +31,8 @@ import java.util.stream.Collectors; @Beta public class Model implements AutoCloseable { + private static final Logger logger = Logger.getLogger(Model.class.getName()); + /** The prefix generated by model-integration/../IntermediateOperation */ private final static String INTERMEDIATE_OPERATION_FUNCTION_PREFIX = "imported_ml_function_"; @@ -54,6 +64,60 @@ public class Model implements AutoCloseable { List.of()); } + static class OnnxReplacer extends ExpressionTransformer<TransformContext> { + private final List<OnnxModel> onnxModels; + private final Map<String, TensorType> declaredTypes; + + private OnnxModel getModel(String name) { + for (var m : onnxModels) if (m.name().equals(name)) return m; + return null; + } + public OnnxReplacer(List<OnnxModel> onnxModels, + Map<String, TensorType> declaredTypes) + { + this.onnxModels = onnxModels; + this.declaredTypes = declaredTypes; + } + + @Override + public ExpressionNode transform(ExpressionNode node, TransformContext context) { + var orig = node; + if (node instanceof ReferenceNode r) { + var ref = r.reference(); + if (ref.name().equals("onnx") || ref.name().equals("onnxModel")) { + logger.fine("consider replacing: " + ref); + var m = getModel(ref.simpleArgument().orElse(null)); + if (m != null) { + // Load the model (if not already loaded) to extract inputs + m.load(); + var expr = m.getExpressionForOutput(ref.output()); + if (expr != null) { + logger.fine("Replacing " + node + " => " + expr); + node = expr; + for (var inputSpec : m.inputSpecs) { + var old = declaredTypes.put(inputSpec.source, inputSpec.wantedType); + if (old != null && ! old.equals(inputSpec.wantedType)) { + throw new IllegalArgumentException("Conflicting types needed for " + inputSpec.source + "; " + old + " != " + inputSpec.wantedType); + } + } + } else { + logger.fine("no output named " + ref.output() + " from " + m); + } + } else { + logger.fine("no onnx model named " + ref.simpleArgument()); + } + } + } + if (node instanceof CompositeNode c) { + node = transformChildren(c, context); + } + if (node != orig) { + logger.fine("transformed: " + orig + " => " + node); + } + return node; + } + } + Model(String name, Map<FunctionReference, ExpressionFunction> functions, Map<FunctionReference, ExpressionFunction> referencedFunctions, @@ -68,6 +132,8 @@ public class Model implements AutoCloseable { Map<String, LazyArrayContext> contextBuilder = new LinkedHashMap<>(); for (Map.Entry<FunctionReference, ExpressionFunction> function : functions.entrySet()) { try { + var body = function.getValue().getBody(); + body.setRoot(new OnnxReplacer(onnxModels, declaredTypes).transform(body.getRoot(), null)); LazyArrayContext context = new LazyArrayContext(function.getValue(), bindingExtractor, referencedFunctions, constants, this); contextBuilder.put(function.getValue().getName(), context); if (function.getValue().returnType().isEmpty()) { diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java index a50d9e36d74..67c016074de 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxExpressionNode.java @@ -82,7 +82,7 @@ class OnnxExpressionNode extends CompositeNode { Value inputValue = inputRefs.get(i).evaluate(context); inputs.put(modelInputs.get(i), inputValue.asTensor()); } - return new TensorValue(model.evaluate(inputs, onnxOutputName)); + return new TensorValue(model.unmappedEvaluate(inputs, onnxOutputName)); } @Override @@ -94,9 +94,9 @@ class OnnxExpressionNode extends CompositeNode { @Override public StringBuilder toString(StringBuilder b, SerializationContext context, Deque<String> path, CompositeNode parent) { b.append("onnx_expression_node(").append(model.name()).append(")"); - if (outputAs != null && ! outputAs.equals("")) { + if (outputAs != null && ! outputAs.equals("")) { b.append(".").append(outputAs); - } - return b; + } + return b; } } diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java index 59febf7cdbf..ad27f9d2d15 100644 --- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java +++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/OnnxModel.java @@ -193,7 +193,11 @@ class OnnxModel implements AutoCloseable { if (onnxName == null) { throw new IllegalArgumentException("evaluate ONNX model " + name() + ": no output available as: " + output); } - return evaluator().evaluate(mapped, onnxName); + return unmappedEvaluate(mapped, onnxName); + } + + Tensor unmappedEvaluate(Map<String, Tensor> inputs, String onnxOutputName) { + return evaluator().evaluate(inputs, onnxOutputName); } private OnnxEvaluator evaluator() { diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java index 65eb55ae46d..ff1f9afcc91 100644 --- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java +++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/RankProfileImportingTest.java @@ -60,4 +60,18 @@ public class RankProfileImportingTest { ModelTester tester = new ModelTester("src/test/resources/config/ranking-macros/"); assertEquals(5, tester.models().size()); } + + @Test + public void testImportingAdvancedGlobalPhase() { + ModelTester tester = new ModelTester("src/test/resources/config/advanced-global-phase/"); + assertEquals(6, tester.models().size()); + Model m = tester.models().get("global_phase"); + assertEquals("global_phase", m.name()); + var func = m.function("globalphase"); + assertEquals("globalphase", func.getName()); + var args = func.argumentTypes(); + assertEquals(2, args.size()); + assertEquals("tensor(d0[2])", args.get("attribute(doc_vec)").toString()); + assertEquals("tensor(d0[2])", args.get("query(query_vec)").toString()); + } } diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json b/model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json new file mode 100644 index 00000000000..cd9bae2501e --- /dev/null +++ b/model-evaluation/src/test/resources/config/advanced-global-phase/matrix.json @@ -0,0 +1,6 @@ +{ "cells": [ + { "address": { "d0": "0", "d1": "0" }, "value": 1.0 }, + { "address": { "d0": "1", "d1": "0" }, "value": 1.0 }, + { "address": { "d0": "0", "d1": "1" }, "value": 1.0 }, + { "address": { "d0": "1", "d1": "1" }, "value": 1.0 } +]} diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx b/model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx new file mode 100644 index 00000000000..5cbd44f6635 --- /dev/null +++ b/model-evaluation/src/test/resources/config/advanced-global-phase/multiply_add.onnx @@ -0,0 +1,23 @@ +:Ý +* +
model_input_1 +
model_input_2XA"MatMul +( +XA +
model_input_3model_output_1"Addmultiply_addZ +
model_input_1 + + +Z +
model_input_2 + + +Z +
model_input_3 + + +b +model_output_1 + + +B
\ No newline at end of file diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg new file mode 100644 index 00000000000..821087cbd8c --- /dev/null +++ b/model-evaluation/src/test/resources/config/advanced-global-phase/onnx-models.cfg @@ -0,0 +1,16 @@ +model[0].name "multiply_add" +model[0].fileref "multiply_add.onnx" +model[0].input[0].name "model_input_2" +model[0].input[0].source "rankingExpression(input_two)" +model[0].input[1].name "model_input_1" +model[0].input[1].source "rankingExpression(input_one)" +model[0].input[2].name "model_input_3" +model[0].input[2].source "rankingExpression(input_three)" +model[0].output[0].name "model_output_1" +model[0].output[0].as "multiply_add_output" +model[0].dry_run_on_setup true +model[0].stateless_execution_mode "" +model[0].stateless_interop_threads -1 +model[0].stateless_intraop_threads -1 +model[0].gpu_device -1 +model[0].gpu_device_required false diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg new file mode 100644 index 00000000000..bb725f13d91 --- /dev/null +++ b/model-evaluation/src/test/resources/config/advanced-global-phase/rank-profiles.cfg @@ -0,0 +1,130 @@ +rankprofile[0].name "default" +rankprofile[0].fef.property[0].name "vespa.type.attribute.doc_vec" +rankprofile[0].fef.property[0].value "tensor(d0[2])" +rankprofile[1].name "unranked" +rankprofile[1].fef.property[0].name "vespa.rank.firstphase" +rankprofile[1].fef.property[0].value "value(0)" +rankprofile[1].fef.property[1].name "vespa.hitcollector.heapsize" +rankprofile[1].fef.property[1].value "0" +rankprofile[1].fef.property[2].name "vespa.hitcollector.arraysize" +rankprofile[1].fef.property[2].value "0" +rankprofile[1].fef.property[3].name "vespa.dump.ignoredefaultfeatures" +rankprofile[1].fef.property[3].value "true" +rankprofile[1].fef.property[4].name "vespa.type.attribute.doc_vec" +rankprofile[1].fef.property[4].value "tensor(d0[2])" +rankprofile[2].name "basebase" +rankprofile[2].fef.property[0].name "vespa.type.attribute.doc_vec" +rankprofile[2].fef.property[0].value "tensor(d0[2])" +rankprofile[3].name "base" +rankprofile[3].fef.property[0].name "rankingExpression(input_one).rankingScript" +rankprofile[3].fef.property[0].value "constant(matrix)" +rankprofile[3].fef.property[1].name "rankingExpression(input_one).type" +rankprofile[3].fef.property[1].value "tensor(d0[2],d1[2])" +rankprofile[3].fef.property[2].name "rankingExpression(input_two).rankingScript" +rankprofile[3].fef.property[2].value "attribute(doc_vec)" +rankprofile[3].fef.property[3].name "rankingExpression(input_two).type" +rankprofile[3].fef.property[3].value "tensor(d0[2])" +rankprofile[3].fef.property[4].name "rankingExpression(input_three).rankingScript" +rankprofile[3].fef.property[4].value "query(query_vec)" +rankprofile[3].fef.property[5].name "rankingExpression(input_three).type" +rankprofile[3].fef.property[5].value "tensor(d0[2])" +rankprofile[3].fef.property[6].name "rankingExpression(fn_query_vec).rankingScript" +rankprofile[3].fef.property[6].value "query(query_vec)" +rankprofile[3].fef.property[7].name "rankingExpression(fn_query_vec).type" +rankprofile[3].fef.property[7].value "tensor(d0[2])" +rankprofile[3].fef.property[8].name "rankingExpression(fn_doc_vec).rankingScript" +rankprofile[3].fef.property[8].value "attribute(doc_vec)" +rankprofile[3].fef.property[9].name "rankingExpression(fn_doc_vec).type" +rankprofile[3].fef.property[9].value "tensor(d0[2])" +rankprofile[3].fef.property[10].name "vespa.rank.firstphase" +rankprofile[3].fef.property[10].value "rankingExpression(firstphase)" +rankprofile[3].fef.property[11].name "rankingExpression(firstphase).rankingScript" +rankprofile[3].fef.property[11].value "-attribute(score)" +rankprofile[3].fef.property[12].name "vespa.summary.feature" +rankprofile[3].fef.property[12].value "query(query_vec)" +rankprofile[3].fef.property[13].name "vespa.summary.feature" +rankprofile[3].fef.property[13].value "onnx(multiply_add).multiply_add_output" +rankprofile[3].fef.property[14].name "vespa.type.attribute.doc_vec" +rankprofile[3].fef.property[14].value "tensor(d0[2])" +rankprofile[3].fef.property[15].name "vespa.type.query.query_vec" +rankprofile[3].fef.property[15].value "tensor(d0[2])" +rankprofile[4].name "global_phase" +rankprofile[4].fef.property[0].name "rankingExpression(fn_query_vec).rankingScript" +rankprofile[4].fef.property[0].value "query(query_vec)" +rankprofile[4].fef.property[1].name "rankingExpression(fn_query_vec).type" +rankprofile[4].fef.property[1].value "tensor(d0[2])" +rankprofile[4].fef.property[2].name "rankingExpression(input_one).rankingScript" +rankprofile[4].fef.property[2].value "constant(matrix)" +rankprofile[4].fef.property[3].name "rankingExpression(input_one).type" +rankprofile[4].fef.property[3].value "tensor(d0[2],d1[2])" +rankprofile[4].fef.property[4].name "rankingExpression(input_two).rankingScript" +rankprofile[4].fef.property[4].value "attribute(doc_vec)" +rankprofile[4].fef.property[5].name "rankingExpression(input_two).type" +rankprofile[4].fef.property[5].value "tensor(d0[2])" +rankprofile[4].fef.property[6].name "rankingExpression(input_three).rankingScript" +rankprofile[4].fef.property[6].value "query(query_vec)" +rankprofile[4].fef.property[7].name "rankingExpression(input_three).type" +rankprofile[4].fef.property[7].value "tensor(d0[2])" +rankprofile[4].fef.property[8].name "rankingExpression(fn_doc_vec).rankingScript" +rankprofile[4].fef.property[8].value "attribute(doc_vec)" +rankprofile[4].fef.property[9].name "rankingExpression(fn_doc_vec).type" +rankprofile[4].fef.property[9].value "tensor(d0[2])" +rankprofile[4].fef.property[10].name "vespa.rank.firstphase" +rankprofile[4].fef.property[10].value "rankingExpression(firstphase)" +rankprofile[4].fef.property[11].name "rankingExpression(firstphase).rankingScript" +rankprofile[4].fef.property[11].value "-attribute(score)" +rankprofile[4].fef.property[12].name "vespa.rank.globalphase" +rankprofile[4].fef.property[12].value "rankingExpression(globalphase)" +rankprofile[4].fef.property[13].name "rankingExpression(globalphase).rankingScript" +rankprofile[4].fef.property[13].value "reduce(onnx(multiply_add).multiply_add_output - rankingExpression(fn_query_vec), sum)" +rankprofile[4].fef.property[14].name "vespa.summary.feature" +rankprofile[4].fef.property[14].value "query(query_vec)" +rankprofile[4].fef.property[15].name "vespa.summary.feature" +rankprofile[4].fef.property[15].value "onnx(multiply_add).multiply_add_output" +rankprofile[4].fef.property[16].name "vespa.match.feature" +rankprofile[4].fef.property[16].value "attribute(doc_vec)" +rankprofile[4].fef.property[17].name "vespa.globalphase.rerankcount" +rankprofile[4].fef.property[17].value "3" +rankprofile[4].fef.property[18].name "vespa.type.attribute.doc_vec" +rankprofile[4].fef.property[18].value "tensor(d0[2])" +rankprofile[4].fef.property[19].name "vespa.type.query.query_vec" +rankprofile[4].fef.property[19].value "tensor(d0[2])" +rankprofile[5].name "second_phase" +rankprofile[5].fef.property[0].name "rankingExpression(fn_query_vec).rankingScript" +rankprofile[5].fef.property[0].value "query(query_vec)" +rankprofile[5].fef.property[1].name "rankingExpression(fn_query_vec).type" +rankprofile[5].fef.property[1].value "tensor(d0[2])" +rankprofile[5].fef.property[2].name "rankingExpression(input_one).rankingScript" +rankprofile[5].fef.property[2].value "constant(matrix)" +rankprofile[5].fef.property[3].name "rankingExpression(input_one).type" +rankprofile[5].fef.property[3].value "tensor(d0[2],d1[2])" +rankprofile[5].fef.property[4].name "rankingExpression(input_two).rankingScript" +rankprofile[5].fef.property[4].value "attribute(doc_vec)" +rankprofile[5].fef.property[5].name "rankingExpression(input_two).type" +rankprofile[5].fef.property[5].value "tensor(d0[2])" +rankprofile[5].fef.property[6].name "rankingExpression(input_three).rankingScript" +rankprofile[5].fef.property[6].value "query(query_vec)" +rankprofile[5].fef.property[7].name "rankingExpression(input_three).type" +rankprofile[5].fef.property[7].value "tensor(d0[2])" +rankprofile[5].fef.property[8].name "rankingExpression(fn_doc_vec).rankingScript" +rankprofile[5].fef.property[8].value "attribute(doc_vec)" +rankprofile[5].fef.property[9].name "rankingExpression(fn_doc_vec).type" +rankprofile[5].fef.property[9].value "tensor(d0[2])" +rankprofile[5].fef.property[10].name "vespa.rank.firstphase" +rankprofile[5].fef.property[10].value "rankingExpression(firstphase)" +rankprofile[5].fef.property[11].name "rankingExpression(firstphase).rankingScript" +rankprofile[5].fef.property[11].value "-attribute(score)" +rankprofile[5].fef.property[12].name "vespa.rank.secondphase" +rankprofile[5].fef.property[12].value "rankingExpression(secondphase)" +rankprofile[5].fef.property[13].name "rankingExpression(secondphase).rankingScript" +rankprofile[5].fef.property[13].value "reduce(onnx(multiply_add).multiply_add_output - rankingExpression(fn_query_vec), sum)" +rankprofile[5].fef.property[14].name "vespa.summary.feature" +rankprofile[5].fef.property[14].value "query(query_vec)" +rankprofile[5].fef.property[15].name "vespa.summary.feature" +rankprofile[5].fef.property[15].value "onnx(multiply_add).multiply_add_output" +rankprofile[5].fef.property[16].name "vespa.hitcollector.heapsize" +rankprofile[5].fef.property[16].value "3" +rankprofile[5].fef.property[17].name "vespa.type.attribute.doc_vec" +rankprofile[5].fef.property[17].value "tensor(d0[2])" +rankprofile[5].fef.property[18].name "vespa.type.query.query_vec" +rankprofile[5].fef.property[18].value "tensor(d0[2])" diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg new file mode 100644 index 00000000000..9ee78311b79 --- /dev/null +++ b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-constants.cfg @@ -0,0 +1,3 @@ +constant[0].name "matrix" +constant[0].fileref "matrix.json" +constant[0].type "tensor(d0[2],d1[2])" diff --git a/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg new file mode 100644 index 00000000000..e69de29bb2d --- /dev/null +++ b/model-evaluation/src/test/resources/config/advanced-global-phase/ranking-expressions.cfg |