diff options
author | Jon Bratseth <bratseth@oath.com> | 2018-09-23 16:20:32 -0700 |
---|---|---|
committer | Jon Bratseth <bratseth@oath.com> | 2018-09-23 16:20:32 -0700 |
commit | 4e44e5472829c033c3d995c618f2febcc4463eb7 (patch) | |
tree | 402dc48f0fce44759ce7bca8068c6b98097dd031 /searchlib | |
parent | 2ee637ff5ef12924e77d5fbf087fb9fb803f0143 (diff) |
Use ExpressionFunction
Diffstat (limited to 'searchlib')
7 files changed, 55 insertions, 56 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index 63d3f9df663..848ad68a6c0 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -60,8 +60,8 @@ public class ExpressionFunction { this(name, arguments, body, ImmutableMap.of(), Optional.empty()); } - private ExpressionFunction(String name, List<String> arguments, RankingExpression body, - Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { + public ExpressionFunction(String name, List<String> arguments, RankingExpression body, + Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) { this.name = Objects.requireNonNull(name, "name cannot be null"); this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments); this.body = Objects.requireNonNull(body, "body cannot be null"); diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 88b5645e2e5..979487827a8 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.yahoo.collections.Pair; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -11,9 +12,11 @@ import java.util.ArrayList; import java.util.Collection; import java.util.Collections; import java.util.HashMap; +import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Objects; +import java.util.Optional; import java.util.regex.Pattern; /** @@ -108,27 +111,38 @@ public class ImportedModel { * if signatures are used, or the expression name if signatures are not used and there are multiple * expressions, and the second is the output name if signature names are used. */ - public List<Pair<String, ExpressionWithInputs>> outputExpressions() { - List<Pair<String, ExpressionWithInputs>> expressions = new ArrayList<>(); + public List<Pair<String, ExpressionFunction>> outputExpressions() { + List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>(); for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) { for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet()) expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(), signatureEntry.getValue().outputExpression(outputEntry.getKey()))); if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), - new ExpressionWithInputs(expressions().get(signatureEntry.getKey()), - signatureEntry.getValue().inputMap()))); + new ExpressionFunction(signatureEntry.getKey(), + new ArrayList<>(signatureEntry.getValue().inputs().keySet()), + expressions().get(signatureEntry.getKey()), + signatureEntry.getValue().inputMap(), + Optional.empty()))); } if (signatures().isEmpty()) { // fallback for models without signatures if (expressions().size() == 1) { Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next(); expressions.add(new Pair<>(singleEntry.getKey(), - new ExpressionWithInputs(singleEntry.getValue(), inputs))); + new ExpressionFunction(singleEntry.getKey(), + new ArrayList<>(inputs.keySet()), + singleEntry.getValue(), + inputs, + Optional.empty()))); } else { for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) { expressions.add(new Pair<>(expressionEntry.getKey(), - new ExpressionWithInputs(expressionEntry.getValue(), inputs))); + new ExpressionFunction(expressionEntry.getKey(), + new ArrayList<>(inputs.keySet()), + expressionEntry.getValue(), + inputs, + Optional.empty()))); } } } @@ -144,8 +158,8 @@ public class ImportedModel { public class Signature { private final String name; - private final Map<String, String> inputs = new HashMap<>(); - private final Map<String, String> outputs = new HashMap<>(); + private final Map<String, String> inputs = new LinkedHashMap<>(); + private final Map<String, String> outputs = new LinkedHashMap<>(); private final Map<String, String> skippedOutputs = new HashMap<>(); private final List<String> importWarnings = new ArrayList<>(); @@ -190,8 +204,12 @@ public class ImportedModel { public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); } /** Returns the expression this output references */ - public ExpressionWithInputs outputExpression(String outputName) { - return new ExpressionWithInputs(owner().expressions().get(outputs.get(outputName)), inputMap()); + public ExpressionFunction outputExpression(String outputName) { + return new ExpressionFunction(outputName, + new ArrayList<>(inputs.keySet()), + owner().expressions().get(outputs.get(outputName)), + inputMap(), + Optional.empty()); } @Override @@ -204,28 +222,4 @@ public class ImportedModel { } - /** - * An expression, with the inputs (bindings) which must be supplied to evaluate it. - * All non-scalar (non-empty tensor type) inputs are always present here. Inputs not - * given explicitly here (but present in the expression) are always scalar. - */ - public static class ExpressionWithInputs { - - private final RankingExpression expression; - private final ImmutableMap<String, TensorType> inputs; - - public ExpressionWithInputs(RankingExpression expression, Map<String, TensorType> inputs) { - this.expression = Objects.requireNonNull(expression, "expression cannot be null"); - this.inputs = ImmutableMap.copyOf(inputs); - } - - public RankingExpression expression() { return expression; } - public ImmutableMap<String, TensorType> inputs() { return inputs; } - - public ExpressionWithInputs with(RankingExpression newExpression) { - return new ExpressionWithInputs(newExpression, inputs); - } - - } - } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 3a1c9ec9551..62bbc9ae81f 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -20,11 +21,11 @@ public class BatchNormImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.expression().getName()); - model.assertEqualResult("X", output.expression().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); + model.assertEqualResult("X", output.getBody().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.arguments().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 4c35d843f5d..2a894adc92c 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.TensorType; import org.junit.Test; @@ -29,13 +30,13 @@ public class DropoutImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("outputs/Maximum", output.expression().getName()); + assertEquals("outputs/Maximum", output.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.expression().getRoot().toString()); - model.assertEqualResult("X", output.expression().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); + output.getBody().getRoot().toString()); + model.assertEqualResult("X", output.getBody().getName()); + assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java index b3e281ad25d..3d8d5d5a570 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import org.junit.Test; @@ -20,10 +21,10 @@ public class MnistImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("dnn/outputs/add", output.expression().getName()); - model.assertEqualResultSum("input", output.expression().getName(), 0.00001); + assertEquals("dnn/outputs/add", output.getBody().getName()); + model.assertEqualResultSum("input", output.getBody().getName(), 0.00001); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java index b5655cfbfa5..bcdfde67dc0 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java @@ -1,5 +1,6 @@ package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.searchlib.rankingexpression.evaluation.Context; import com.yahoo.searchlib.rankingexpression.evaluation.MapContext; @@ -41,14 +42,14 @@ public class OnnxMnistSoftmaxImportTestCase { assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder")); // Check signature - ImportedModel.ExpressionWithInputs output = model.defaultSignature().outputExpression("add"); + ExpressionFunction output = model.defaultSignature().outputExpression("add"); assertNotNull(output); - assertEquals("add", output.expression().getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))", - output.expression().getRoot().toString()); + output.getBody().getRoot().toString()); assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get(model.defaultSignature().inputs().get("Placeholder"))); - assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.inputs().toString()); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.getBody().toString()); } @Test diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 4a0362c0229..b14a4a5b430 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -1,6 +1,7 @@ // Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root. package com.yahoo.searchlib.rankingexpression.integration.ml; +import com.yahoo.searchlib.rankingexpression.ExpressionFunction; import com.yahoo.searchlib.rankingexpression.RankingExpression; import com.yahoo.tensor.Tensor; import com.yahoo.tensor.TensorType; @@ -56,12 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase { // ... signature outputs assertEquals(1, signature.outputs().size()); - ImportedModel.ExpressionWithInputs output = signature.outputExpression("y"); + ExpressionFunction output = signature.outputExpression("y"); assertNotNull(output); - assertEquals("add", output.expression().getName()); + assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", - output.expression().getRoot().toString()); - assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString()); + output.getBody().getRoot().toString()); + assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); |