diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-23 16:20:32 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-23 16:20:32 -0700 |
commit | 4e44e5472829c033c3d995c618f2febcc4463eb7 (patch) | |
tree | 402dc48f0fce44759ce7bca8068c6b98097dd031 | |
parent | 2ee637ff5ef12924e77d5fbf087fb9fb803f0143 (diff) |
Use ExpressionFunction
9 files changed, 88 insertions, 88 deletions
diff --git a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java index 04481a3bc8d..2c6a7941772 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/VespaModel.java @@ -237,8 +237,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromSource(new ModelName(model.name()), model.name(), profile, queryProfiles.getRegistry(), model); - for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs + for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) { + profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use arguments } } } @@ -249,8 +249,8 @@ public final class VespaModel extends AbstractConfigProducerRoot implements Seri RankProfile profile = new RankProfile(modelName, this, rankProfileRegistry); rankProfileRegistry.add(profile); ConvertedModel convertedModel = ConvertedModel.fromStore(new ModelName(modelName), modelName, profile); - for (Map.Entry<String, ImportedModel.ExpressionWithInputs> entry : convertedModel.expressions().entrySet()) { - profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().expression()), false); // TODO: Use inputs + for (Map.Entry<String, ExpressionFunction> entry : convertedModel.expressions().entrySet()) { + profile.addFunction(new ExpressionFunction(entry.getKey(), entry.getValue().getBody()), false); // TODO: Use inputs } } } diff --git a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java index d72a22f7c5e..fb0109ed32e 100644 --- a/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java +++ b/config-model/src/main/java/com/yahoo/vespa/model/ml/ConvertedModel.java @@ -48,6 +48,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.HashMap; import java.util.HashSet; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Optional; @@ -67,14 +68,14 @@ public class ConvertedModel { private final ModelName modelName; private final String modelDescription; - private final ImmutableMap<String, ImportedModel.ExpressionWithInputs> expressions; + private final ImmutableMap<String, ExpressionFunction> expressions; /** The source importedModel, or empty if this was created from a stored converted model */ private final Optional<ImportedModel> sourceModel; private ConvertedModel(ModelName modelName, String modelDescription, - Map<String, ImportedModel.ExpressionWithInputs> expressions, + Map<String, ExpressionFunction> expressions, Optional<ImportedModel> sourceModel) { this.modelName = modelName; this.modelDescription = modelDescription; @@ -132,23 +133,23 @@ public class ConvertedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public Map<String, ImportedModel.ExpressionWithInputs> expressions() { return expressions; } + public Map<String, ExpressionFunction> expressions() { return expressions; } /** * Returns the expression matching the given arguments. */ public ExpressionNode expression(FeatureArguments arguments, RankProfileTransformContext context) { - ImportedModel.ExpressionWithInputs expression = selectExpression(arguments); + ExpressionFunction expression = selectExpression(arguments); if (sourceModel.isPresent()) // we should verify - verifyInputs(expression.expression(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); - return expression.expression().getRoot(); + verifyInputs(expression.getBody(), sourceModel.get(), context.rankProfile(), context.queryProfiles()); + return expression.getBody().getRoot(); } - private ImportedModel.ExpressionWithInputs selectExpression(FeatureArguments arguments) { + private ExpressionFunction selectExpression(FeatureArguments arguments) { if (expressions.isEmpty()) throw new IllegalArgumentException("No expressions available in " + this); - ImportedModel.ExpressionWithInputs expression = expressions.get(arguments.toName()); + ExpressionFunction expression = expressions.get(arguments.toName()); if (expression != null) return expression; if ( ! arguments.signature().isPresent()) { @@ -158,7 +159,7 @@ public class ConvertedModel { } if ( ! arguments.output().isPresent()) { - List<Map.Entry<String, ImportedModel.ExpressionWithInputs>> entriesWithTheRightPrefix = + List<Map.Entry<String, ExpressionFunction>> entriesWithTheRightPrefix = expressions.entrySet().stream().filter(entry -> entry.getKey().startsWith(arguments.signature().get() + ".")).collect(Collectors.toList()); if (entriesWithTheRightPrefix.size() < 1) throw new IllegalArgumentException("No expressions named '" + arguments.signature().get() + @@ -179,10 +180,10 @@ public class ConvertedModel { // ----------------------- Static model conversion/storage below here - private static Map<String, ImportedModel.ExpressionWithInputs> convertAndStore(ImportedModel model, - RankProfile profile, - QueryProfileRegistry queryProfiles, - ModelStore store) { + private static Map<String, ExpressionFunction> convertAndStore(ImportedModel model, + RankProfile profile, + QueryProfileRegistry queryProfiles, + ModelStore store) { // Add constants Set<String> constantsReplacedByFunctions = new HashSet<>(); model.smallConstants().forEach((k, v) -> transformSmallConstant(store, profile, k, v)); @@ -193,8 +194,8 @@ public class ConvertedModel { addGeneratedFunctions(model, profile); // Add expressions - Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>(); - for (Pair<String, ImportedModel.ExpressionWithInputs> output : model.outputExpressions()) { + Map<String, ExpressionFunction> expressions = new HashMap<>(); + for (Pair<String, ExpressionFunction> output : model.outputExpressions()) { addExpression(output.getSecond(), output.getFirst(), constantsReplacedByFunctions, model, store, profile, queryProfiles, @@ -210,21 +211,21 @@ public class ConvertedModel { return expressions; } - private static void addExpression(ImportedModel.ExpressionWithInputs expression, + private static void addExpression(ExpressionFunction expression, String expressionName, Set<String> constantsReplacedByFunctions, ImportedModel model, ModelStore store, RankProfile profile, QueryProfileRegistry queryProfiles, - Map<String, ImportedModel.ExpressionWithInputs> expressions) { - expression = expression.with(replaceConstantsByFunctions(expression.expression(), constantsReplacedByFunctions)); - reduceBatchDimensions(expression.expression(), model, profile, queryProfiles); + Map<String, ExpressionFunction> expressions) { + expression = expression.withBody(replaceConstantsByFunctions(expression.getBody(), constantsReplacedByFunctions)); + reduceBatchDimensions(expression.getBody(), model, profile, queryProfiles); store.writeExpression(expressionName, expression); expressions.put(expressionName, expression); } - private static Map<String, ImportedModel.ExpressionWithInputs> convertStored(ModelStore store, RankProfile profile) { + private static Map<String, ExpressionFunction> convertStored(ModelStore store, RankProfile profile) { for (Pair<String, Tensor> constant : store.readSmallConstants()) profile.addConstant(constant.getFirst(), asValue(constant.getSecond())); @@ -525,15 +526,15 @@ public class ConvertedModel { * @param name the name of this ranking expression - may have 1-3 parts separated by dot where the first part * is always the model name */ - void writeExpression(String name, ImportedModel.ExpressionWithInputs expression) { - StringBuilder b = new StringBuilder(expression.expression().getRoot().toString()); - for (Map.Entry<String, TensorType> input : expression.inputs().entrySet()) + void writeExpression(String name, ExpressionFunction expression) { + StringBuilder b = new StringBuilder(expression.getBody().getRoot().toString()); + for (Map.Entry<String, TensorType> input : expression.argumentTypes().entrySet()) b.append('\n').append(input.getKey()).append('\t').append(input.getValue()); application.getFile(modelFiles.expressionPath(name)).writeFile(new StringReader(b.toString())); } - Map<String, ImportedModel.ExpressionWithInputs> readExpressions() { - Map<String, ImportedModel.ExpressionWithInputs> expressions = new HashMap<>(); + Map<String, ExpressionFunction> readExpressions() { + Map<String, ExpressionFunction> expressions = new HashMap<>(); ApplicationFile expressionPath = application.getFile(modelFiles.expressionsPath()); if ( ! expressionPath.exists() || ! expressionPath.isDirectory()) return Collections.emptyMap(); for (ApplicationFile expressionFile : expressionPath.listFiles()) { @@ -551,18 +552,18 @@ public class ConvertedModel { return expressions; } - private ImportedModel.ExpressionWithInputs readExpression(String name, BufferedReader reader) + private ExpressionFunction readExpression(String name, BufferedReader reader) throws IOException, ParseException { // First line is expression RankingExpression expression = new RankingExpression(name, reader.readLine()); // Next lines are inputs on the format name\ttensorTypeSpec - Map<String, TensorType> inputs = new HashMap<>(); + Map<String, TensorType> inputs = new LinkedHashMap<>(); String line; while (null != (line = reader.readLine())) { String[] parts = line.split("\t"); inputs.put(parts[0], TensorType.fromSpec(parts[1])); } - return new ImportedModel.ExpressionWithInputs(expression, inputs); + return new ExpressionFunction(name, new ArrayList<>(inputs.keySet()), expression, inputs, Optional.empty()); } /** Adds this function expression to the application package so it can be read later. */ diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 63d3f9df663..848ad68a6c0 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -60,8 +60,8 @@ public class ExpressionFunction { this(name, arguments, body, ImmutableMap.of(), Optional.empty()); } - private ExpressionFunction(String name, List<String> arguments, RankingExpression body, - Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { + public ExpressionFunction(String name, List<String> arguments, RankingExpression body, + Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments); this.body = Objects.requireNonNull(body, "body cannot be null"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 88b5645e2e5..979487827a8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -11,9 +12,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.regex.Pattern; /** @@ -108,27 +111,38 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionWithInputs>> outputExpressions() { - List<Pair<String, ExpressionWithInputs>> expressions = new ArrayList<>(); + public List<Pair<String, ExpressionFunction>> outputExpressions() { + List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), signatureEntry.getValue().outputExpression(outputEntry.getKey()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionWithInputs(expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap()))); + new ExpressionFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().keySet()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty()))); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionWithInputs(singleEntry.getValue(), inputs))); + new ExpressionFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty()))); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionWithInputs(expressionEntry.getValue(), inputs))); + new ExpressionFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty()))); } } } @@ -144,8 +158,8 @@ public class ImportedModel { public class Signature { private final String name; - private final Map<String, String> inputs = new HashMap<>(); - private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> inputs = new LinkedHashMap<>(); + private final Map<String, String> outputs = new LinkedHashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); @@ -190,8 +204,12 @@ public class ImportedModel { public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references */ - public ExpressionWithInputs outputExpression(String outputName) { - return new ExpressionWithInputs(owner().expressions().get(outputs.get(outputName)), inputMap()); + public ExpressionFunction outputExpression(String outputName) { + return new ExpressionFunction(outputName, + new ArrayList<>(inputs.keySet()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); } @Override @@ -204,28 +222,4 @@ public class ImportedModel { } - /** - * An expression, with the inputs (bindings) which must be supplied to evaluate it. - * All non-scalar (non-empty tensor type) inputs are always present here. Inputs not - * given explicitly here (but present in the expression) are always scalar. - */ - public static class ExpressionWithInputs { - - private final RankingExpression expression; - private final ImmutableMap<String, TensorType> inputs; - - public ExpressionWithInputs(RankingExpression expression, Map<String, TensorType> inputs) { - this.expression = Objects.requireNonNull(expression, "expression cannot be null"); - this.inputs = ImmutableMap.copyOf(inputs); - } - - public RankingExpression expression() { return expression; } - public ImmutableMap<String, TensorType> inputs() { return inputs; } - - public ExpressionWithInputs with(RankingExpression newExpression) { - return new ExpressionWithInputs(newExpression, inputs); - } - - } - } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 3a1c9ec9551..62bbc9ae81f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -20,11 +21,11 @@ public class BatchNormImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.expression().getName()); - model.assertEqualResult("X", output.expression().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); + model.assertEqualResult("X", output.getBody().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.arguments().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 4c35d843f5d..2a894adc92c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -29,13 +30,13 @@ public class DropoutImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("outputs/Maximum", output.expression().getName()); + assertEquals("outputs/Maximum", output.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.expression().getRoot().toString()); - model.assertEqualResult("X", output.expression().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); + output.getBody().getRoot().toString()); + model.assertEqualResult("X", output.getBody().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java index b3e281ad25d..3d8d5d5a570 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -20,10 +21,10 @@ public class MnistImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.expression().getName()); - model.assertEqualResultSum("input", output.expression().getName(), 0.00001); + assertEquals("dnn/outputs/add", output.getBody().getName()); + model.assertEqualResultSum("input", output.getBody().getName(), 0.00001); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index b5655cfbfa5..bcdfde67dc0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; @@ -41,14 +42,14 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); // Check signature - ImportedModel.ExpressionWithInputs output = model.defaultSignature().outputExpression("add"); + ExpressionFunction output = model.defaultSignature().outputExpression("add"); assertNotNull(output); - assertEquals("add", output.expression().getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.expression().getRoot().toString()); + output.getBody().getRoot().toString()); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); - assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.inputs().toString()); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.getBody().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 4a0362c0229..b14a4a5b430 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -56,12 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase { // ... signature outputs assertEquals(1, signature.outputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("add", output.expression().getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.expression().getRoot().toString()); - assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); + output.getBody().getRoot().toString()); + assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); |