diff options
author | Lester Solbakken <lesters@users.noreply.github.com> | 2018-10-01 14:22:55 +0200 |
---|---|---|
committer | GitHub <noreply@github.com> | 2018-10-01 14:22:55 +0200 |
commit | fbca8fc6115fbf924cc688d927c50d8e9d99a321 (patch) | |
tree | caf9b072fdaf5b7aff2c6dad5056402caed3a393 /searchlib | |
parent | e317da1b538ced3dd49d7f582a1c942a4a00d772 (diff) | |
parent | da1a20ab27fff180baf3f574774c3bbb57488fee (diff) |
Merge pull request #7155 from vespa-engine/bratseth/expose-type-information
Bratseth/expose type information
Diffstat (limited to 'searchlib')
6 files changed, 27 insertions, 18 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java index f6502a9801d..787b857839d 100755 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java @@ -11,6 +11,7 @@ import com.yahoo.text.Utf8; import java.security.MessageDigest; import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; import java.util.Collections; import java.util.Deque; import java.util.HashMap; @@ -97,7 +98,12 @@ public class ExpressionFunction { return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType)); } - public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) { + /** Returns a copy of this with the given argument and argument type added */ + public ExpressionFunction withArgument(String argument, TensorType type) { + List<String> arguments = new ArrayList<>(this.arguments); + arguments.add(argument); + Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes); + argumentTypes.put(argument, type); return new ExpressionFunction(name, arguments, body, argumentTypes, returnType); } diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java index 17157ab385f..8aa7446cae7 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java @@ -9,7 +9,6 @@ import com.yahoo.searchlib.rankingexpression.rule.TruthOperator; * In a boolean context doubles are true if they are different from 0.0 * * @author bratseth - * @since 5.1.5 */ public final class DoubleValue extends DoubleCompatibleValue { diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java index 9ff391a5cfe..f26f2dea04f 100644 --- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java +++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java @@ -121,7 +121,7 @@ public class ImportedModel { if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs expressions.add(new Pair<>(signatureEntry.getKey(), new ExpressionFunction(signatureEntry.getKey(), - new ArrayList<>(signatureEntry.getValue().inputs().keySet()), + new ArrayList<>(signatureEntry.getValue().inputs().values()), expressions().get(signatureEntry.getKey()), signatureEntry.getValue().inputMap(), Optional.empty()))); @@ -182,8 +182,11 @@ public class ImportedModel { /** Returns the name and type of all inputs in this signature as an immutable map */ public Map<String, TensorType> inputMap() { ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>(); + // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to* + // in the model, as these are the names which must actually be bound, if we are to avoid creating an + // "input mapping" to accomodate this complexity in for (Map.Entry<String, String> inputEntry : inputs().entrySet()) - inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue())); + inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue())); return inputs.build(); } @@ -207,7 +210,7 @@ public class ImportedModel { /** Returns the expression this output references */ public ExpressionFunction outputExpression(String outputName) { return new ExpressionFunction(outputName, - new ArrayList<>(inputs.keySet()), + new ArrayList<>(inputs.values()), owner().expressions().get(outputs.get(outputName)), inputMap(), Optional.empty()); diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java index 593e7b54c10..e325c3d11b4 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java @@ -15,17 +15,18 @@ public class BatchNormImportTestCase { @Test public void testBatchNormImport() { - TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved"); + TestableTensorFlowModel model = new TestableTensorFlowModel("test", + "src/test/files/integration/tensorflow/batch_norm/saved"); ImportedModel.Signature signature = model.get().signature("serving_default"); assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + ExpressionFunction function = signature.outputExpression("y"); + assertNotNull(function); + assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName()); + model.assertEqualResult("X", function.getBody().getName()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java index 59712c0152f..8ca5a9a7888 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java @@ -30,13 +30,13 @@ public class DropoutImportTestCase { assertEquals("Has skipped outputs", 0, model.get().signature("serving_default").skippedOutputs().size()); - ExpressionFunction output = signature.outputExpression("y"); - assertNotNull(output); - assertEquals("outputs/Maximum", output.getBody().getName()); + ExpressionFunction function = signature.outputExpression("y"); + assertNotNull(function); + assertEquals("outputs/Maximum", function.getBody().getName()); assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))", - output.getBody().getRoot().toString()); - model.assertEqualResult("X", output.getBody().getName()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + function.getBody().getRoot().toString()); + model.assertEqualResult("X", function.getBody().getName()); + assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString()); } } diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java index 0a48ecfce21..feba40601e3 100644 --- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java +++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java @@ -62,7 +62,7 @@ public class TensorFlowMnistSoftmaxImportTestCase { assertEquals("add", output.getBody().getName()); assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))", output.getBody().getRoot().toString()); - assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString()); + assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString()); // Test execution model.assertEqualResult("Placeholder", "MatMul"); |