summaryrefslogtreecommitdiffstats
path: root/searchlib
diff options
context:
space:
mode:
authorJon Bratseth <bratseth@oath.com>2018-09-25 15:49:22 -0700
committerJon Bratseth <bratseth@oath.com>2018-09-25 15:49:22 -0700
commit11884899e39c54abeb79bacbe723df0ff34ce869 (patch)
tree674025004f825c9cc12a075f992c0b2d1d45509e /searchlib
parent0246064bbfb9657515f516e2fea12d593cd13016 (diff)
Revert "Merge pull request #7094 from vespa-engine/revert-7070-bratseth/rank-type-information-2"
This reverts commit 0246064bbfb9657515f516e2fea12d593cd13016, reversing changes made to f627463a8100090ec109d27c3aeb439a3395a34f.
Diffstat (limited to 'searchlib')
-rwxr-xr-xsearchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java48
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java79
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java3
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java1
-rw-r--r--searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java13
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java4
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java16
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java8
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java26
-rw-r--r--searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java14
11 files changed, 153 insertions, 67 deletions
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
index da34ab8822d..f6502a9801d 100755
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/ExpressionFunction.java
@@ -2,8 +2,11 @@
package com.yahoo.searchlib.rankingexpression;
import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
+import com.yahoo.log.event.Collection;
import com.yahoo.searchlib.rankingexpression.rule.ExpressionNode;
import com.yahoo.searchlib.rankingexpression.rule.SerializationContext;
+import com.yahoo.tensor.TensorType;
import com.yahoo.text.Utf8;
import java.security.MessageDigest;
@@ -13,9 +16,14 @@ import java.util.Deque;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
/**
- * A function defined by a ranking expression
+ * A function defined by a ranking expression, optionally containing type information
+ * for inputs and outputs.
+ *
+ * Immutable, but note that ranking expressions are *not* immutable.
*
* @author Simon Thoresen Hult
* @author bratseth
@@ -24,8 +32,13 @@ public class ExpressionFunction {
private final String name;
private final ImmutableList<String> arguments;
+
+ /** Types of the inputs, if known. The keys here is any subset (including empty and identity) of the argument list */
+ private final ImmutableMap<String, TensorType> argumentTypes;
private final RankingExpression body;
+ private final Optional<TensorType> returnType;
+
/**
* Constructs a new function with no arguments
*
@@ -44,9 +57,18 @@ public class ExpressionFunction {
* @param body the ranking expression that defines this function
*/
public ExpressionFunction(String name, List<String> arguments, RankingExpression body) {
- this.name = name;
+ this(name, arguments, body, ImmutableMap.of(), Optional.empty());
+ }
+
+ public ExpressionFunction(String name, List<String> arguments, RankingExpression body,
+ Map<String, TensorType> argumentTypes, Optional<TensorType> returnType) {
+ this.name = Objects.requireNonNull(name, "name cannot be null");
this.arguments = arguments==null ? ImmutableList.of() : ImmutableList.copyOf(arguments);
- this.body = body;
+ this.body = Objects.requireNonNull(body, "body cannot be null");
+ if ( ! this.arguments.containsAll(argumentTypes.keySet()))
+ throw new IllegalArgumentException("Argument type keys must be a subset of the argument keys");
+ this.argumentTypes = ImmutableMap.copyOf(argumentTypes);
+ this.returnType = Objects.requireNonNull(returnType, "returnType cannot be null");
}
public String getName() { return name; }
@@ -56,9 +78,27 @@ public class ExpressionFunction {
public RankingExpression getBody() { return body; }
+ /** Returns the types of the arguments of this, if specified. The keys of this may be any subset of the arguments */
+ public Map<String, TensorType> argumentTypes() { return argumentTypes; }
+
+ /** Returns the return type of this, or empty if not specified */
+ public Optional<TensorType> returnType() { return returnType; }
+
+ public ExpressionFunction withName(String name) {
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
/** Returns a copy of this with the body changed to the given value */
public ExpressionFunction withBody(RankingExpression body) {
- return new ExpressionFunction(name, arguments, body);
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
+ }
+
+ public ExpressionFunction withReturnType(TensorType returnType) {
+ return new ExpressionFunction(name, arguments, body, argumentTypes, Optional.of(returnType));
+ }
+
+ public ExpressionFunction withArgumentTypes(Map<String, TensorType> argumentTypes) {
+ return new ExpressionFunction(name, arguments, body, argumentTypes, returnType);
}
/**
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
index 282a4c5e0a9..9ff391a5cfe 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ImportedModel.java
@@ -1,15 +1,22 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.ImmutableMap;
import com.yahoo.collections.Pair;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
import java.util.ArrayList;
+import java.util.Collection;
import java.util.Collections;
import java.util.HashMap;
+import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
+import java.util.Objects;
+import java.util.Optional;
import java.util.regex.Pattern;
/**
@@ -26,12 +33,11 @@ public class ImportedModel {
private final String source;
private final Map<String, Signature> signatures = new HashMap<>();
- private final Map<String, TensorType> arguments = new HashMap<>();
+ private final Map<String, TensorType> inputs = new HashMap<>();
private final Map<String, Tensor> smallConstants = new HashMap<>();
private final Map<String, Tensor> largeConstants = new HashMap<>();
private final Map<String, RankingExpression> expressions = new HashMap<>();
private final Map<String, RankingExpression> functions = new HashMap<>();
- private final Map<String, TensorType> requiredFunctions = new HashMap<>();
/**
* Creates a new imported model.
@@ -49,11 +55,11 @@ public class ImportedModel {
/** Returns the name of this model, which can only contain the characters in [A-Za-z0-9_] */
public String name() { return name; }
- /** Returns the source path (directiry or file) of this model */
+ /** Returns the source path (directory or file) of this model */
public String source() { return source; }
- /** Returns an immutable map of the arguments ("Placeholders") of this */
- public Map<String, TensorType> arguments() { return Collections.unmodifiableMap(arguments); }
+ /** Returns an immutable map of the inputs of this */
+ public Map<String, TensorType> inputs() { return Collections.unmodifiableMap(inputs); }
/**
* Returns an immutable map of the small constants of this.
@@ -71,7 +77,7 @@ public class ImportedModel {
/**
* Returns an immutable map of the expressions of this - corresponding to graph nodes
- * which are not Inputs/Placeholders or Variables (which instead become respectively arguments and constants).
+ * which are not Inputs/Placeholders or Variables (which instead become respectively inputs and constants).
* Note that only nodes recursively referenced by a placeholder/input are added.
*/
public Map<String, RankingExpression> expressions() { return Collections.unmodifiableMap(expressions); }
@@ -82,9 +88,6 @@ public class ImportedModel {
*/
public Map<String, RankingExpression> functions() { return Collections.unmodifiableMap(functions); }
- /** Returns an immutable map of the functions that must be provided by the environment running this model */
- public Map<String, TensorType> requiredFunctions() { return Collections.unmodifiableMap(requiredFunctions); }
-
/** Returns an immutable map of the signatures of this */
public Map<String, Signature> signatures() { return Collections.unmodifiableMap(signatures); }
@@ -96,12 +99,11 @@ public class ImportedModel {
/** Convenience method for returning a default signature */
Signature defaultSignature() { return signature(defaultSignatureName); }
- void argument(String name, TensorType argumentType) { arguments.put(name, argumentType); }
+ void input(String name, TensorType argumentType) { inputs.put(name, argumentType); }
void smallConstant(String name, Tensor constant) { smallConstants.put(name, constant); }
void largeConstant(String name, Tensor constant) { largeConstants.put(name, constant); }
void expression(String name, RankingExpression expression) { expressions.put(name, expression); }
void function(String name, RankingExpression expression) { functions.put(name, expression); }
- void requiredFunction(String name, TensorType type) { requiredFunctions.put(name, type); }
/**
* Returns all the output expressions of this indexed by name. The names consist of one or two parts
@@ -109,24 +111,39 @@ public class ImportedModel {
* if signatures are used, or the expression name if signatures are not used and there are multiple
* expressions, and the second is the output name if signature names are used.
*/
- public List<Pair<String, RankingExpression>> outputExpressions() {
- List<Pair<String, RankingExpression>> expressions = new ArrayList<>();
+ public List<Pair<String, ExpressionFunction>> outputExpressions() {
+ List<Pair<String, ExpressionFunction>> expressions = new ArrayList<>();
for (Map.Entry<String, Signature> signatureEntry : signatures().entrySet()) {
for (Map.Entry<String, String> outputEntry : signatureEntry.getValue().outputs().entrySet())
expressions.add(new Pair<>(signatureEntry.getKey() + "." + outputEntry.getKey(),
- expressions().get(outputEntry.getValue())));
+ signatureEntry.getValue().outputExpression(outputEntry.getKey())
+ .withName(signatureEntry.getKey() + "." + outputEntry.getKey())));
if (signatureEntry.getValue().outputs().isEmpty()) // fallback: Signature without outputs
expressions.add(new Pair<>(signatureEntry.getKey(),
- expressions().get(signatureEntry.getKey())));
+ new ExpressionFunction(signatureEntry.getKey(),
+ new ArrayList<>(signatureEntry.getValue().inputs().keySet()),
+ expressions().get(signatureEntry.getKey()),
+ signatureEntry.getValue().inputMap(),
+ Optional.empty())));
}
if (signatures().isEmpty()) { // fallback for models without signatures
if (expressions().size() == 1) {
Map.Entry<String, RankingExpression> singleEntry = this.expressions.entrySet().iterator().next();
- expressions.add(new Pair<>(singleEntry.getKey(), singleEntry.getValue()));
+ expressions.add(new Pair<>(singleEntry.getKey(),
+ new ExpressionFunction(singleEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ singleEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
else {
for (Map.Entry<String, RankingExpression> expressionEntry : expressions().entrySet()) {
- expressions.add(new Pair<>(expressionEntry.getKey(), expressionEntry.getValue()));
+ expressions.add(new Pair<>(expressionEntry.getKey(),
+ new ExpressionFunction(expressionEntry.getKey(),
+ new ArrayList<>(inputs.keySet()),
+ expressionEntry.getValue(),
+ inputs,
+ Optional.empty())));
}
}
}
@@ -134,7 +151,7 @@ public class ImportedModel {
}
/**
- * A signature is a set of named inputs and outputs, where the inputs maps to argument
+ * A signature is a set of named inputs and outputs, where the inputs maps to input
* ("placeholder") names+types, and outputs maps to expressions nodes.
* Note that TensorFlow supports multiple signatures in their format, but ONNX has no explicit
* concept of signatures. For now, we handle ONNX models as having a single signature.
@@ -142,8 +159,8 @@ public class ImportedModel {
public class Signature {
private final String name;
- private final Map<String, String> inputs = new HashMap<>();
- private final Map<String, String> outputs = new HashMap<>();
+ private final Map<String, String> inputs = new LinkedHashMap<>();
+ private final Map<String, String> outputs = new LinkedHashMap<>();
private final Map<String, String> skippedOutputs = new HashMap<>();
private final List<String> importWarnings = new ArrayList<>();
@@ -158,12 +175,20 @@ public class ImportedModel {
/**
* Returns an immutable map of the inputs (evaluation context) of this. This is a map from input name
- * to argument (Placeholder) name in the owner of this
+ * in this signature to input name in the owning model
*/
public Map<String, String> inputs() { return Collections.unmodifiableMap(inputs); }
- /** Returns the type of the argument this input references */
- public TensorType inputArgument(String inputName) { return owner().arguments().get(inputs.get(inputName)); }
+ /** Returns the name and type of all inputs in this signature as an immutable map */
+ public Map<String, TensorType> inputMap() {
+ ImmutableMap.Builder<String, TensorType> inputs = new ImmutableMap.Builder<>();
+ for (Map.Entry<String, String> inputEntry : inputs().entrySet())
+ inputs.put(inputEntry.getKey(), owner().inputs().get(inputEntry.getValue()));
+ return inputs.build();
+ }
+
+ /** Returns the type of the input this input references */
+ public TensorType inputArgument(String inputName) { return owner().inputs().get(inputs.get(inputName)); }
/** Returns an immutable list of the expression names of this */
public Map<String, String> outputs() { return Collections.unmodifiableMap(outputs); }
@@ -180,7 +205,13 @@ public class ImportedModel {
public List<String> importWarnings() { return Collections.unmodifiableList(importWarnings); }
/** Returns the expression this output references */
- public RankingExpression outputExpression(String outputName) { return owner().expressions().get(outputs.get(outputName)); }
+ public ExpressionFunction outputExpression(String outputName) {
+ return new ExpressionFunction(outputName,
+ new ArrayList<>(inputs.keySet()),
+ owner().expressions().get(outputs.get(outputName)),
+ inputMap(),
+ Optional.empty());
+ }
@Override
public String toString() { return "signature '" + name + "'"; }
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
index d25502fd149..b7138ad87e3 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/ModelImporter.java
@@ -187,8 +187,7 @@ public abstract class ModelImporter {
if (operation.isInput()) {
// All inputs must have dimensions with standard naming convention: d0, d1, ...
OrderedTensorType standardNamingConvention = OrderedTensorType.standardType(operation.type().get());
- model.argument(operation.vespaName(), standardNamingConvention.type());
- model.requiredFunction(operation.vespaName(), standardNamingConvention.type());
+ model.input(operation.vespaName(), standardNamingConvention.type());
}
}
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
index 917b0d6a389..e6bb5f40b3f 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxImporter.java
@@ -2,7 +2,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
-import com.yahoo.io.IOUtils;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.IntermediateGraph;
import com.yahoo.searchlib.rankingexpression.integration.ml.importer.onnx.GraphImporter;
import onnx.Onnx;
diff --git a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
index 796c13a8669..94d663b4954 100644
--- a/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
+++ b/searchlib/src/main/java/com/yahoo/searchlib/rankingexpression/rule/SerializationContext.java
@@ -3,6 +3,8 @@ package com.yahoo.searchlib.rankingexpression.rule;
import com.google.common.collect.ImmutableMap;
import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
+import com.yahoo.searchlib.rankingexpression.RankingExpression;
+import com.yahoo.tensor.TensorType;
import java.util.Collection;
import java.util.Collections;
@@ -80,9 +82,14 @@ public class SerializationContext extends FunctionReferenceContext {
serializedFunctions.put(name, expressionString);
}
- /** Returns the existing serialization of a function, or null if none */
- public String getFunctionSerialization(String name) {
- return serializedFunctions.get(name);
+ /** Adds the serialization of the an argument type to a function */
+ public void addArgumentTypeSerialization(String functionName, String argumentName, TensorType type) {
+ serializedFunctions.put("rankingExpression(" + functionName + ")." + argumentName + ".type", type.toString());
+ }
+
+ /** Adds the serialization of the return type of a function */
+ public void addFunctionTypeSerialization(String functionName, TensorType type) {
+ serializedFunctions.put("rankingExpression(" + functionName + ").type", type.toString());
}
@Override
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
index 118eba2cd96..969bc318391 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/aggregation/GroupingSerializationTest.java
@@ -15,6 +15,7 @@ import org.junit.Test;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
+import java.nio.charset.StandardCharsets;
import java.util.Arrays;
import static org.junit.Assert.fail;
@@ -40,7 +41,8 @@ public class GroupingSerializationTest {
t.assertMatch(new FloatResultNode(7.3));
t.assertMatch(new StringResultNode("7.3"));
t.assertMatch(new StringResultNode(
- new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c})));
+ new String(new byte[]{(byte)0xe5, (byte)0xa6, (byte)0x82, (byte)0xe6, (byte)0x9e, (byte)0x9c},
+ StandardCharsets.UTF_8)));
t.assertMatch(new RawResultNode(new byte[]{'7', '.', '4'}));
t.assertMatch(new IntegerBucketResultNode());
t.assertMatch(new FloatBucketResultNode());
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
index bf9684082f4..593e7b54c10 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/BatchNormImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,10 +21,11 @@ public class BatchNormImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getName());
- model.assertEqualResult("X", output.getName());
+ assertEquals("dnn/batch_normalization_3/batchnorm/add_1", output.getBody().getName());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
index a8f7542f3a4..59712c0152f 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/DropoutImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.TensorType;
import org.junit.Test;
@@ -19,22 +20,23 @@ public class DropoutImportTestCase {
TestableTensorFlowModel model = new TestableTensorFlowModel("test", "src/test/files/integration/tensorflow/dropout/saved");
// Check required functions
- assertEquals(1, model.get().requiredFunctions().size());
- assertTrue(model.get().requiredFunctions().containsKey("X"));
+ assertEquals(1, model.get().inputs().size());
+ assertTrue(model.get().inputs().containsKey("X"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().requiredFunctions().get("X"));
+ model.get().inputs().get("X"));
ImportedModel.Signature signature = model.get().signature("serving_default");
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("outputs/Maximum", output.getName());
+ assertEquals("outputs/Maximum", output.getBody().getName());
assertEquals("join(join(imported_ml_function_test_outputs_BiasAdd, reduce(constant(test_outputs_Const), sum, d1), f(a,b)(a * b)), imported_ml_function_test_outputs_BiasAdd, f(a,b)(max(a,b)))",
- output.getRoot().toString());
- model.assertEqualResult("X", output.getName());
+ output.getBody().getRoot().toString());
+ model.assertEqualResult("X", output.getBody().getName());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
index add66eece1a..3d8d5d5a570 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/MnistImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import org.junit.Test;
@@ -20,11 +21,10 @@ public class MnistImportTestCase {
assertEquals("Has skipped outputs",
0, model.get().signature("serving_default").skippedOutputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("dnn/outputs/add", output.getName());
- model.assertEqualResultSum("input", output.getName(), 0.00001);
+ assertEquals("dnn/outputs/add", output.getBody().getName());
+ model.assertEqualResultSum("input", output.getBody().getName(), 0.00001);
}
-
}
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
index e20ac16a691..b6e83404ab1 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/OnnxMnistSoftmaxImportTestCase.java
@@ -1,5 +1,6 @@
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.searchlib.rankingexpression.evaluation.Context;
import com.yahoo.searchlib.rankingexpression.evaluation.MapContext;
@@ -27,27 +28,28 @@ public class OnnxMnistSoftmaxImportTestCase {
Tensor constant0 = model.largeConstants().get("test_Variable");
assertNotNull(constant0);
assertEquals(new TensorType.Builder().indexed("d2", 784).indexed("d1", 10).build(),
- constant0.type());
+ constant0.type());
assertEquals(7840, constant0.size());
Tensor constant1 = model.largeConstants().get("test_Variable_1");
assertNotNull(constant1);
- assertEquals(new TensorType.Builder().indexed("d1", 10).build(),
- constant1.type());
+ assertEquals(new TensorType.Builder().indexed("d1", 10).build(), constant1.type());
assertEquals(10, constant1.size());
- // Check required functions (inputs)
- assertEquals(1, model.requiredFunctions().size());
- assertTrue(model.requiredFunctions().containsKey("Placeholder"));
- assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.requiredFunctions().get("Placeholder"));
+ // Check inputs
+ assertEquals(1, model.inputs().size());
+ assertTrue(model.inputs().containsKey("Placeholder"));
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"), model.inputs().get("Placeholder"));
- // Check outputs
- RankingExpression output = model.defaultSignature().outputExpression("add");
+ // Check signature
+ ExpressionFunction output = model.defaultSignature().outputExpression("add");
assertNotNull(output);
- assertEquals("add", output.getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable), f(a,b)(a * b)), sum, d2), constant(test_Variable_1), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.getBody().getRoot().toString());
+ assertEquals(TensorType.fromSpec("tensor(d0[],d1[784])"),
+ model.inputs().get(model.defaultSignature().inputs().get("Placeholder")));
+ assertEquals("{Placeholder=tensor(d0[],d1[784])}", output.argumentTypes().toString());
}
@Test
diff --git a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
index ef28eb4678f..0a48ecfce21 100644
--- a/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
+++ b/searchlib/src/test/java/com/yahoo/searchlib/rankingexpression/integration/ml/TensorFlowMnistSoftmaxImportTestCase.java
@@ -1,6 +1,7 @@
// Copyright 2018 Yahoo Holdings. Licensed under the terms of the Apache 2.0 license. See LICENSE in the project root.
package com.yahoo.searchlib.rankingexpression.integration.ml;
+import com.yahoo.searchlib.rankingexpression.ExpressionFunction;
import com.yahoo.searchlib.rankingexpression.RankingExpression;
import com.yahoo.tensor.Tensor;
import com.yahoo.tensor.TensorType;
@@ -38,10 +39,10 @@ public class TensorFlowMnistSoftmaxImportTestCase {
assertEquals(0, model.get().functions().size());
// Check required functions
- assertEquals(1, model.get().requiredFunctions().size());
- assertTrue(model.get().requiredFunctions().containsKey("Placeholder"));
+ assertEquals(1, model.get().inputs().size());
+ assertTrue(model.get().inputs().containsKey("Placeholder"));
assertEquals(new TensorType.Builder().indexed("d0").indexed("d1", 784).build(),
- model.get().requiredFunctions().get("Placeholder"));
+ model.get().inputs().get("Placeholder"));
// Check signatures
assertEquals(1, model.get().signatures().size());
@@ -56,11 +57,12 @@ public class TensorFlowMnistSoftmaxImportTestCase {
// ... signature outputs
assertEquals(1, signature.outputs().size());
- RankingExpression output = signature.outputExpression("y");
+ ExpressionFunction output = signature.outputExpression("y");
assertNotNull(output);
- assertEquals("add", output.getName());
+ assertEquals("add", output.getBody().getName());
assertEquals("join(reduce(join(rename(Placeholder, (d0, d1), (d0, d2)), constant(test_Variable_read), f(a,b)(a * b)), sum, d2), constant(test_Variable_1_read), f(a,b)(a + b))",
- output.getRoot().toString());
+ output.getBody().getRoot().toString());
+ assertEquals("{x=tensor(d0[],d1[784])}", output.argumentTypes().toString());
// Test execution
model.assertEqualResult("Placeholder", "MatMul");