summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorLester Solbakken <lesters@users.noreply.github.com>2018-10-01 14:22:55 +0200
committerGitHub <noreply@github.com>2018-10-01 14:22:55 +0200
commitfbca8fc6115fbf924cc688d927c50d8e9d99a321 (patch)
treecaf9b072fdaf5b7aff2c6dad5056402caed3a393 /searchlib
parente317da1b538ced3dd49d7f582a1c942a4a00d772 (diff)
parentda1a20ab27fff180baf3f574774c3bbb57488fee (diff)
Merge pull request #7155 from vespa-engine/bratseth/expose-type-information
Bratseth/expose type information
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java8
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java9
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java12
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java2
6 files changed, 27 insertions, 18 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index f6502a9801d..787b857839d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -11,6 +11,7 @@ import com.yahoo.text.Utf8;
import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
+import java.util.ArrayList;
import java.util.Collections;
import java.util.Deque;
import java.util.HashMap;
@@ -97,7 +98,12 @@ public class ExpressionFunction {
return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
}
- public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) {
+ /** Returns a copy of this with the given argument and argument type added */
+ public ExpressionFunction withArgument(String argument, TensorType type) {
+ List<String> arguments = new ArrayList<>(this.arguments);
+ arguments.add(argument);
+ Map<String, TensorType> argumentTypes = new HashMap<>(this.argumentTypes);
+ argumentTypes.put(argument, type);
return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
index 17157ab385f..8aa7446cae7 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/evaluation/DoubleValue.java
@@ -9,7 +9,6 @@ import com.yahoo.searchlib.rankingexpression.rule.TruthOperator;
* In a boolean context doubles are true if they are different from 0.0
*
* @author bratseth
- * @since 5.1.5
*/
public final class DoubleValue extends DoubleCompatibleValue {
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 9ff391a5cfe..f26f2dea04f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -121,7 +121,7 @@ public class ImportedModel {
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
new ExpressionFunction(signatureEntry.getKey(),
- new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ new ArrayList<>(signatureEntry.getValue().inputs().values()),
expressions().get(signatureEntry.getKey()),
signatureEntry.getValue().inputMap(),
Optional.empty())));
@@ -182,8 +182,11 @@ public class ImportedModel {
/** Returns the name and type of all inputs in this signature as an immutable map */
public Map<String, TensorType> inputMap() {
ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
+ // Note: We're naming inputs by their actual name (used in the expression, given by what the input maps *to*
+ // in the model, as these are the names which must actually be bound, if we are to avoid creating an
+ // "input mapping" to accomodate this complexity in
for (Map.Entry<String, String> inputEntry : inputs().entrySet())
- inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue()));
+ inputs.put(inputEntry.getValue(), owner().inputs().get(inputEntry.getValue()));
return inputs.build();
}
@@ -207,7 +210,7 @@ public class ImportedModel {
/** Returns the expression this output references */
public ExpressionFunction outputExpression(String outputName) {
return new ExpressionFunction(outputName,
- new ArrayList<>(inputs.keySet()),
+ new ArrayList<>(inputs.values()),
owner().expressions().get(outputs.get(outputName)),
inputMap(),
Optional.empty());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index 593e7b54c10..e325c3d11b4 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -15,17 +15,18 @@ public class BatchNormImportTestCase {
@Test
public void testBatchNormImport() {
- TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/batch_norm/saved");
+ TestableTensorFlowModel model = new TestableTensorFlowModel("test",
+ "src/test/files/integration/tensorflow/batch_norm/saved");
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
- model.assertEqualResult("X", output.getBody().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ ExpressionFunction function = signature.outputExpression("y");
+ assertNotNull(function);
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", function.getBody().getName());
+ model.assertEqualResult("X", function.getBody().getName());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index 59712c0152f..8ca5a9a7888 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -30,13 +30,13 @@ public class DropoutImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- ExpressionFunction output = signature.outputExpression("y");
- assertNotNull(output);
- assertEquals("outputs/Maximum", output.getBody().getName());
+ ExpressionFunction function = signature.outputExpression("y");
+ assertNotNull(function);
+ assertEquals("outputs/Maximum", function.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.getBody().getRoot().toString());
- model.assertEqualResult("X", output.getBody().getName());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ function.getBody().getRoot().toString());
+ model.assertEqualResult("X", function.getBody().getName());
+ assertEquals("{X=tensor(d0[],d1[784])}", function.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index 0a48ecfce21..feba40601e3 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -62,7 +62,7 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
output.getBody().getRoot().toString());
- assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");