aboutsummaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorLester Solbakken <lesters@oath.com>2021-11-24 12:47:23 +0100
committerLester Solbakken <lesters@oath.com>2021-11-24 12:47:23 +0100
commita40eea2df1d64d8586768f1122da90a5756bef10 (patch)
treeece85a26fe47404a2dfeda1c273640b0bc0f7334
parentbf026551a16d12d4b4e75949933d080f39a85eef (diff)
Remove MultiFunctionEvaluator
-rw-r--r--config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java12
-rw-r--r--model-evaluation/abi-spec.json20
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java126
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java2
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java24
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java11
-rw-r--r--model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java120
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java4
-rw-r--r--model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java31
9 files changed, 125 insertions, 225 deletions
diff --git a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
index 70c4cb942bc..8ed229b2ff5 100644
--- a/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
+++ b/config-model/src/test/java/com/yahoo/vespa/model/container/ml/ModelsEvaluatorTest.java
@@ -3,12 +3,9 @@ package com.yahoo.vespa.model.container.ml;
import ai.vespa.models.evaluation.FunctionEvaluator;
import ai.vespa.models.evaluation.ModelsEvaluator;
-import ai.vespa.models.evaluation.MultiFunctionEvaluator;
import com.yahoo.tensor.Tensor;
import org.junit.Test;
-import java.util.Map;
-
import static org.junit.Assert.assertEquals;
/**
@@ -30,10 +27,11 @@ public class ModelsEvaluatorTest {
Tensor output = mul.bind("input1", input1).bind("input2", input2).evaluate();
assertEquals(6.0, output.sum().asDouble(), 1e-9);
- MultiFunctionEvaluator eval = modelsEvaluator.multiEvaluatorOf("mul");
- Map<String, Tensor> out = eval.bind("input1", input1).bind("input2", input2).evaluate();
- assertEquals(6.0, out.get("output1").sum().asDouble(), 1e-9);
- assertEquals(5.0, out.get("output2").sum().asDouble(), 1e-9);
+ FunctionEvaluator eval = modelsEvaluator.evaluatorOf("mul");
+ output = eval.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, output.sum().asDouble(), 1e-9);
+ assertEquals(6.0, eval.result("output1").sum().asDouble(), 1e-9);
+ assertEquals(5.0, eval.result("output2").sum().asDouble(), 1e-9);
// LightGBM model evaluation
FunctionEvaluator lgbm = modelsEvaluator.evaluatorOf("lightgbm_regression");
diff --git a/model-evaluation/abi-spec.json b/model-evaluation/abi-spec.json
index 3f23e7456ad..71dd7ffc2eb 100644
--- a/model-evaluation/abi-spec.json
+++ b/model-evaluation/abi-spec.json
@@ -6,6 +6,7 @@
"public"
],
"methods": [
+ "public com.yahoo.tensor.Tensor result(java.lang.String)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, double)",
"public ai.vespa.models.evaluation.FunctionEvaluator bind(java.lang.String, java.lang.String)",
@@ -13,7 +14,8 @@
"public ai.vespa.models.evaluation.FunctionEvaluator setMissingValue(double)",
"public com.yahoo.tensor.Tensor evaluate()",
"public com.yahoo.searchlib.rankingexpression.ExpressionFunction function()",
- "public ai.vespa.models.evaluation.LazyArrayContext context()"
+ "public ai.vespa.models.evaluation.LazyArrayContext context()",
+ "public java.util.List outputs()"
],
"fields": []
},
@@ -56,7 +58,6 @@
"public java.lang.String name()",
"public java.util.List functions()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String[])",
- "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String[])",
"public java.lang.String toString()"
],
"fields": []
@@ -73,25 +74,10 @@
"public void <init>(java.util.Map)",
"public java.util.Map models()",
"public varargs ai.vespa.models.evaluation.FunctionEvaluator evaluatorOf(java.lang.String, java.lang.String[])",
- "public varargs ai.vespa.models.evaluation.MultiFunctionEvaluator multiEvaluatorOf(java.lang.String, java.lang.String[])",
"public ai.vespa.models.evaluation.Model requireModel(java.lang.String)"
],
"fields": []
},
- "ai.vespa.models.evaluation.MultiFunctionEvaluator": {
- "superClass": "java.lang.Object",
- "interfaces": [],
- "attributes": [
- "public"
- ],
- "methods": [
- "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, com.yahoo.tensor.Tensor)",
- "public ai.vespa.models.evaluation.MultiFunctionEvaluator bind(java.lang.String, double)",
- "public java.util.Map evaluate()",
- "public java.util.List functions()"
- ],
- "fields": []
- },
"ai.vespa.models.evaluation.RankProfilesConfigImporter": {
"superClass": "java.lang.Object",
"interfaces": [],
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
index aa13cb96845..7a992cb7aa9 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/FunctionEvaluator.java
@@ -8,24 +8,37 @@ import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.HashMap;
+import java.util.HashSet;
+import java.util.List;
import java.util.Map;
+import java.util.Set;
import java.util.stream.Collectors;
/**
- * An evaluator which can be used to evaluate a single function once.
+ * An evaluator which can be used to evaluate a function once.
*
* @author bratseth
*/
// This wraps all access to the context and the ranking expression to avoid incorrect usage
public class FunctionEvaluator {
- private final ExpressionFunction function;
- private final LazyArrayContext context;
+ private final List<ExpressionFunction> functions;
+ private final Map<String, LazyArrayContext> contexts;
+ private final Map<String, Tensor> results;
private boolean evaluated = false;
FunctionEvaluator(ExpressionFunction function, LazyArrayContext context) {
- this.function = function;
- this.context = context;
+ this(List.of(function), Map.of(function.getName(), context));
+ }
+
+ FunctionEvaluator(List<ExpressionFunction> functions, Map<String, LazyArrayContext> contexts) {
+ this.functions = List.copyOf(functions);
+ this.contexts = Map.copyOf(contexts);
+ this.results = new HashMap<>();
+ }
+
+ public Tensor result(String name) {
+ return results.get(name);
}
/**
@@ -38,15 +51,14 @@ public class FunctionEvaluator {
public FunctionEvaluator bind(String name, Tensor value) {
if (evaluated)
throw new IllegalStateException("Cannot bind a new value in a used evaluator");
- TensorType requiredType = function.argumentTypes().get(name);
- if (requiredType == null)
- throw new IllegalArgumentException("'" + name + "' is not a valid argument in " + function +
- ". Expected arguments: " + function.argumentTypes().entrySet().stream()
- .map(e -> e.getKey() + ": " + e.getValue())
- .collect(Collectors.joining(", ")));
- if ( ! value.type().isAssignableTo(requiredType))
- throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
- context.put(name, new TensorValue(value));
+ for (ExpressionFunction function : functions) {
+ if (function.argumentTypes().containsKey(name)) {
+ TensorType requiredType = function.argumentTypes().get(name);
+ if ( ! value.type().isAssignableTo(requiredType))
+ throw new IllegalArgumentException("'" + name + "' must be of type " + requiredType + ", not " + value.type());
+ contexts.get(function.getName()).put(name, new TensorValue(value));
+ }
+ }
return this;
}
@@ -73,7 +85,11 @@ public class FunctionEvaluator {
public FunctionEvaluator bind(String name, String value) {
if (evaluated)
throw new IllegalStateException("Cannot bind a new value in a used evaluator");
- context.put(name, new StringValue(value));
+ for (ExpressionFunction function : functions) {
+ if (function.argumentTypes().containsKey(name)) {
+ contexts.get(function.getName()).put(name, new StringValue(value));
+ }
+ }
return this;
}
@@ -86,7 +102,9 @@ public class FunctionEvaluator {
public FunctionEvaluator setMissingValue(Tensor value) {
if (evaluated)
throw new IllegalStateException("Cannot change the missing value in a used evaluator");
- context.setMissingValue(value);
+ for (LazyArrayContext context : contexts.values()) {
+ context.setMissingValue(value);
+ }
return this;
}
@@ -102,18 +120,31 @@ public class FunctionEvaluator {
public Tensor evaluate() {
checkArguments();
- evaluated = true;
evaluateOnnxModels();
- return function.getBody().evaluate(context).asTensor();
+
+ Tensor defaultResult = null;
+ for (ExpressionFunction function: functions) {
+ LazyArrayContext context = contexts.get(function.getName());
+ Tensor result = function.getBody().evaluate(context).asTensor();
+ results.put(function.getName(), function.getBody().evaluate(context).asTensor());
+ if (defaultResult == null) {
+ defaultResult = result;
+ }
+ }
+ evaluated = true;
+ return defaultResult;
}
void checkArguments() {
- for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
- checkArgument(argument.getKey(), argument.getValue());
+ for (ExpressionFunction function : functions) {
+ LazyArrayContext context = contexts.get(function.getName());
+ for (Map.Entry<String, TensorType> argument : function.argumentTypes().entrySet()) {
+ checkArgument(argument.getKey(), argument.getValue(), context);
+ }
}
}
- private void checkArgument(String name, TensorType type) {
+ private void checkArgument(String name, TensorType type, LazyArrayContext context) {
if (context.isMissing(name))
throw new IllegalStateException("Missing argument '" + name + "': Must be bound to a value of type " + type);
if (! context.get(name).type().isAssignableTo(type))
@@ -124,23 +155,52 @@ public class FunctionEvaluator {
* Evaluate ONNX models (if not already evaluated) and add the result back to the context.
*/
private void evaluateOnnxModels() {
- for (Map.Entry<String, OnnxModel> entry : context().onnxModels().entrySet()) {
- String onnxFeature = entry.getKey();
- OnnxModel onnxModel = entry.getValue();
- if (context.get(onnxFeature).equals(context.defaultValue())) {
- Map<String, Tensor> inputs = new HashMap<>();
- for (Map.Entry<String, TensorType> input: onnxModel.inputs().entrySet()) {
- inputs.put(input.getKey(), context.get(input.getKey()).asTensor());
+ Set<OnnxModel> onnxModels = new HashSet<>();
+ for (LazyArrayContext context : contexts.values()) {
+ onnxModels.addAll(context.onnxModels().values());
+ }
+
+ for (OnnxModel onnxModel : onnxModels) {
+
+ // Gather inputs from all functions. Inputs with the same name must have the same value.
+ Map<String, Tensor> inputs = new HashMap<>();
+ for (LazyArrayContext context : contexts.values()) {
+ for (OnnxModel functionModel : context.onnxModels().values()) {
+ if (functionModel.name().equals(onnxModel.name())) {
+ for (String inputName: onnxModel.inputs().keySet()) {
+ inputs.put(inputName, context.get(inputName).asTensor());
+ }
+ }
}
- Tensor result = onnxModel.evaluate(inputs, function.getName()); // Function name is output of model
- context.put(onnxFeature, new TensorValue(result));
}
+
+ // Evaluate model once.
+ Map<String, Tensor> outputs = onnxModel.evaluate(inputs);
+
+ // Add outputs back to the context of the functions that need them; they won't be recalculated.
+ for (ExpressionFunction function : functions) {
+ LazyArrayContext context = contexts.get(function.getName());
+ for (Map.Entry<String, OnnxModel> entry : context.onnxModels().entrySet()) {
+ String onnxFeature = entry.getKey();
+ OnnxModel functionModel = entry.getValue();
+ if (functionModel.name().equals(onnxModel.name())) {
+ Tensor result = outputs.get(function.getName()); // Function name is output of model
+ context.put(onnxFeature, new TensorValue(result));
+ }
+ }
+ }
+
}
}
- /** Returns the function evaluated by this */
- public ExpressionFunction function() { return function; }
+ /** Returns the default function evaluated by this */
+ public ExpressionFunction function() { return functions.get(0); }
- public LazyArrayContext context() { return context; }
+ public LazyArrayContext context() { return contexts.get(function().getName()); }
+
+ /** Returns the names of the outputs of this function */
+ public List<String> outputs() {
+ return functions.stream().map(ExpressionFunction::getName).collect(Collectors.toList());
+ }
}
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
index d97235d11d2..cc53f38f800 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/LazyArrayContext.java
@@ -153,7 +153,7 @@ public final class LazyArrayContext extends Context implements ContextIndex {
/** The mapping from variable name to index */
private final ImmutableMap<String, Integer> nameToIndex;
- /** The names which needs to be bound externally when invoking this (i.e not constant or invocation */
+ /** The names which needs to be bound externally when invoking this (i.e. not constant or invocation) */
private final ImmutableSet<String> arguments;
/** The current values set */
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
index 84ab6e81840..ab24986e542 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/Model.java
@@ -12,6 +12,7 @@ import com.yahoo.tensor.TensorType;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
+import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -181,9 +182,7 @@ public class Model {
*/
public FunctionEvaluator evaluatorOf(String ... names) { // TODO: Parameter overloading?
if (names.length == 0) {
- if (functions.size() > 1)
- throwUndeterminedFunction("More than one function is available in " + this + ", but no name is given");
- return evaluatorOf(functions.get(0));
+ return evaluatorOf(functions);
}
else if (names.length == 1) {
String name = names[0];
@@ -230,20 +229,13 @@ public class Model {
return new FunctionEvaluator(function, requireContextPrototype(function.getName()).copy());
}
- /**
- * Returns an evaluator which can be used to evaluate the given model in a single thread once.
- *
- * @param names The names identifying the outputs. If none are passed, evaluates all outputs.
- * @throws IllegalArgumentException if the function is not present.
- */
- public MultiFunctionEvaluator multiEvaluatorOf(String ... names) {
- List<FunctionEvaluator> evaluators;
- if (names.length == 0) {
- evaluators = functions.stream().map(this::evaluatorOf).collect(Collectors.toList());
- } else {
- evaluators = Arrays.stream(names).map(this::evaluatorOf).collect(Collectors.toList());
+ /** Returns a single-use evaluator of a function */
+ private FunctionEvaluator evaluatorOf(List<ExpressionFunction> functions) {
+ Map<String, LazyArrayContext> contexts = new HashMap<>();
+ for (ExpressionFunction function : functions) {
+ contexts.put(function.getName(), requireContextPrototype(function.getName()).copy());
}
- return new MultiFunctionEvaluator(evaluators);
+ return new FunctionEvaluator(functions, contexts);
}
private void throwUndeterminedFunction(String message) {
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
index bd00f5510c6..01427ca811a 100644
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
+++ b/model-evaluation/src/main/java/ai/vespa/models/evaluation/ModelsEvaluator.java
@@ -60,17 +60,6 @@ public class ModelsEvaluator extends AbstractComponent {
return requireModel(modelName).evaluatorOf(names);
}
- /**
- * Returns a model evaluator which can be used to evaluate multiple functions in a model
- *
- * @param modelName the name of the model
- * @param names the names of the outputs to evaluate, or none if all should be evaluated
- * @throws IllegalArgumentException if the function or model is not present
- */
- public MultiFunctionEvaluator multiEvaluatorOf(String modelName, String ... names) {
- return requireModel(modelName).multiEvaluatorOf(names);
- }
-
/** Returns the given model, or throws a IllegalArgumentException if it does not exist */
public Model requireModel(String name) {
Model model = models.get(name);
diff --git a/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java b/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java
deleted file mode 100644
index 53d470ecc19..00000000000
--- a/model-evaluation/src/main/java/ai/vespa/models/evaluation/MultiFunctionEvaluator.java
+++ /dev/null
@@ -1,120 +0,0 @@
-// Copyright Yahoo. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
-package ai.vespa.models.evaluation;
-
-import com.yahoo.searchlib.rankingexpression.evaluation.TensorValue;
-import com.yahoo.tensor.Tensor;
-import com.yahoo.tensor.TensorType;
-
-import java.util.HashMap;
-import java.util.HashSet;
-import java.util.List;
-import java.util.Map;
-import java.util.Set;
-
-/**
- * An evaluator which can be used to evaluate a model with multiple outputs.
- * This will ensure that ONNX models are only evaluated once.
- *
- * @author lesters
- */
-public class MultiFunctionEvaluator {
-
- private final List<FunctionEvaluator> functions;
- private boolean evaluated = false;
-
- MultiFunctionEvaluator(List<FunctionEvaluator> functions) {
- this.functions = functions;
- }
-
- /**
- * Binds the given variable referred in this expression to the given value.
- *
- * @param name the variable to bind
- * @param value the value this becomes bound to
- * @return this for chaining
- */
- public MultiFunctionEvaluator bind(String name, Tensor value) {
- if (evaluated)
- throw new IllegalStateException("Cannot bind a new value in a used evaluator");
- for (FunctionEvaluator function : functions) {
- if (function.function().argumentTypes().containsKey(name)) {
- function.bind(name, value); // only bind input to the functions that need them
- }
- }
- return this;
- }
-
- /**
- * Binds the given variable referred in this expression to the given value.
- * This is equivalent to <code>bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build())</code>
- *
- * @param name the variable to bind
- * @param value the value this becomes bound to
- * @return this for chaining
- */
- public MultiFunctionEvaluator bind(String name, double value) {
- return bind(name, Tensor.Builder.of(TensorType.empty).cell(value).build());
- }
-
- public Map<String, Tensor> evaluate() {
- for (FunctionEvaluator function : functions) {
- function.checkArguments();
- }
-
- evaluateOnnxModels(); // evaluate each ONNX model only once
-
- Map<String, Tensor> results = new HashMap<>();
- for (FunctionEvaluator function : functions) {
- results.put(function.function().getName(), function.evaluate());
- }
- evaluated = true;
- return results;
- }
-
- /**
- * Evaluate all ONNX models across all functions once and add the result
- * back to the functions' context.
- */
- private void evaluateOnnxModels() {
- Set<OnnxModel> onnxModels = new HashSet<>();
- for (FunctionEvaluator function : functions) {
- onnxModels.addAll(function.context().onnxModels().values());
- }
-
- for (OnnxModel onnxModel : onnxModels) {
-
- // Gather inputs from all functions. Inputs with the same name must have the same value.
- Map<String, Tensor> inputs = new HashMap<>();
- for (FunctionEvaluator function : functions) {
- for (OnnxModel functionModel : function.context().onnxModels().values()) {
- if (functionModel.name().equals(onnxModel.name())) {
- for (String inputName: onnxModel.inputs().keySet()) {
- inputs.put(inputName, function.context().get(inputName).asTensor());
- }
- }
- }
- }
-
- // Evaluate model once.
- Map<String, Tensor> outputs = onnxModel.evaluate(inputs);
-
- // Add outputs back to the context of the functions that need them; they won't be recalculated.
- for (FunctionEvaluator function : functions) {
- for (Map.Entry<String, OnnxModel> entry : function.context().onnxModels().entrySet()) {
- String onnxFeature = entry.getKey();
- OnnxModel functionModel = entry.getValue();
- if (functionModel.name().equals(onnxModel.name())) {
- Tensor result = outputs.get(function.function().getName()); // Function name is output of model
- function.context().put(onnxFeature, new TensorValue(result));
- }
- }
- }
-
- }
- }
-
- public List<FunctionEvaluator> functions() {
- return functions;
- }
-
-}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
index 4cb52216137..3e065d25ad2 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/ModelsEvaluatorTest.java
@@ -95,8 +95,8 @@ public class ModelsEvaluatorTest {
evaluator.bind("argNone", Tensor.from(TensorType.fromSpec("tensor(d1{})"), "{{d1:foo}:0.1}"));
evaluator.evaluate();
}
- catch (IllegalArgumentException e) {
- assertEquals("'argNone' is not a valid argument in function 'test'. Expected arguments: arg2: tensor(d1{}), arg1: tensor(d0[1])",
+ catch (IllegalStateException e) {
+ assertEquals("Argument 'arg2' must be bound to a value of type tensor(d1{})",
Exceptions.toMessageString(e));
}
diff --git a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
index 59ab378e43a..ae77af264a1 100644
--- a/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
+++ b/model-evaluation/src/test/java/ai/vespa/models/evaluation/OnnxEvaluatorTest.java
@@ -35,27 +35,22 @@ public class OnnxEvaluatorTest {
assertTrue(models.models().containsKey("add_mul"));
assertTrue(models.models().containsKey("one_layer"));
+ Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
+ Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
+
FunctionEvaluator function = models.evaluatorOf("add_mul", "output1");
- function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]"));
- function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
- assertEquals(6.0, function.evaluate().sum().asDouble(), delta);
+ Tensor result = function.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, result.sum().asDouble(), delta);
function = models.evaluatorOf("add_mul", "output2");
- function.bind("input1", Tensor.from("tensor<float>(d0[1]):[2]"));
- function.bind("input2", Tensor.from("tensor<float>(d0[1]):[3]"));
- assertEquals(5.0, function.evaluate().sum().asDouble(), delta);
-
- MultiFunctionEvaluator evaluator = models.multiEvaluatorOf("add_mul");
- Tensor input1 = Tensor.from("tensor<float>(d0[1]):[2]");
- Tensor input2 = Tensor.from("tensor<float>(d0[1]):[3]");
- Map<String, Tensor> result = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
- assertEquals(6.0, result.get("output1").sum().asDouble(), delta);
- assertEquals(5.0, result.get("output2").sum().asDouble(), delta);
-
- evaluator = models.multiEvaluatorOf("add_mul", "output1");
- result = evaluator.bind("input1", input1).bind("input2", input2).evaluate();
- assertTrue("Result does not contain requested output", result.containsKey("output1"));
- assertFalse("Result contains output that was not requested", result.containsKey("output2"));
+ result = function.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(5.0, result.sum().asDouble(), delta);
+
+ function = models.evaluatorOf("add_mul"); // contains two models
+ result = function.bind("input1", input1).bind("input2", input2).evaluate();
+ assertEquals(6.0, result.sum().asDouble(), delta);
+ assertEquals(6.0, function.result("output1").sum().asDouble(), delta);
+ assertEquals(5.0, function.result("output2").sum().asDouble(), delta);
function = models.evaluatorOf("one_layer");
function.bind("input", Tensor.from("tensor<float>(d0[2],d1[3]):[[0.1, 0.2, 0.3],[0.4,0.5,0.6]]"));