summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-23 16:20:32 -0700
committerJon Bratseth <bratseth@oath.com>2018-09-23 16:20:32 -0700
commit4e44e5472829c033c3d995c618f2febcc4463eb7 (patch)
tree402dc48f0fce44759ce7bca8068c6b98097dd031 /searchlib
parent2ee637ff5ef12924e77d5fbf087fb9fb803f0143 (diff)
Use ExpressionFunction
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java4
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java62
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java11
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java7
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java9
7 files changed, 55 insertions, 56 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index 63d3f9df663..848ad68a6c0 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -60,8 +60,8 @@ public class ExpressionFunction {
this(name, arguments, body, ImmutableMap.of(), Optional.empty());
}
- private ExpressionFunction(String name, List<String> arguments, RankingExpression body,
- Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
+ public ExpressionFunction(String name, List<String> arguments, RankingExpression body,
+ Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
this.name = Objects.requireNonNull(name, "name cannot be null");
this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments);
this.body = Objects.requireNonNull(body, "body cannot be null");
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 88b5645e2e5..979487827a8 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -3,6 +3,7 @@ package com.yahoo.searchlib.rankingexpression.integration.ml;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -11,9 +12,11 @@ import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
+import java.util.Optional;
import java.util.regex.Pattern;
/**
@@ -108,27 +111,38 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public List<Pair<String, ExpressionWithInputs>> outputExpressions() {
- List<Pair<String, ExpressionWithInputs>> expressions = new ArrayList<>();
+ public List<Pair<String, ExpressionFunction>> outputExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
signatureEntry.getValue().outputExpression(outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
- new ExpressionWithInputs(expressions().get(signatureEntry.getKey()),
- signatureEntry.getValue().inputMap())));
+ new ExpressionFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty())));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
expressions.add(new Pair<>(singleEntry.getKey(),
- new ExpressionWithInputs(singleEntry.getValue(), inputs)));
+ new ExpressionFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
expressions.add(new Pair<>(expressionEntry.getKey(),
- new ExpressionWithInputs(expressionEntry.getValue(), inputs)));
+ new ExpressionFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
}
}
@@ -144,8 +158,8 @@ public class ImportedModel {
public class Signature {
private final String name;
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> inputs = new LinkedHashMap<>();
+ private final Map<String, String> outputs = new LinkedHashMap<>();
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
@@ -190,8 +204,12 @@ public class ImportedModel {
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references */
- public ExpressionWithInputs outputExpression(String outputName) {
- return new ExpressionWithInputs(owner().expressions().get(outputs.get(outputName)), inputMap());
+ public ExpressionFunction outputExpression(String outputName) {
+ return new ExpressionFunction(outputName,
+ new ArrayList<>(inputs.keySet()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
}
@Override
@@ -204,28 +222,4 @@ public class ImportedModel {
}
- /**
- * An expression, with the inputs (bindings) which must be supplied to evaluate it.
- * All non-scalar (non-empty tensor type) inputs are always present here. Inputs not
- * given explicitly here (but present in the expression) are always scalar.
- */
- public static class ExpressionWithInputs {
-
- private final RankingExpression expression;
- private final ImmutableMap<String, TensorType> inputs;
-
- public ExpressionWithInputs(RankingExpression expression, Map<String, TensorType> inputs) {
- this.expression = Objects.requireNonNull(expression, "expression cannot be null");
- this.inputs = ImmutableMap.copyOf(inputs);
- }
-
- public RankingExpression expression() { return expression; }
- public ImmutableMap<String, TensorType> inputs() { return inputs; }
-
- public ExpressionWithInputs with(RankingExpression newExpression) {
- return new ExpressionWithInputs(newExpression, inputs);
- }
-
- }
-
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 3a1c9ec9551..62bbc9ae81f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,11 +21,11 @@ public class BatchNormImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.expression().getName());
- model.assertEqualResult("X", output.expression().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.arguments().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 4c35d843f5d..2a894adc92c 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -29,13 +30,13 @@ public class DropoutImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("outputs/Maximum", output.expression().getName());
+ assertEquals("outputs/Maximum", output.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.expression().getRoot().toString());
- model.assertEqualResult("X", output.expression().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
+ output.getBody().getRoot().toString());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
index b3e281ad25d..3d8d5d5a570 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,10 +21,10 @@ public class MnistImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.expression().getName());
- model.assertEqualResultSum("input", output.expression().getName(), 0.00001);
+ assertEquals("dnn/outputs/add", output.getBody().getName());
+ model.assertEqualResultSum("input", output.getBody().getName(), 0.00001);
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index b5655cfbfa5..bcdfde67dc0 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,5 +1,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
@@ -41,14 +42,14 @@ public class OnnxMnistSoftmaxImportTestCase {
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
// Check signature
- ImportedModel.ExpressionWithInputs output = model.defaultSignature().outputExpression("add");
+ ExpressionFunction output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
- assertEquals("add", output.expression().getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.expression().getRoot().toString());
+ output.getBody().getRoot().toString());
assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
- assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.inputs().toString());
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.getBody().toString());
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index 4a0362c0229..b14a4a5b430 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -56,12 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// ... signature outputs
assertEquals(1, signature.outputs().size());
- ImportedModel.ExpressionWithInputs output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("add", output.expression().getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
- output.expression().getRoot().toString());
- assertEquals("{x=tensor(d0[],d1[784])}", output.inputs().toString());
+ output.getBody().getRoot().toString());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.getBody().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");