summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-23 16:20:32 -0700
committerJon Bratseth <bratseth@oath.com>2018-09-23 16:20:32 -0700
commit4e44e5472829c033c3d995c618f2febcc4463eb7 (patch)
tree402dc48f0fce44759ce7bca8068c6b98097dd031
parent2ee637ff5ef12924e77d5fbf087fb9fb803f0143 (diff)
Use ExpressionFunction
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java8
-rw-r--r--config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java57
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java62
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java11
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java9
9 files changed, 88 insertions, 88 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
index 04481a3bc8d..2c6a7941772 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java
@@ -237,8 +237,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()),
model.name(), profile, queryProfiles.getRegistry(), model);
- for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) {
- profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs
+ for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) {
+ profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use arguments
}
}
}
@@ -249,8 +249,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri
RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry);
rankProfileRegistry.add(profile);
ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile);
- for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) {
- profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs
+ for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) {
+ profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use inputs
}
}
}
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
index d72a22f7c5e..fb0109ed32e 100644
--- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
+++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java
@@ -48,6 +48,7 @@ import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -67,14 +68,14 @@ public class ConvertedModel {
private final ModelName modelName;
private final String modelDescription;
- private final ImmutableMap<String, ImportedModel.ExpressionWithInputs> expressions;
+ private final ImmutableMap<String, ExpressionFunction> expressions;
/** The source importedModel, or empty if this was created from a stored converted model */
private final Optional<ImportedModel> sourceModel;
private ConvertedModel(ModelName modelName,
String modelDescription,
- Map<String, ImportedModel.ExpressionWithInputs> expressions,
+ Map<String, ExpressionFunction> expressions,
Optional<ImportedModel> sourceModel) {
this.modelName = modelName;
this.modelDescription = modelDescription;
@@ -132,23 +133,23 @@ public class ConvertedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public Map<String, ImportedModel.ExpressionWithInputs> expressions() { return expressions; }
+ public Map<String, ExpressionFunction> expressions() { return expressions; }
/**
* Returns the expression matching the given arguments.
*/
public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) {
- ImportedModel.ExpressionWithInputs expression = selectExpression(arguments);
+ ExpressionFunction expression = selectExpression(arguments);
if (sourceModel.isPresent()) // we should verify
- verifyInputs(expression.expression(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
- return expression.expression().getRoot();
+ verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles());
+ return expression.getBody().getRoot();
}
- private ImportedModel.ExpressionWithInputs selectExpression(FeatureArguments arguments) {
+ private ExpressionFunction selectExpression(FeatureArguments arguments) {
if (expressions.isEmpty())
throw new IllegalArgumentException("No expressions available in " + this);
- ImportedModel.ExpressionWithInputs expression = expressions.get(arguments.toName());
+ ExpressionFunction expression = expressions.get(arguments.toName());
if (expression != null) return expression;
if ( ! arguments.signature().isPresent()) {
@@ -158,7 +159,7 @@ public class ConvertedModel {
}
if ( ! arguments.output().isPresent()) {
- List<Map.Entry<String, ImportedModel.ExpressionWithInputs>> entriesWithTheRightPrefix =
+ List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix =
expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList());
if (entriesWithTheRightPrefix.size() < 1)
throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() +
@@ -179,10 +180,10 @@ public class ConvertedModel {
// ----------------------- Static model conversion/storage below here
- private static Map<String, ImportedModel.ExpressionWithInputs> convertAndStore(ImportedModel model,
- RankProfile profile,
- QueryProfileRegistry queryProfiles,
- ModelStore store) {
+ private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model,
+ RankProfile profile,
+ QueryProfileRegistry queryProfiles,
+ ModelStore store) {
// Add constants
Set<String> constantsReplacedByFunctions = new HashSet<>();
model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v));
@@ -193,8 +194,8 @@ public class ConvertedModel {
addGeneratedFunctions(model, profile);
// Add expressions
- Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>();
- for (Pair<String, ImportedModel.ExpressionWithInputs> output : model.outputExpressions()) {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
+ for (Pair<String, ExpressionFunction> output : model.outputExpressions()) {
addExpression(output.getSecond(), output.getFirst(),
constantsReplacedByFunctions,
model, store, profile, queryProfiles,
@@ -210,21 +211,21 @@ public class ConvertedModel {
return expressions;
}
- private static void addExpression(ImportedModel.ExpressionWithInputs expression,
+ private static void addExpression(ExpressionFunction expression,
String expressionName,
Set<String> constantsReplacedByFunctions,
ImportedModel model,
ModelStore store,
RankProfile profile,
QueryProfileRegistry queryProfiles,
- Map<String, ImportedModel.ExpressionWithInputs> expressions) {
- expression = expression.with(replaceConstantsByFunctions(expression.expression(), constantsReplacedByFunctions));
- reduceBatchDimensions(expression.expression(), model, profile, queryProfiles);
+ Map<String, ExpressionFunction> expressions) {
+ expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions));
+ reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles);
store.writeExpression(expressionName, expression);
expressions.put(expressionName, expression);
}
- private static Map<String, ImportedModel.ExpressionWithInputs> convertStored(ModelStore store, RankProfile profile) {
+ private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) {
for (Pair<String, Tensor> constant : store.readSmallConstants())
profile.addConstant(constant.getFirst(), asValue(constant.getSecond()));
@@ -525,15 +526,15 @@ public class ConvertedModel {
* @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part
* is always the model name
*/
- void writeExpression(String name, ImportedModel.ExpressionWithInputs expression) {
- StringBuilder b = new StringBuilder(expression.expression().getRoot().toString());
- for (Map.Entry<String, TensorType> input : expression.inputs().entrySet())
+ void writeExpression(String name, ExpressionFunction expression) {
+ StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString());
+ for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet())
b.append('\n').append(input.getKey()).append('\t').append(input.getValue());
application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString()));
}
- Map<String, ImportedModel.ExpressionWithInputs> readExpressions() {
- Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>();
+ Map<String, ExpressionFunction> readExpressions() {
+ Map<String, ExpressionFunction> expressions = new HashMap<>();
ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath());
if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap();
for (ApplicationFile expressionFile : expressionPath.listFiles()) {
@@ -551,18 +552,18 @@ public class ConvertedModel {
return expressions;
}
- private ImportedModel.ExpressionWithInputs readExpression(String name, BufferedReader reader)
+ private ExpressionFunction readExpression(String name, BufferedReader reader)
throws IOException, ParseException {
// First line is expression
RankingExpression expression = new RankingExpression(name, reader.readLine());
// Next lines are inputs on the format name\ttensorTypeSpec
- Map<String, TensorType> inputs = new HashMap<>();
+ Map<String, TensorType> inputs = new LinkedHashMap<>();
String line;
while (null != (line = reader.readLine())) {
String[] parts = line.split("\t");
inputs.put(parts[0], TensorType.fromSpec(parts[1]));
}
- return new ImportedModel.ExpressionWithInputs(expression, inputs);
+ return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty());
}
/** Adds this function expression to the application package so it can be read later. */
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 63d3f9df663..848ad68a6c0 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -60,8 +60,8 @@ public class ExpressionFunction {
this(name, arguments, body, ImmutableMap.of(), Optional.empty());
}
- private ExpressionFunction(String name, List<String> arguments, RankingExpression body,
- Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
+ public ExpressionFunction(String name, List<String> arguments, RankingExpression body,
+ Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments);
this.body = Objects.requireNonNull(body, "body cannot be null");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 88b5645e2e5..979487827a8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -11,9 +12,11 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
import java.util.regex.Pattern;
/**
@@ -108,27 +111,38 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public List<Pair<String, ExpressionWithInputs>> outputExpressions() {
- List<Pair<String, ExpressionWithInputs>> expressions = new ArrayList<>();
+ public List<Pair<String, ExpressionFunction>> outputExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
signatureEntry.getValue().outputExpression(outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
- new ExpressionWithInputs(expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap())));
+ new ExpressionFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty())));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
expressions.add(new Pair<>(singleEntry.getKey(),
- new ExpressionWithInputs(singleEntry.getValue(), inputs)));
+ new ExpressionFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
expressions.add(new Pair<>(expressionEntry.getKey(),
- new ExpressionWithInputs(expressionEntry.getValue(), inputs)));
+ new ExpressionFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
}
}
@@ -144,8 +158,8 @@ public class ImportedModel {
public class Signature {
private final String name;
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> inputs = new LinkedHashMap<>();
+ private final Map<String, String> outputs = new LinkedHashMap<>();
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
@@ -190,8 +204,12 @@ public class ImportedModel {
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references */
- public ExpressionWithInputs outputExpression(String outputName) {
- return new ExpressionWithInputs(owner().expressions().get(outputs.get(outputName)), inputMap());
+ public ExpressionFunction outputExpression(String outputName) {
+ return new ExpressionFunction(outputName,
+ new ArrayList<>(inputs.keySet()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
}
@Override
@@ -204,28 +222,4 @@ public class ImportedModel {
}
- /**
- * An expression, with the inputs (bindings) which must be supplied to evaluate it.
- * All non-scalar (non-empty tensor type) inputs are always present here. Inputs not
- * given explicitly here (but present in the expression) are always scalar.
- */
- public static class ExpressionWithInputs {
-
- private final RankingExpression expression;
- private final ImmutableMap<String, TensorType> inputs;
-
- public ExpressionWithInputs(RankingExpression expression, Map<String, TensorType> inputs) {
- this.expression = Objects.requireNonNull(expression, "expression cannot be null");
- this.inputs = ImmutableMap.copyOf(inputs);
- }
-
- public RankingExpression expression() { return expression; }
- public ImmutableMap<String, TensorType> inputs() { return inputs; }
-
- public ExpressionWithInputs with(RankingExpression newExpression) {
- return new ExpressionWithInputs(newExpression, inputs);
- }
-
- }
-
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 3a1c9ec9551..62bbc9ae81f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,11 +21,11 @@ public class BatchNormImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.expression().getName());
- model.assertEqualResult("X", output.expression().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.arguments().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 4c35d843f5d..2a894adc92c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -29,13 +30,13 @@ public class DropoutImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("outputs/Maximum", output.expression().getName());
+ assertEquals("outputs/Maximum", output.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.expression().getRoot().toString());
- model.assertEqualResult("X", output.expression().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
+ output.getBody().getRoot().toString());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
index b3e281ad25d..3d8d5d5a570 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,10 +21,10 @@ public class MnistImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.expression().getName());
- model.assertEqualResultSum("input", output.expression().getName(), 0.00001);
+ assertEquals("dnn/outputs/add", output.getBody().getName());
+ model.assertEqualResultSum("input", output.getBody().getName(), 0.00001);
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index b5655cfbfa5..bcdfde67dc0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,5 +1,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
@@ -41,14 +42,14 @@ public class OnnxMnistSoftmaxImportTestCase {
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
// Check signature
- ImportedModel.ExpressionWithInputs output = model.defaultSignature().outputExpression("add");
+ ExpressionFunction output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
- assertEquals("add", output.expression().getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.expression().getRoot().toString());
+ output.getBody().getRoot().toString());
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
- assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.inputs().toString());
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.getBody().toString());
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index 4a0362c0229..b14a4a5b430 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -56,12 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// ... signature outputs
assertEquals(1, signature.outputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("add", output.expression().getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
- output.expression().getRoot().toString());
- assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
+ output.getBody().getRoot().toString());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");